Skip to content

[Feat] Fuse TopKGatingSoftmax and MoE Sorting kernels#582

Open
amd-wsung102 wants to merge 9 commits into
ROCm:mainfrom
amd-wsung102:fuse_topk_sorting_updated
Open

[Feat] Fuse TopKGatingSoftmax and MoE Sorting kernels#582
amd-wsung102 wants to merge 9 commits into
ROCm:mainfrom
amd-wsung102:fuse_topk_sorting_updated

Conversation

@amd-wsung102
Copy link
Copy Markdown
Contributor

@amd-wsung102 amd-wsung102 commented May 28, 2026

Motivation

The topk_gating_softmax_kernel.py kernel and moe_sorting_kernel.py kernel can be fused for improved performance across eager mode, graph mode, and raw kernel time.

Relevant Files

kernels/moe_sorting_kernel.py - added fused topk and sorting
tests/kernels/test_moe_sorting.py - unit test for the fused kernels

Additional Details

The fusion applies to the decode path in moe_sorting and only for number of tokens T=16 and T<16. For T > 16, the fusion doesn't yield improvements and this is an ongoing investigation, and a future PR can be created to tackle this issue.

Test Result - DeepSeek-R1: E=256, topk=8, model_dim=7168, bf16

All time are in us

  • Eager: 2-2.4x improvement
  • Graph: 1.12-1.14x improvement
  • Raw kernel: 1.3-1.4x improvement
T unfused_eager fused_eager unfused_graph fused_graph unfused_kernel fused_kernel eager speedup graph speedup kernel speedup
1 32 13.1 16.2 14.3 13.5 10.3 2.44 1.13 1.32
2 32.5 16 16.2 14.4 14 10.5 2.03 1.12 1.34
4 32.5 16.2 16.7 14.8 14.5 11 2 1.13 1.32
8 33.7 15.9 17.3 15.4 15 11.4 2.11 1.12 1.32
12 34 14.9 19.3 16.9 17.1 12.2 2.28 1.14 1.4
16 33.9 15.4 19.7 17.4 17.7 12.9 2.2 1.13 1.37

Test Result - GPT-OSS 120B: E=128, topk=4, model_dim=2880, bf16

All time are in us

  • Eager: 2.13-2.24x improvement
  • Graph: 1.21-1.27x improvement
  • Raw kernel: 1.3-1.4x improvement
T unfused_eager fused_eager unfused_graph fused_graph unfused_kernel fused_kernel eager speedup graph speedup kernel speedup
1 31.2 14.4 13.3 10.9 9.9 7.1 2.16 1.22 1.39
2 31.4 14.6 13.5 11.1 10.2 7.3 2.15 1.21 1.39
4 31.7 14.8 13.8 11.3 10.4 7.5 2.14 1.21 1.38
8 31.6 14.8 13.9 11.4 10.5 7.6 2.13 1.21 1.38
12 32.8 14.6 15.6 12.3 12.2 8.5 2.24 1.27 1.44
16 33.4 14.9 15.8 12.5 12.3 8.7 2.24 1.27 1.42

Submission Checklist

Comment thread kernels/moe_sorting_kernel.py Outdated


@contextmanager
def _if_then(if_op):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need this scf if?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused kernel needs the explicit scf.IfOp and the helper function _if_then because if the plain python if is used, the kernel JIT fails:

File ".../kernels/moe_sorting_kernel.py", line 1242, in __then_1
for _z in range(_zs, _ze, _z1):
TypeError: 'ArithValue' object cannot be interpreted as an integer

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused kernel needs the explicit scf.IfOp and the helper function _if_then because if the plain python if is used, the kernel JIT fails:

File ".../kernels/moe_sorting_kernel.py", line 1242, in __then_1
for _z in range(_zs, _ze, _z1):
TypeError: 'ArithValue' object cannot be interpreted as an integer

I've fixed a similar bug before, but I'm not sure if it's the same issue. Could you please simplify the test case that's causing the error and let me reproduce it?

Copy link
Copy Markdown
Contributor Author

@amd-wsung102 amd-wsung102 May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xudoyuan, you may use this commit 5627ef2, which uses the regular python if instead of scf if. Then, run pytest tests/kernels/test_moe_sorting.py::test_moe_softmax_sort_fused_oneshot -k "1-256-8-bf16".

It will show this

FAILED tests/kernels/test_moe_sorting.py::test_moe_softmax_sort_fused_oneshot[1-256-8-bf16] - TypeError: 'ArithValue' object cannot be interpreted as an integer

After the regular python if is switched to scf if, like in commit 07ea93d, the pytest doesn't show the error anymore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants