Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""
Copyright (C) 2026, Amazon.com. All Rights Reserved

NKI implementation for the all-gather + matmul ring tutorial.

This kernel performs a fused all-gather + matmul along a ring of TP ranks,
using ``nki.collectives.collective_permute_implicit`` (CPI) to overlap
communication with compute. Each step, a rank computes a local matmul
using the LHS fragment currently in its ring buffer, then passes the
fragment on to the next rank in the ring while receiving the previous
rank's fragment — the scheduler places the matmul of step (i) and the
CPI of step (i) on disjoint engines, so they run concurrently.

The matmul is row-parallel (LHS row-sharded, RHS column-sharded). After
RANK_N ring steps, every rank has computed one (M_LOCAL, N_LOCAL) slot of
the fully-gathered output for every source rank.
"""

import nki as nki
import nki.collectives as ncc
import nki.isa as nisa
import nki.language as nl


# NKI_EXAMPLE_AGMM_RING_BEGIN
@nki.jit
def allgather_matmul_ring(
a_shard,
b_shard,
replica_group,
RANK_N,
M_LOCAL,
N_LOCAL,
K_TILE,
N_TILE,
):
"""Ring all-gather + matmul.

Args:
a_shard (nl.ndarray): [M_LOCAL, K] — this rank's LHS row-slice.
b_shard (nl.ndarray): [K, N_LOCAL] — this rank's RHS column-slice.
replica_group (ncc.ReplicaGroup): Ring of all TP ranks.
RANK_N (int): Number of TP ranks (ring length).
M_LOCAL (int): Local M dimension per rank.
N_LOCAL (int): Local N dimension per rank (RHS column shard width).
K_TILE (int): Contraction-dim tile size, must be 128 (Tensor Engine
partition-dim limit).
N_TILE (int): Free-dim tile size, must be <= 512.

Returns:
nl.ndarray of shape [RANK_N, M_LOCAL, N_LOCAL]. Slot ``r`` contains the
matmul contribution from rank ``r``'s LHS shard (a_shard_r @ b_shard).
"""
_, K = a_shard.shape
K_TILES = K // K_TILE
N_TILES = N_LOCAL // N_TILE

# Two ping-pong ring buffers on shared_hbm. Collectives cannot source
# from IO tensors directly, so we also seed with a DMA copy below.
buf0 = nl.ndarray((M_LOCAL, K), dtype=a_shard.dtype, buffer=nl.shared_hbm,
name="ring_buf0")
buf1 = nl.ndarray((M_LOCAL, K), dtype=a_shard.dtype, buffer=nl.shared_hbm,
name="ring_buf1")

# Output: one (M_LOCAL, N_LOCAL) slot per source-rank.
out = nl.ndarray((RANK_N, M_LOCAL, N_LOCAL), dtype=a_shard.dtype,
buffer=nl.shared_hbm, name="out")

# Seed the ring: copy local a_shard into buf0.
nisa.dma_copy(dst=buf0, src=a_shard)

# `step` is an ordinary Python integer here (meta-programming): the
# compiler specializes the kernel body for each concrete value of step.
for step in range(RANK_N):
# Alternate the two ring buffers across steps.
if step % 2 == 0:
buf_cur, buf_next = buf0, buf1
else:
buf_cur, buf_next = buf1, buf0

# src_rank is a runtime-valued rank ID held in a register. It can only
# be used as scalar_offset in an access pattern (not compared, not
# materialized as a Python value).
src_rank = ncc.collective_permute_implicit_current_processing_rank_id(
iteration_id=step, replica_group=replica_group, channel_id=0,
)

# Launch the CPI for the NEXT step first. The matmul below and this
# CPI both read buf_cur — independent readers — and CPI writes to a
# buffer not touched until the next iteration. Placing the CPI first
# hints the scheduler to overlap the two on different engines.
if step < RANK_N - 1:
ncc.collective_permute_implicit(
srcs_by_channel=[[buf_cur]],
dsts_by_channel=[[buf_next]],
replica_group=replica_group,
channel_ids=[0],
)

# Local matmul (M_LOCAL, N_LOCAL) = a_shard_cur @ b_shard.
# Tiled over K (K_TILES × K_TILE accumulated into PSUM) and N
# (N_TILES × N_TILE). M fits in a single Tensor Engine partition.
partial_sb = nl.ndarray((M_LOCAL, N_LOCAL), dtype=a_shard.dtype,
buffer=nl.sbuf)
for nt in nl.affine_range(N_TILES):
n0 = nt * N_TILE
psum = nl.zeros((M_LOCAL, N_TILE), dtype=nl.float32, buffer=nl.psum)
for kt in nl.affine_range(K_TILES):
k0 = kt * K_TILE
# Fresh SBUF tiles per K-step — avoid false WAR chains across
# iterations so the compiler can pipeline K-tile loads.
a_tile = nl.ndarray((K_TILE, M_LOCAL), dtype=a_shard.dtype,
buffer=nl.sbuf)
b_tile = nl.ndarray((K_TILE, N_TILE), dtype=b_shard.dtype,
buffer=nl.sbuf)
# nc_matmul wants stationary in [K, M] layout. buf_cur is
# [M, K], so dma_transpose swaps axes during the load.
nisa.dma_transpose(dst=a_tile, src=buf_cur[:, k0 : k0 + K_TILE])
nisa.dma_copy(dst=b_tile,
src=b_shard[k0 : k0 + K_TILE, n0 : n0 + N_TILE])
# Auto-accumulate into psum: first write overwrites, later writes add.
nisa.nc_matmul(dst=psum, stationary=a_tile, moving=b_tile)
nisa.tensor_copy(dst=partial_sb[:, n0 : n0 + N_TILE], src=psum)

# Write this step's result into the src_rank slot of `out`.
# scalar_offset=src_rank + indirect_dim=0 indexes the leading
# RANK_N dimension at runtime.
nisa.dma_copy(
dst=out.ap(
pattern=[[N_LOCAL, M_LOCAL], [1, N_LOCAL]],
offset=0,
scalar_offset=src_rank,
indirect_dim=0,
),
src=partial_sb,
)

return out
# NKI_EXAMPLE_AGMM_RING_END
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
Copyright (C) 2026, Amazon.com. All Rights Reserved

PyTorch/XLA runner for the all-gather + matmul ring NKI tutorial.

Launches ``allgather_matmul_ring`` across TP ranks using
``torch_xla.distributed.xla_multiprocessing.spawn``, constructs
deterministic LHS/RHS shards, and validates each rank's output against a
reference matmul computed on the host.

Run on trn2 with TP=16 LNC=2 (16 ranks across 4 devices):

NEURON_CC_FLAGS="--lnc=2" \\
NEURON_LOGICAL_NC_CONFIG=2 \\
NEURONCORE_NUM_DEVICES=16 \\
python allgather_matmul_ring_torch.py

TP=4 LNC=2 on a single Neuron device also works. Other TP values may
fail if the replica group does not map to a valid CPI ring topology.
"""

import os

import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr

import nki.collectives as ncc

from allgather_matmul_ring_nki_kernels import allgather_matmul_ring


RANK_N = 16 # TP degree, also ring length
M_LOCAL = 128 # per-rank M slice (M_full = 2048)
N_LOCAL = 512 # per-rank N slice (N_full = 8192)
K = 2048 # shared contraction dim
K_TILE = 128 # per-nc_matmul contract-dim tile (Tensor Engine partition-dim limit)
N_TILE = 512 # per-nc_matmul free-dim tile (Tensor Engine free-dim limit)


def _lnc_from_env() -> int:
return int(os.environ.get("NEURON_LOGICAL_NC_CONFIG", "1"))


def runner(rank):
device = xm.xla_device()
tp_size = xr.world_size()
assert tp_size == RANK_N, f"need tp={RANK_N}, got {tp_size}"
replica_group = ncc.ReplicaGroup((tuple(range(tp_size)),))

M_full = M_LOCAL * RANK_N
N_full = N_LOCAL * RANK_N

# Deterministic full tensors on every rank; each rank takes its shard.
torch.manual_seed(0)
x_full = torch.randn((M_full, K), dtype=torch.bfloat16)
w_full = torch.randn((K, N_full), dtype=torch.bfloat16)

# Reference: this rank owns N_LOCAL columns of the output; compute the
# full-M matmul on host for that column slice and compare.
n_start = rank * N_LOCAL
expected = (
x_full.float() @ w_full[:, n_start : n_start + N_LOCAL].float()
).to(torch.bfloat16).reshape(RANK_N, M_LOCAL, N_LOCAL)

# This rank's LHS and RHS shards.
m_start = rank * M_LOCAL
a_shard = x_full[m_start : m_start + M_LOCAL, :].contiguous().to(device)
b_shard = w_full[:, n_start : n_start + N_LOCAL].contiguous().to(device)

lnc = _lnc_from_env()
out = allgather_matmul_ring[lnc](
a_shard, b_shard, replica_group,
RANK_N, M_LOCAL, N_LOCAL, K_TILE, N_TILE,
)
xm.mark_step()
got = out.cpu()

got_f, exp_f = got.float(), expected.float()
max_abs_err = (got_f - exp_f).abs().max().item()
rel_err = max_abs_err / (exp_f.abs().max().item() + 1e-9)
ok = rel_err < 0.10
print(
f"Rank {rank}: out shape={tuple(got.shape)} "
f"max_abs_err={max_abs_err:.3f} rel_err={rel_err:.4f} "
f"{'PASS' if ok else 'FAIL'}",
)
assert ok, f"Rank {rank}: output mismatch"


if __name__ == "__main__":
xmp.spawn(runner, args=())