Skip to content
Closed
10 changes: 10 additions & 0 deletions kernels/flash_attn_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _pointer_store(value: ir.Value, ptr: ir.Value):
return llvm.StoreOp(_llvm_value(value), _llvm_value(ptr))


@fx.source_loc_scope
def _waitcnt_vm_n(n):
"""Emit s_waitcnt vmcnt(n) only (lgkmcnt=63, expcnt=7)."""
val = (n & _VMCNT_LO_MASK) | _LGKMCNT_EXPCNT_BASE | (((n >> 4) & _VMCNT_HI_MASK) << _VMCNT_HI_SHIFT)
Expand Down Expand Up @@ -282,18 +283,23 @@ def flash_attn_func_kernel(
def _mfma(mfma_fn, a, b, c):
return mfma_fn(v16f32_type, [a, b, c])

@fx.source_loc_scope
def _fadd(a, b):
return arith.addf(_raw(a), _raw(b), fastmath=fm_fast)

@fx.source_loc_scope
def _fsub(a, b):
return arith.subf(_raw(a), _raw(b), fastmath=fm_fast)

@fx.source_loc_scope
def _fmul(a, b):
return arith.mulf(_raw(a), _raw(b), fastmath=fm_fast)

@fx.source_loc_scope
def _fmax(a, b):
return arith.MaxNumFOp(_raw(a), _raw(b), fastmath=fm_fast).result

@fx.source_loc_scope
def mfma_acc(a, b, c):
if const_expr(dtype_str == "bf16"):
if const_expr(USE_K16):
Expand Down Expand Up @@ -376,6 +382,7 @@ def _load_global_half_vec(ptr, base_idx, vec_elems: int):
gep = buffer_ops.get_element_ptr(ptr, fx.Int64(base_idx), elem_type=elem_type)
return _pointer_load(Vec.make_type(vec_elems, elem_dtype), gep)

@fx.source_loc_scope
def _store_global_half(ptr, base_idx, val):
gep = buffer_ops.get_element_ptr(ptr, fx.Int64(base_idx), elem_type=elem_type)
_pointer_store(val, gep)
Expand Down Expand Up @@ -524,6 +531,7 @@ def coop_store_v_lds(vecs, buf_id=0):
_dma_off = fx.Int32(0)
_dma_aux = fx.Int32(1)

@fx.source_loc_scope
def coop_dma_k(tile_start, buf_id=0):
"""Load K tile via DMA with XOR-swizzled global fetch."""
if const_expr(isinstance(buf_id, int)):
Expand Down Expand Up @@ -570,6 +578,7 @@ def _v_swizzle(row_idx, col_idx):
LANES_PER_V_ROW = HEAD_DIM * 2 // DMA_BYTES
ROWS_PER_DMA_BATCH_V = DMA_BATCH_BYTES // (HEAD_DIM * 2)

@fx.source_loc_scope
def coop_dma_v(tile_start, buf_id=0):
"""Load V tile via DMA with XOR-swizzled global fetch."""
v_lds_byte_base = lds_kv_base_idx + fx.Index((LDS_V_BASE + buf_id * LDS_V_TILE_SIZE) * 2)
Expand Down Expand Up @@ -1068,6 +1077,7 @@ def _k_idx_hi(ks):
_steps = [(dc, pks) for dc in range(D_CHUNKS) for pks in range(PV_K_STEPS)]
TOTAL_PV = len(_steps)

@fx.source_loc_scope
def _read_v_pack(step_idx):
dc, pks = _steps[step_idx]
if const_expr(USE_HW_TR):
Expand Down
2 changes: 2 additions & 0 deletions python/flydsl/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .gpu import *
from .derived import *
from .struct import *
from .meta import source_loc as source_loc
from .meta import source_loc_scope as source_loc_scope

from . import utils as utils
from . import arith as arith
Expand Down
13 changes: 12 additions & 1 deletion python/flydsl/expr/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .._mlir.dialects import gpu
from .._mlir.dialects._fly_enum_gen import AddressSpace
from ..compiler.protocol import dsl_align_of, dsl_size_of
from .meta import traced_op
from .numeric import Uint8
from .primitive import get_dyn_shared, make_ptr
from .struct import (
Expand All @@ -41,7 +42,17 @@
block_dim = Tuple3D(gpu.block_dim)
grid_dim = Tuple3D(gpu.grid_dim)

barrier = gpu.barrier

@traced_op
def barrier(*, address_spaces=None, loc=None, ip=None):
"""``gpu.barrier`` wrapped so ROCprof ATT maps it to its call site.

A bare ``gpu.barrier()`` in a kernel body emits with no location and inherits the
coarse kernel-default location. ``@traced_op`` resolves ``loc`` (explicit > enclosing
``source_loc`` pin > caller line) and stamps the emitted op via the ambient location.
"""
return gpu.barrier(address_spaces=address_spaces, loc=loc, ip=ip)


_int = int

Expand Down
60 changes: 59 additions & 1 deletion python/flydsl/expr/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2025 FlyDSL Project Contributors

import inspect
import threading
from functools import wraps

from .._mlir import ir
Expand Down Expand Up @@ -52,12 +53,69 @@ def _caller_location(depth=1):
return ir.Location.name(label, childLoc=file_loc)


_scope = threading.local()


def _pinned_loc():
"""Return the call-site Location pinned by an enclosing ``source_loc`` scope, or None."""
return getattr(_scope, "cur", None)


class source_loc:
"""Pin the user call-site location for every op emitted inside the ``with`` block.

A kernel helper that emits device ops is itself one (or more) Python frames above the
user's scheduling line, so ``traced_op``'s ``_caller_location(depth=1)`` would attribute
those ops to the helper body instead of the call site. Wrapping the helper body in
``with source_loc():`` captures the caller once and makes both untraced ODS builders
(via MLIR's ambient location) and ``traced_op`` leaves (via the pin) resolve to it.

Re-entrant: a nested ``source_loc`` is a no-op so the outermost scope wins.
"""

def __init__(self):
self._own = _pinned_loc() is None
# depth 2: hop past __init__ and the helper/wrapper frame to the user call site.
self._loc = _caller_location(2) if self._own else None

def __enter__(self):
if self._own:
self._loc.__enter__()
_scope.cur = self._loc
return self

def __exit__(self, *exc):
if self._own:
try:
self._loc.__exit__(*exc)
finally:
_scope.cur = None
return False
Comment on lines +87 to +93


def source_loc_scope(fn):
"""Decorator form of :class:`source_loc` for a kernel helper.

Runs the whole helper body inside ``source_loc()`` so every op it emits attributes to
the helper's call site, without reindenting the body. The ``wrapper`` frame occupies
the same stack slot the helper body otherwise would, so the captured location resolves
to the helper's call site.
"""

@wraps(fn)
def wrapper(*args, **kwargs):
with source_loc():
return fn(*args, **kwargs)

return wrapper


def traced_op(op):
@wraps(op)
def wrapper(*args, **kwargs):
loc = kwargs.pop("loc", None)
if loc is None:
loc = _caller_location(depth=1)
loc = _pinned_loc() or _caller_location(depth=1)
args, kwargs = _flatten_args(args, kwargs)
with loc:
return op(*args, **kwargs)
Expand Down
45 changes: 37 additions & 8 deletions python/flydsl/expr/rocdl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
_ods_cluster_load_async_to_lds_b128 = cluster_load_async_to_lds_b128
_ods_s_wait_asynccnt = s_wait_asynccnt
_ods_readfirstlane = readfirstlane
_ods_s_waitcnt = s_waitcnt
_ods_sched_barrier = sched_barrier
_ods_sched_group_barrier = sched_group_barrier
_ods_mfma_f32_32x32x8f16 = globals().get("mfma_f32_32x32x8f16", None)
_ods_mfma_f32_32x32x8bf16_1k = globals().get("mfma_f32_32x32x8bf16_1k", None)
_ods_mfma_f32_32x32x16_f16 = globals().get("mfma_f32_32x32x16_f16", None)
Expand All @@ -50,20 +53,46 @@
mask_dswr = 0x200


def sched_mfma(cnt):
sched_group_barrier(mask_mfma, cnt, 0)
# Synchronization / scheduling primitives are wrapped with @traced_op so ROCprof ATT
# maps them to the user's kernel line instead of the coarse kernel-default location.
# traced_op resolves loc (explicit > source_loc pin > caller line) and stamps the op via
# the ambient location; their args are compile-time int masks, so the eager arg-unwrap is
# a no-op (unlike the high-level helpers, which must use source_loc_scope instead).


def sched_vmem(cnt):
sched_group_barrier(mask_vmem_rd, cnt, 0)
@traced_op
def s_waitcnt(bitfield, *, loc=None, ip=None):
return _ods_s_waitcnt(bitfield, loc=loc, ip=ip)
Comment on lines +64 to +65
Copy link
Copy Markdown
Collaborator

@sjfeng1999 sjfeng1999 May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks not good. Can't we use @traced_op here to capture location info?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just get back. Let me quickly check traced_op here for this func.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call — switched these to @traced_op (0f4041c). Their args are compile-time int masks (no SSA operands), so traced_op's eager arg-unwrap is a no-op, and it matches the mfma_f32_* wrappers already in this file. Kept source_loc_scope only on the high-level helpers (coop_dma_k/mfma_acc/…) where the unwrap would raise ValueError on aggregate/Constexpr operands. Verified the attribution is unchanged (same distinct source lines, MFMA still on call sites) and debug-off ISA stays byte-identical.



@traced_op
def sched_barrier(mask, *, loc=None, ip=None):
return _ods_sched_barrier(mask, loc=loc, ip=ip)


def sched_dsrd(cnt):
sched_group_barrier(mask_dsrd, cnt, 0)
@traced_op
def sched_group_barrier(mask, size, group_id, *, loc=None, ip=None):
return _ods_sched_group_barrier(mask, size, group_id, loc=loc, ip=ip)
Comment on lines +64 to +75


@traced_op
def sched_mfma(cnt, *, loc=None, ip=None):
return _ods_sched_group_barrier(mask_mfma, cnt, 0, loc=loc, ip=ip)


def sched_dswr(cnt):
sched_group_barrier(mask_dswr, cnt, 0)
@traced_op
def sched_vmem(cnt, *, loc=None, ip=None):
return _ods_sched_group_barrier(mask_vmem_rd, cnt, 0, loc=loc, ip=ip)


@traced_op
def sched_dsrd(cnt, *, loc=None, ip=None):
return _ods_sched_group_barrier(mask_dsrd, cnt, 0, loc=loc, ip=ip)


@traced_op
def sched_dswr(cnt, *, loc=None, ip=None):
return _ods_sched_group_barrier(mask_dswr, cnt, 0, loc=loc, ip=ip)


def _unwrap_mfma_operand(v, *, loc=None):
Expand Down
Loading