Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
578 changes: 578 additions & 0 deletions benchmarks/bench_kda_decode_mtp.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions cula/kda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,19 @@
from cula.kda.chunk import chunk_kda
from cula.kda.hopper_fused_fwd import cula_kda_prefill as kda_prefill_hopper
from cula.ops.kda_decode import fused_sigmoid_gating_delta_rule_update, kda_decode
from cula.ops.kda_decode_mtp import (
kda_decode_mtp,
kda_decode_mtp_small_batch,
kda_decode_mtp_ws,
)

__all__ = [
"chunk_kda",
"kda_prefill_blackwell",
"kda_decode",
"kda_decode_mtp",
"kda_decode_mtp_ws",
"kda_decode_mtp_small_batch",
"fused_sigmoid_gating_delta_rule_update",
"kda_prefill_hopper",
]
8 changes: 8 additions & 0 deletions cula/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@
# limitations under the License.

from cula.ops.kda_decode import fused_sigmoid_gating_delta_rule_update, kda_decode
from cula.ops.kda_decode_mtp import (
kda_decode_mtp,
kda_decode_mtp_small_batch,
kda_decode_mtp_ws,
)
from cula.ops.la_decode import linear_attention_decode

__all__ = [
"kda_decode",
"kda_decode_mtp",
"kda_decode_mtp_ws",
"kda_decode_mtp_small_batch",
"fused_sigmoid_gating_delta_rule_update",
"linear_attention_decode",
]
16 changes: 15 additions & 1 deletion cula/ops/kda_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def _try_fast_dense_decode(
softplus_threshold: float,
out: torch.Tensor | None,
state_layout: str | None,
opt_level: int = 1,
):
"""Fast path for the common dense decode case used by the benchmark.

Expand Down Expand Up @@ -267,6 +268,7 @@ def _try_fast_dense_decode(
dense_small_hv_parallel=dense_small_hv_parallel,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
opt_level=opt_level,
)
compiled_kernel(
cu_seqlens_to_use,
Expand Down Expand Up @@ -1552,12 +1554,18 @@ def _get_compiled_kernel(
dense_small_hv_parallel,
softplus_beta,
softplus_threshold,
opt_level=1,
):
"""Get or lazily compile one CuteDSL decode kernel variant.

Compile-time specialization is still important here, so we cache the result
by shape, layout, and constexpr options. The compiled function is emitted
with TVM-FFI enabled so runtime calls can pass torch tensors directly.

``opt_level`` selects the CuTe DSL ``--opt-level`` (codegen optimization;
NOT a kernel constexpr). It is part of the cache key so the same shape can
be compiled at multiple opt-levels without colliding. Default 1 keeps the
historical behavior; 2/3 are experiments (see issue 17 compile-knob tuning).
"""
global _compiled_kernels

Expand All @@ -1578,6 +1586,7 @@ def _get_compiled_kernel(
dense_small_hv_parallel,
softplus_beta,
softplus_threshold,
opt_level,
)
if key in _compiled_kernels:
return _compiled_kernels[key]
Expand Down Expand Up @@ -1656,7 +1665,7 @@ def _get_compiled_kernel(
num_blocks_per_state_small=num_blocks_per_state_small,
dense_small_hv_parallel=dense_small_hv_parallel,
stream=stream,
options="--enable-tvm-ffi --opt-level 1",
options=f"--enable-tvm-ffi --opt-level {opt_level}",
)

_compiled_kernels[key] = compiled_kernel
Expand Down Expand Up @@ -1809,6 +1818,7 @@ def fused_sigmoid_gating_delta_rule_update(
is_kda: bool = False,
out: torch.Tensor | None = None,
state_layout: str = "vk",
opt_level: int = 1,
):
"""Public cuLA decode API backed by CuTe DSL.

Expand Down Expand Up @@ -1839,6 +1849,7 @@ def fused_sigmoid_gating_delta_rule_update(
softplus_threshold=softplus_threshold,
out=out,
state_layout=state_layout,
opt_level=opt_level,
)


Expand All @@ -1859,6 +1870,7 @@ def kda_decode(
softplus_threshold: float = 20.0,
out: torch.Tensor | None = None,
state_layout: str = "vk",
opt_level: int = 1,
) -> torch.Tensor:
"""CuTe DSL implementation of fused sigmoid gating KDA update.

Expand Down Expand Up @@ -1911,6 +1923,7 @@ def kda_decode(
softplus_threshold,
out,
state_layout,
opt_level,
)
if fast_dense_out is not None:
return fast_dense_out
Expand Down Expand Up @@ -2074,6 +2087,7 @@ def kda_decode(
dense_small_hv_parallel=dense_small_hv_parallel,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
opt_level=opt_level,
)

# With TVM-FFI enabled at compile time, the runtime launch can pass torch
Expand Down
Loading