diff --git a/src/nki_samples/tutorials/allgather_matmul_ring/allgather_matmul_ring_nki_kernels.py b/src/nki_samples/tutorials/allgather_matmul_ring/allgather_matmul_ring_nki_kernels.py new file mode 100644 index 0000000..00575a3 --- /dev/null +++ b/src/nki_samples/tutorials/allgather_matmul_ring/allgather_matmul_ring_nki_kernels.py @@ -0,0 +1,139 @@ +""" +Copyright (C) 2026, Amazon.com. All Rights Reserved + +NKI implementation for the all-gather + matmul ring tutorial. + +This kernel performs a fused all-gather + matmul along a ring of TP ranks, +using ``nki.collectives.collective_permute_implicit`` (CPI) to overlap +communication with compute. Each step, a rank computes a local matmul +using the LHS fragment currently in its ring buffer, then passes the +fragment on to the next rank in the ring while receiving the previous +rank's fragment — the scheduler places the matmul of step (i) and the +CPI of step (i) on disjoint engines, so they run concurrently. + +The matmul is row-parallel (LHS row-sharded, RHS column-sharded). After +RANK_N ring steps, every rank has computed one (M_LOCAL, N_LOCAL) slot of +the fully-gathered output for every source rank. +""" + +import nki as nki +import nki.collectives as ncc +import nki.isa as nisa +import nki.language as nl + + +# NKI_EXAMPLE_AGMM_RING_BEGIN +@nki.jit +def allgather_matmul_ring( + a_shard, + b_shard, + replica_group, + RANK_N, + M_LOCAL, + N_LOCAL, + K_TILE, + N_TILE, +): + """Ring all-gather + matmul. + + Args: + a_shard (nl.ndarray): [M_LOCAL, K] — this rank's LHS row-slice. + b_shard (nl.ndarray): [K, N_LOCAL] — this rank's RHS column-slice. + replica_group (ncc.ReplicaGroup): Ring of all TP ranks. + RANK_N (int): Number of TP ranks (ring length). + M_LOCAL (int): Local M dimension per rank. + N_LOCAL (int): Local N dimension per rank (RHS column shard width). + K_TILE (int): Contraction-dim tile size, must be 128 (Tensor Engine + partition-dim limit). + N_TILE (int): Free-dim tile size, must be <= 512. + + Returns: + nl.ndarray of shape [RANK_N, M_LOCAL, N_LOCAL]. Slot ``r`` contains the + matmul contribution from rank ``r``'s LHS shard (a_shard_r @ b_shard). + """ + _, K = a_shard.shape + K_TILES = K // K_TILE + N_TILES = N_LOCAL // N_TILE + + # Two ping-pong ring buffers on shared_hbm. Collectives cannot source + # from IO tensors directly, so we also seed with a DMA copy below. + buf0 = nl.ndarray((M_LOCAL, K), dtype=a_shard.dtype, buffer=nl.shared_hbm, + name="ring_buf0") + buf1 = nl.ndarray((M_LOCAL, K), dtype=a_shard.dtype, buffer=nl.shared_hbm, + name="ring_buf1") + + # Output: one (M_LOCAL, N_LOCAL) slot per source-rank. + out = nl.ndarray((RANK_N, M_LOCAL, N_LOCAL), dtype=a_shard.dtype, + buffer=nl.shared_hbm, name="out") + + # Seed the ring: copy local a_shard into buf0. + nisa.dma_copy(dst=buf0, src=a_shard) + + # `step` is an ordinary Python integer here (meta-programming): the + # compiler specializes the kernel body for each concrete value of step. + for step in range(RANK_N): + # Alternate the two ring buffers across steps. + if step % 2 == 0: + buf_cur, buf_next = buf0, buf1 + else: + buf_cur, buf_next = buf1, buf0 + + # src_rank is a runtime-valued rank ID held in a register. It can only + # be used as scalar_offset in an access pattern (not compared, not + # materialized as a Python value). + src_rank = ncc.collective_permute_implicit_current_processing_rank_id( + iteration_id=step, replica_group=replica_group, channel_id=0, + ) + + # Launch the CPI for the NEXT step first. The matmul below and this + # CPI both read buf_cur — independent readers — and CPI writes to a + # buffer not touched until the next iteration. Placing the CPI first + # hints the scheduler to overlap the two on different engines. + if step < RANK_N - 1: + ncc.collective_permute_implicit( + srcs_by_channel=[[buf_cur]], + dsts_by_channel=[[buf_next]], + replica_group=replica_group, + channel_ids=[0], + ) + + # Local matmul (M_LOCAL, N_LOCAL) = a_shard_cur @ b_shard. + # Tiled over K (K_TILES × K_TILE accumulated into PSUM) and N + # (N_TILES × N_TILE). M fits in a single Tensor Engine partition. + partial_sb = nl.ndarray((M_LOCAL, N_LOCAL), dtype=a_shard.dtype, + buffer=nl.sbuf) + for nt in nl.affine_range(N_TILES): + n0 = nt * N_TILE + psum = nl.zeros((M_LOCAL, N_TILE), dtype=nl.float32, buffer=nl.psum) + for kt in nl.affine_range(K_TILES): + k0 = kt * K_TILE + # Fresh SBUF tiles per K-step — avoid false WAR chains across + # iterations so the compiler can pipeline K-tile loads. + a_tile = nl.ndarray((K_TILE, M_LOCAL), dtype=a_shard.dtype, + buffer=nl.sbuf) + b_tile = nl.ndarray((K_TILE, N_TILE), dtype=b_shard.dtype, + buffer=nl.sbuf) + # nc_matmul wants stationary in [K, M] layout. buf_cur is + # [M, K], so dma_transpose swaps axes during the load. + nisa.dma_transpose(dst=a_tile, src=buf_cur[:, k0 : k0 + K_TILE]) + nisa.dma_copy(dst=b_tile, + src=b_shard[k0 : k0 + K_TILE, n0 : n0 + N_TILE]) + # Auto-accumulate into psum: first write overwrites, later writes add. + nisa.nc_matmul(dst=psum, stationary=a_tile, moving=b_tile) + nisa.tensor_copy(dst=partial_sb[:, n0 : n0 + N_TILE], src=psum) + + # Write this step's result into the src_rank slot of `out`. + # scalar_offset=src_rank + indirect_dim=0 indexes the leading + # RANK_N dimension at runtime. + nisa.dma_copy( + dst=out.ap( + pattern=[[N_LOCAL, M_LOCAL], [1, N_LOCAL]], + offset=0, + scalar_offset=src_rank, + indirect_dim=0, + ), + src=partial_sb, + ) + + return out +# NKI_EXAMPLE_AGMM_RING_END diff --git a/src/nki_samples/tutorials/allgather_matmul_ring/allgather_matmul_ring_torch.py b/src/nki_samples/tutorials/allgather_matmul_ring/allgather_matmul_ring_torch.py new file mode 100644 index 0000000..151356c --- /dev/null +++ b/src/nki_samples/tutorials/allgather_matmul_ring/allgather_matmul_ring_torch.py @@ -0,0 +1,93 @@ +""" +Copyright (C) 2026, Amazon.com. All Rights Reserved + +PyTorch/XLA runner for the all-gather + matmul ring NKI tutorial. + +Launches ``allgather_matmul_ring`` across TP ranks using +``torch_xla.distributed.xla_multiprocessing.spawn``, constructs +deterministic LHS/RHS shards, and validates each rank's output against a +reference matmul computed on the host. + +Run on trn2 with TP=16 LNC=2 (16 ranks across 4 devices): + + NEURON_CC_FLAGS="--lnc=2" \\ + NEURON_LOGICAL_NC_CONFIG=2 \\ + NEURONCORE_NUM_DEVICES=16 \\ + python allgather_matmul_ring_torch.py + +TP=4 LNC=2 on a single Neuron device also works. Other TP values may +fail if the replica group does not map to a valid CPI ring topology. +""" + +import os + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.runtime as xr + +import nki.collectives as ncc + +from allgather_matmul_ring_nki_kernels import allgather_matmul_ring + + +RANK_N = 16 # TP degree, also ring length +M_LOCAL = 128 # per-rank M slice (M_full = 2048) +N_LOCAL = 512 # per-rank N slice (N_full = 8192) +K = 2048 # shared contraction dim +K_TILE = 128 # per-nc_matmul contract-dim tile (Tensor Engine partition-dim limit) +N_TILE = 512 # per-nc_matmul free-dim tile (Tensor Engine free-dim limit) + + +def _lnc_from_env() -> int: + return int(os.environ.get("NEURON_LOGICAL_NC_CONFIG", "1")) + + +def runner(rank): + device = xm.xla_device() + tp_size = xr.world_size() + assert tp_size == RANK_N, f"need tp={RANK_N}, got {tp_size}" + replica_group = ncc.ReplicaGroup((tuple(range(tp_size)),)) + + M_full = M_LOCAL * RANK_N + N_full = N_LOCAL * RANK_N + + # Deterministic full tensors on every rank; each rank takes its shard. + torch.manual_seed(0) + x_full = torch.randn((M_full, K), dtype=torch.bfloat16) + w_full = torch.randn((K, N_full), dtype=torch.bfloat16) + + # Reference: this rank owns N_LOCAL columns of the output; compute the + # full-M matmul on host for that column slice and compare. + n_start = rank * N_LOCAL + expected = ( + x_full.float() @ w_full[:, n_start : n_start + N_LOCAL].float() + ).to(torch.bfloat16).reshape(RANK_N, M_LOCAL, N_LOCAL) + + # This rank's LHS and RHS shards. + m_start = rank * M_LOCAL + a_shard = x_full[m_start : m_start + M_LOCAL, :].contiguous().to(device) + b_shard = w_full[:, n_start : n_start + N_LOCAL].contiguous().to(device) + + lnc = _lnc_from_env() + out = allgather_matmul_ring[lnc]( + a_shard, b_shard, replica_group, + RANK_N, M_LOCAL, N_LOCAL, K_TILE, N_TILE, + ) + xm.mark_step() + got = out.cpu() + + got_f, exp_f = got.float(), expected.float() + max_abs_err = (got_f - exp_f).abs().max().item() + rel_err = max_abs_err / (exp_f.abs().max().item() + 1e-9) + ok = rel_err < 0.10 + print( + f"Rank {rank}: out shape={tuple(got.shape)} " + f"max_abs_err={max_abs_err:.3f} rel_err={rel_err:.4f} " + f"{'PASS' if ok else 'FAIL'}", + ) + assert ok, f"Rank {rank}: output mismatch" + + +if __name__ == "__main__": + xmp.spawn(runner, args=())