Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions cula/ops/chunk_delta_h_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import torch.nn.functional as F
import triton
from cutlass._mlir.dialects import llvm as _llvm
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cute.nvgpu import OperandMajorMode, cpasync, tcgen05
from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream
from cutlass.cute.typing import Float32, Int32, Int64
from cutlass.cutlass_dsl import T as _T
Expand Down Expand Up @@ -308,8 +308,9 @@ def __call__(
# WH MMA: A=state(TMEM, K-major), B=W(SMEM, K-major)
wh_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.K, # A: state, K-major (required for TMEM source)
tcgen05.OperandMajorMode.K, # B: W, K-major (BK contiguous)
self.io_dtype,
OperandMajorMode.K, # A: state, K-major (required for TMEM source)
OperandMajorMode.K, # B: W, K-major (BK contiguous)
self.acc_dtype,
self.cta_group,
self.wh_mma_tiler[:2],
Expand All @@ -319,8 +320,9 @@ def __call__(
# KV MMA: A=v_new^T(TMEM, K-major required), B=K^T(SMEM, MN-major)
kv_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.K, # A: v_new, K-major (required for TMEM source)
tcgen05.OperandMajorMode.MN, # B: K^T, MN-major (BK contiguous)
self.io_dtype,
OperandMajorMode.K, # A: v_new, K-major (required for TMEM source)
OperandMajorMode.MN, # B: K^T, MN-major (BK contiguous)
self.acc_dtype,
self.cta_group,
self.kv_mma_tiler[:2],
Expand Down
37 changes: 22 additions & 15 deletions cula/ops/chunk_wy_dqkg_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
mbarrier_init_fence,
mbarrier_wait,
)
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cute.nvgpu import OperandMajorMode, cpasync, tcgen05
from cutlass.cute.nvgpu.tcgen05 import (
make_umma_smem_desc,
smem_descriptor_to_int,
Expand Down Expand Up @@ -605,8 +605,9 @@ def __call__(
# dq += do @ h, dk += vnew @ dh, dw += dv @ h
vloop_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.K, # A: K-major
tcgen05.OperandMajorMode.K, # B: K-major
self.io_dtype,
OperandMajorMode.K, # A: K-major
OperandMajorMode.K, # B: K-major
self.acc_dtype,
self.cta_group,
self.vloop_gemm_tiler[:2], # (64, 128)
Expand All @@ -617,8 +618,9 @@ def __call__(
# dA += dv @ v^T, dA += dw @ kg^T
dA_vloop_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
self.io_dtype,
OperandMajorMode.K,
OperandMajorMode.K,
self.acc_dtype,
self.cta_group,
self.dA_vloop_tiler[:2], # (64, 64)
Expand All @@ -629,8 +631,9 @@ def __call__(
# dvb = A @ dv, dkgb = A @ dw
dvb_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.MN,
tcgen05.OperandMajorMode.MN,
self.io_dtype,
OperandMajorMode.MN,
OperandMajorMode.MN,
self.acc_dtype,
self.cta_group,
self.dvb_tiler[:2], # (64, 64)
Expand All @@ -639,8 +642,9 @@ def __call__(
# dkgb_tiled_mma: SS MN,MN (64,128) - dkgb
dkgb_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.MN,
tcgen05.OperandMajorMode.MN,
self.io_dtype,
OperandMajorMode.MN,
OperandMajorMode.MN,
self.acc_dtype,
self.cta_group,
self.kloop_dkgb_tiler[:2], # (64, 128)
Expand All @@ -650,8 +654,9 @@ def __call__(
# dA += dw @ kg^T
dA_kloop_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
self.io_dtype,
OperandMajorMode.K,
OperandMajorMode.K,
self.acc_dtype,
self.cta_group,
self.kloop_dA_tiler[:2], # (64, 64)
Expand All @@ -661,8 +666,9 @@ def __call__(
# dA = dA @ A
dA2post_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
self.io_dtype,
OperandMajorMode.K,
OperandMajorMode.K,
self.acc_dtype,
self.cta_group,
self.dApost_tiler[:2], # (64, 64)
Expand All @@ -673,8 +679,9 @@ def __call__(
# dA = A @ dA
dA3post_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.MN,
tcgen05.OperandMajorMode.MN,
self.io_dtype,
OperandMajorMode.MN,
OperandMajorMode.MN,
self.acc_dtype,
self.cta_group,
self.dApost_tiler[:2], # (64, 64)
Expand Down
12 changes: 7 additions & 5 deletions cula/ops/cp/pre_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import cutlass.utils as utils
import cutlass.utils.blackwell_helpers as sm100_utils
import torch
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cute.nvgpu import OperandMajorMode, cpasync, tcgen05
from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream
from cutlass.cute.typing import Float32, Int32, Int64

Expand Down Expand Up @@ -272,17 +272,19 @@ def __call__(
# ===================== MMA setup (same as fwd_h) =====================
wh_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
self.io_dtype,
OperandMajorMode.K,
OperandMajorMode.K,
self.acc_dtype,
self.cta_group,
self.wh_mma_tiler[:2],
tcgen05.OperandSource.TMEM,
)
kv_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.MN,
self.io_dtype,
OperandMajorMode.K,
OperandMajorMode.MN,
self.acc_dtype,
self.cta_group,
self.kv_mma_tiler[:2],
Expand Down
17 changes: 10 additions & 7 deletions cula/ops/fwd_o_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
import cutlass.utils as utils
import cutlass.utils.blackwell_helpers as sm100_utils
import torch
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cute.nvgpu import OperandMajorMode, cpasync, tcgen05
from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream
from cutlass.cute.typing import Float32, Int32, Int64
from fla.ops.utils import prepare_chunk_indices
Expand Down Expand Up @@ -354,8 +354,9 @@ def __call__(
# B is MN-major because h_T GMEM has V(=N) contiguous
qh_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.K, # A: K-major (TMEM requires K-major)
tcgen05.OperandMajorMode.MN, # B: MN-major (V contiguous in GMEM)
self.io_dtype,
OperandMajorMode.K, # A: K-major (TMEM requires K-major)
OperandMajorMode.MN, # B: MN-major (V contiguous in GMEM)
self.acc_dtype,
self.cta_group,
self.qh_mma_tiler[:2],
Expand All @@ -366,8 +367,9 @@ def __call__(
# B is MN-major because v_T GMEM has V(=N) contiguous
av_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.K, # A: K-major (TMEM requires K-major)
tcgen05.OperandMajorMode.MN, # B: MN-major (V contiguous in GMEM)
self.io_dtype,
OperandMajorMode.K, # A: K-major (TMEM requires K-major)
OperandMajorMode.MN, # B: MN-major (V contiguous in GMEM)
self.acc_dtype,
self.cta_group,
self.av_mma_tiler[:2],
Expand Down Expand Up @@ -561,8 +563,9 @@ class SharedStorage:
# B operand majorness must match av_tiled_mma for C layout compatibility.
am_coord_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.MN,
self.io_dtype,
OperandMajorMode.K,
OperandMajorMode.MN,
self.acc_dtype,
self.cta_group,
(self.BT, self.BT),
Expand Down
11 changes: 4 additions & 7 deletions cula/ops/intrinsics_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@

import cutlass.cute as cute
from cutlass._mlir import ir as _ir_mod
from cutlass._mlir.dialects import arith as _arith
from cutlass._mlir.dialects import llvm
from cutlass._mlir.dialects import nvvm as _nvvm
from cutlass._mlir.dialects import vector as _vector
Expand Down Expand Up @@ -117,7 +116,6 @@ def _do(addr_val, *, loc=None, ip=None):
return _nvvm.tcgen05_ld(
res=vec_i32_ty,
shape=_nvvm.Tcgen05LdStShape.SHAPE_32X32B,
num=num,
tmem_addr=tmem_ptr,
loc=loc,
ip=ip,
Expand Down Expand Up @@ -154,9 +152,8 @@ def _do(addr_val, vec_val, *, loc=None, ip=None):
tmem_ptr = llvm.inttoptr(ptr6_ty, _to_ir(addr_val, loc, ip), loc=loc, ip=ip)
_nvvm.tcgen05_st(
shape=_nvvm.Tcgen05LdStShape.SHAPE_32X32B,
num=num,
tmem_addr=tmem_ptr,
r=_to_ir(vec_val, loc, ip),
val=_to_ir(vec_val, loc, ip),
loc=loc,
ip=ip,
)
Expand Down Expand Up @@ -279,12 +276,12 @@ def store_256b(gmem_ptr, vec):

@dsl_user_op
def _do(addr, v, *, loc=None, ip=None):
i32_ty = _ir_mod.IntegerType.get_signless(32)
ir_v = _to_ir(v, loc, ip)
elems = [
_vector.extractelement(
_vector.extract(
ir_v,
position=_arith.constant(i32_ty, i, loc=loc, ip=ip),
dynamic_position=[],
static_position=[i],
loc=loc,
ip=ip,
)
Expand Down
23 changes: 15 additions & 8 deletions cula/ops/kda_fully_fused_sm100_wip.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import cutlass.utils as utils
import cutlass.utils.blackwell_helpers as sm100_utils
import torch
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cute.nvgpu import OperandMajorMode, cpasync, tcgen05
from cutlass.cute.runtime import from_dlpack
from cutlass.cute.typing import Int32, Int64
from fla.modules.l2norm import l2norm_fwd
Expand Down Expand Up @@ -478,20 +478,21 @@ def __call__(
self.k_major_mode = utils.LayoutEnum.from_tensor(k).mma_major_mode()
self.v_major_mode = utils.LayoutEnum.from_tensor(v).mma_major_mode()
self.g_major_mode = utils.LayoutEnum.from_tensor(g).mma_major_mode() # NEW for KDA
self.k_major_mode_kv = tcgen05.OperandMajorMode.MN # For V^T*K, S dimension coalesced
self.k_major_mode_kv = OperandMajorMode.MN # For V^T*K, S dimension coalesced
# TMEM register output results as (D, C)
self.o_layout = utils.LayoutEnum.from_tensor(o)

if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K):
if cutlass.const_expr(self.q_major_mode != OperandMajorMode.K):
raise RuntimeError("The layout of q is not supported")
if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K):
if cutlass.const_expr(self.k_major_mode != OperandMajorMode.K):
raise RuntimeError("The layout of k is not supported")
if cutlass.const_expr(self.o_layout != utils.LayoutEnum.COL_MAJOR):
raise RuntimeError("The layout of o is not supported")
if cutlass.const_expr(self.k_major_mode == self.k_major_mode_kv):
raise RuntimeError("The layout of k & k^t should be different")

qk_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.q_dtype,
self.q_dtype,
Comment on lines +495 to 496

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For qk_tiled_mma, the operands are Q and K. The first two arguments should be self.q_dtype and self.k_dtype respectively, rather than passing self.q_dtype twice.

Suggested change
self.q_dtype,
self.q_dtype,
self.q_dtype,
self.k_dtype,

self.q_major_mode,
self.k_major_mode,
Expand All @@ -500,6 +501,7 @@ def __call__(
self.qk_mma_tiler[:2],
)
kk_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.k_dtype,
self.k_dtype,
# SHOULE BE both K-major
self.k_major_mode,
Expand All @@ -519,19 +521,21 @@ def __call__(
)
# State^T Q^T
sq_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
self.io_dtype,
Comment on lines +524 to 525

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For sq_tiled_mma, the operands are State and Q. The first two arguments should be self.io_dtype and self.q_dtype respectively.

Suggested change
self.io_dtype,
self.io_dtype,
self.io_dtype,
self.q_dtype,

# State is in TMEM, always K major, TODO
tcgen05.OperandMajorMode.K,
OperandMajorMode.K,
self.q_major_mode,
self.acc_dtype,
self.cta_group,
self.sq_mma_tiler[:2],
a_source=tcgen05.OperandSource.TMEM,
)
ks_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.io_dtype,
self.io_dtype,
Comment on lines +535 to 536

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For ks_tiled_mma, the operands are State and K. The first two arguments should be self.io_dtype and self.k_dtype respectively.

Suggested change
self.io_dtype,
self.io_dtype,
self.io_dtype,
self.k_dtype,

# State is in TMEM, always K major, TODO
tcgen05.OperandMajorMode.K,
OperandMajorMode.K,
# State is in TMEM, always K major, TODO
self.k_major_mode,
self.acc_dtype,
Expand All @@ -540,8 +544,9 @@ def __call__(
a_source=tcgen05.OperandSource.TMEM,
)

m_major_mode = tcgen05.OperandMajorMode.K
m_major_mode = OperandMajorMode.K
mv_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.v_dtype,
self.v_dtype,
self.v_major_mode,
m_major_mode,
Expand All @@ -550,8 +555,9 @@ def __call__(
self.mv_mma_tiler[:2],
)

p_major_mode = tcgen05.OperandMajorMode.K
p_major_mode = OperandMajorMode.K
vp_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.v_dtype,
self.v_dtype,
self.v_major_mode,
p_major_mode,
Expand Down Expand Up @@ -1459,6 +1465,7 @@ def kernel(
############################################
kv_mma_tiler2 = (self.kv_mma_tiler[0], self.kv_mma_tiler[1] // 2, self.kv_mma_tiler[2])
fake_kv_tiled_mma_acc32 = sm100_utils.make_trivial_tiled_mma(
self.k_dtype,
self.k_dtype,
Comment on lines +1468 to 1469

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For fake_kv_tiled_mma_acc32, the operands are V and K. The first two arguments should be self.v_dtype and self.k_dtype respectively.

Suggested change
self.k_dtype,
self.k_dtype,
self.v_dtype,
self.k_dtype,

self.v_major_mode,
self.k_major_mode_kv,
Expand Down
Loading