Skip to content

feat(moe): add moe support and fused topk & moe kernels#37

Open
MikanAffine wants to merge 5 commits intoSJTU-DENG-Lab:mainfrom
MikanAffine:fusedmoe
Open

feat(moe): add moe support and fused topk & moe kernels#37
MikanAffine wants to merge 5 commits intoSJTU-DENG-Lab:mainfrom
MikanAffine:fusedmoe

Conversation

@MikanAffine
Copy link
Copy Markdown

@MikanAffine MikanAffine commented Mar 31, 2026

Feature:

  • add MoE support to SDAR-MOE
  • add expert parallelism support
  • refactored MoE code structure like SGLang style
  • add fused TopK and fused MoE triton kernel, and unit tests

Summary by CodeRabbit

  • New Features

    • Added support for expert-parallel distributed training alongside tensor parallelism
    • Implemented optimized fused kernel implementations for Mixture-of-Experts computation
    • Introduced configurable expert parallelism size parameter
  • Refactor

    • Restructured Mixture-of-Experts implementation with new dispatcher and runner abstractions
    • Unified model parallelism initialization and metadata management

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 31, 2026

📝 Walkthrough

Walkthrough

This PR introduces a comprehensive refactoring of the Mixture-of-Experts (MoE) implementation. It replaces the monolithic SparseMoEBlock with a modular FusedMoE architecture, adds Triton-optimized kernels for fused MoE computation and top-k routing, establishes a pluggable dispatcher/runner/router framework, integrates expert parallelism support into distributed training, and consolidates model parallelism metadata management across the codebase.

Changes

Cohort / File(s) Summary
MoE Core Implementation
diffulex/moe/__init__.py, diffulex/moe/moe_impl.py, diffulex/moe/layers.py
Removed SparseMoEBlock from exports and file system; replaced with new FusedMoE class that integrates top-k routing, token dispatching, expert execution, and result combination in a single modular layer.
Triton Kernel Backends
diffulex_kernel/python/fused_moe_triton.py, diffulex_kernel/python/fused_topk_triton.py, diffulex_kernel/__init__.py
Added high-performance Triton implementations for fused MoE forward pass (including gated MLP and weight application) and fused top-k routing with softmax/sigmoid scoring and renormalization support. Extended kernel lazy-loader to expose both functions.
MoE Dispatcher Framework
diffulex/moe/dispatcher/__init__.py, diffulex/moe/dispatcher/base.py, diffulex/moe/dispatcher/datatype.py, diffulex/moe/dispatcher/trivial.py
Introduced abstract TokenDispatcher base class and TrivialTokenDispatcher implementation. Defined DispatchOutput and CombineInput dataclasses to standardize dispatcher I/O. Added factory build_dispatcher(...) for runtime selection.
MoE Runner Framework
diffulex/moe/runner/__init__.py, diffulex/moe/runner/base.py, diffulex/moe/runner/triton.py, diffulex/moe/runner/trivial.py
Introduced abstract MoERunner base class with _all_reduce_output_if_needed(...) helper. Implemented TritonFusedMoERunner for kernel-accelerated expert computation and TrivialMoERunner for reference CPU-like execution. Added factory build_runner(...).
MoE Top-K Router Framework
diffulex/moe/topk/__init__.py, diffulex/moe/topk/base.py, diffulex/moe/topk/datatype.py, diffulex/moe/topk/bypass.py, diffulex/moe/topk/trivial.py
Extracted TopKOutput dataclass and TopKRouter abstract base into separate modules. Refactored TrivialTopKRouter to inherit from base. Added BypassTopKRouter to skip top-k selection. Added factory build_topk_router(...).
Model Parallelism Infrastructure
diffulex/utils/parallelism.py, diffulex/utils/checkpoint.py
New parallelism.py module provides ModelParallelismMetadata dataclass, global initialization/reset/accessor functions for TP/EP rank/size, and init_process_group(...) with layout validation. New checkpoint.py defines LoadContext and ResolvedWeight immutable dataclasses for weight resolution.
Distributed Training Integration
diffulex/engine/dp_worker.py, diffulex/engine/model_runner.py, diffulex/engine/tp_worker.py
Updated DP/TP workers to compute combined model-parallel world size from tensor_parallel_size and new expert_parallel_size. Refactored model_runner.py to use centralized init_process_group(...) and init_model_parallelism_metadata(...) helpers instead of direct torch.distributed calls.
Model/Layer TP Integration
diffulex/layer/embed_head.py, diffulex/layer/linear.py, diffulex/model/dream.py, diffulex/model/fast_dllm_v2.py, diffulex/model/llada.py, diffulex/model/sdar.py
Replaced direct torch.distributed.get_rank()/get_world_size() calls with get_tp_rank() and get_tp_world_size() from new parallelism module, centralizing TP configuration lookup.
Configuration & Weight Loading
diffulex/config.py, diffulex/utils/loader.py
Added expert_parallel_size: int = 1 field to Config with validation. Extended weight loading pipeline in loader.py with resolve_weight_spec(...) and apply_resolved_weight(...) functions to support per-module checkpoint-weight resolution hooks, enabling flexible expert-weight sharding/slicing.
Sampler Registration
diffulex/sampler/sdar.py
Extended SDARSampler registration with AutoSampler to include additional "sdar_moe" key alongside existing "sdar".
Repository Config
.gitignore
Added ignore patterns for .claude and .codex directories.

Sequence Diagram(s)

sequenceDiagram
    participant Input as Input Token<br/>Sequence
    participant FusedMoE as FusedMoE<br/>Layer
    participant Router as TopK<br/>Router
    participant Dispatcher as Token<br/>Dispatcher
    participant Runner as MoE<br/>Runner
    participant Output as Output<br/>Sequence

    Input->>FusedMoE: hidden_states
    FusedMoE->>Router: router_logits
    Router->>Router: compute top-k<br/>routing scores
    Router->>Dispatcher: TopKOutput<br/>(ids, weights)
    Dispatcher->>Dispatcher: map tokens to<br/>local experts
    Dispatcher->>Runner: DispatchOutput
    Runner->>Runner: execute experts<br/>fused_moe(...) or<br/>per-expert MLP
    Runner->>Runner: all_reduce if<br/>expert_parallel
    Runner->>Output: CombineInput<br/>(combined_hidden_states)
    Output->>FusedMoE: (final_hidden_states,<br/>router_logits)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

This PR spans multiple interconnected subsystems (MoE architecture, kernel implementations, distributed parallelism metadata, checkpoint loading, and model integration). It requires careful review of: new abstract base classes and factory patterns across dispatcher/runner/router frameworks; high-density Triton kernel code with GEMM tiling and masking logic; parallelism metadata initialization and validation with cross-process consistency checks; complex weight resolution and tensor slicing for expert-parallel sharding; and distributed training coordination changes affecting multiple worker types.

Poem

🐰 The MoE hops with newfound grace,
Fused kernels racing through compute space,
Experts dispatched with purpose true,
Top-k routers guide the way through,
Parallel wisdom, sharded and bright—
Mixture refined to pure delight!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 34.88% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly summarizes the main feature additions: MoE support and fused TopK/MoE kernels. It accurately reflects the primary changes across multiple modules and aligns well with the PR objectives.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (3)
diffulex/moe/__init__.py (1)

10-16: Remove the dead fallback or make it a real branch.

Line 15 returns unconditionally, so Line 16 can never execute. If SparseMoEBlock is still meant to be a fallback, this needs an availability/config check instead of an unreachable return.

Suggested cleanup
 def build_mlp_or_moe(config, layer_idx: int, dense_factory):
     """Build a dense MLP or MoE block according to the config."""
     if is_moe_layer(config, layer_idx):
         return FusedSparseMoEBlock.from_config(config)
-        return SparseMoEBlock.from_config(config)
     return dense_factory()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/__init__.py` around lines 10 - 16, The function build_mlp_or_moe
currently returns FusedSparseMoEBlock.from_config(config) unconditionally,
making the subsequent return of SparseMoEBlock.from_config(config) dead code;
either remove the unreachable fallback or implement a real branch that chooses
SparseMoEBlock when FusedSparseMoEBlock is unavailable. Update build_mlp_or_moe
to check availability (e.g., try/except ImportError or a feature flag) before
calling FusedSparseMoEBlock.from_config(config) and only call
SparseMoEBlock.from_config(config) when the fused implementation is not
available, or delete the redundant return line if the fused block is the sole
supported implementation; reference the symbols build_mlp_or_moe,
FusedSparseMoEBlock, SparseMoEBlock, and is_moe_layer when making the change.
diffulex_kernel/__init__.py (1)

42-50: Keep the lazy-export surface consistent.

fused_topk is exposed through both __getattr__ and __all__, but fused_moe is only exposed through __getattr__ while Line 62 keeps it commented out. If fused_moe is public, export it consistently; if it is private, hiding it in one place and exposing it in another is confusing.

Also applies to: 55-63

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/__init__.py` around lines 42 - 50, The lazy-export surface is
inconsistent: __getattr__ exposes fused_topk and fused_moe but __all__ only
lists fused_topk (fused_moe is commented out). Make the exports consistent by
either adding "fused_moe" to the module-level __all__ list (or uncommenting the
existing entry) if it should be public, or remove/deny export in __getattr__ for
fused_moe if it should be private; update the code paths around __getattr__, the
fused_topk and fused_moe import lines, and the __all__ definition so both
symbols are treated the same way.
diffulex/moe/topk.py (1)

9-10: Lazy-import the Triton backend.

TopKRouter(impl="torch") and topk_pytorch_reference() do not need the kernel package, but the module-level from diffulex_kernel import fused_topk makes diffulex.moe.topk depend on that stack at import time anyway. If diffulex_kernel resolves the Triton module eagerly, CPU/reference users will fail before they can ever select the torch path.

♻️ Suggested change
-from diffulex_kernel import fused_topk
-
 def topk_pytorch_reference(
     router_logits: torch.Tensor,
     top_k: int,
@@
         if impl == "torch":
             self.impl = topk_pytorch_reference
         elif impl == "triton":
+            from diffulex_kernel import fused_topk
             self.impl = fused_topk
         else:
             raise ValueError(f"Unsupported impl: {impl!r}")

Also applies to: 59-64

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk.py` around lines 9 - 10, The module currently imports
fused_topk at top-level causing an eager dependency; change to lazy-import
diffulex_kernel.fused_topk only where needed: move the import into the code
paths that actually call it (e.g., inside TopKRouter implementation branch that
selects the Triton backend and inside the function that calls fused_topk), so
TopKRouter(impl="torch") and topk_pytorch_reference() can be imported without
resolving diffulex_kernel; ensure you reference fused_topk by importing it
locally just before use and keep existing function/class names (TopKRouter,
topk_pytorch_reference, fused_topk) to locate the spots to modify.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@diffulex_kernel/python/fused_moe_triton.py`:
- Around line 21-23: Rename the single-letter kernel dimension `I` to a
descriptive name (e.g., `I_dim` or `INPUT_SIZE`) both in the Triton kernel
signature (the `I: tl.constexpr` parameter) and in the wrapper/local variables
that reference it so Ruff E741 is resolved; update every usage (pointer
arithmetic, index calculations, and any calls that pass the argument) to the new
name, including the other occurrences noted around the later blocks (the local
wrapper variable and the other kernel signatures/usages), ensuring all
references (e.g., kernel definition, launch invocation, and any local variable
named `I`) are consistently renamed.
- Around line 266-343: Replace the fragile asserts in _run_fused_moe_kernels
(and validate in fused_moe caller) with explicit runtime checks that raise clear
exceptions: verify w13 and w2 are 3D tensors, check w13.shape[1] % 2 == 0 and
compute I = w13.shape[1] // 2 only after that, ensure w13.shape[2] == H and
w2.shape == (E, H, I), confirm topk_ids and topk_weights are 2D with identical
shapes and that topk_ids.shape[0] == M; raise ValueError (or TypeError) with
descriptive messages naming the offending tensor (w13, w2, topk_ids,
topk_weights) so kernel launches fail fast with clear Python errors.

In `@diffulex/engine/dllm_block.py`:
- Around line 36-47: The file defines __getstate__ and __setstate__ multiple
times which causes the later/older definitions to shadow the new weakref-based
serialization path; consolidate to a single pair of methods that implement the
weakref handling: keep the implementations that use weakref_fn and convert
s['_req'] and s['_dllm_block_buffer'] to/from weak references, remove or merge
the older duplicate __getstate__/__setstate__ definitions so only one canonical
implementation remains, and ensure the final __setstate__/__getstate__ reference
the weakref_fn helper and the _req and _dllm_block_buffer attributes
consistently.

In `@diffulex/moe/fused_moe.py`:
- Around line 29-38: Constructor currently accepts arbitrary hidden_act causing
instantiation of a fused MoE that only supports "silu"; update the validation to
fail fast by checking hidden_act in __init__ (and the alternate
constructor/loader referenced at lines ~148-156, e.g., from_config or similar
factory) and raise a clear ValueError if hidden_act != "silu" that points users
to use the unfused MoE block instead; ensure the check is implemented early in
the FusedMoE initialization path (reference symbols: __init__, hidden_act,
from_config) so unsupported configs never create a fused instance.

In `@test/python/kernel/test_fused_moe.py`:
- Around line 12-18: Rename the ambiguous dimension name `I` to a descriptive
name (e.g., `intermediate_size` or `intermediate_dim`) everywhere in this test
module: update the function signature of fused_moe_pytorch_reference (change
comments and type hints for w13 and w2), update usages inside
fused_moe_pytorch_reference, update the helper function `_run_test` and any
local variables or test locals that use `I` (including the later locals around
lines 58-69) so all occurrences are consistently renamed and lint E741 is
resolved.
- Around line 375-395: The test named test_determinism does not actually
guarantee determinism because top_k=2 allows atomic_add race-induced FP32
variations; update the test to either set top_k=1 (change the local top_k
variable to 1 so fused_moe runs without expert conflicts and true determinism is
validated) or rename the test (e.g., test_approximation) and its docstring to
reflect it verifies bounded numerical closeness for fused_moe with top_k=2; also
update the test name/docstring and any inline comment accordingly so readers and
CI expectations match the chosen behavior.

---

Nitpick comments:
In `@diffulex_kernel/__init__.py`:
- Around line 42-50: The lazy-export surface is inconsistent: __getattr__
exposes fused_topk and fused_moe but __all__ only lists fused_topk (fused_moe is
commented out). Make the exports consistent by either adding "fused_moe" to the
module-level __all__ list (or uncommenting the existing entry) if it should be
public, or remove/deny export in __getattr__ for fused_moe if it should be
private; update the code paths around __getattr__, the fused_topk and fused_moe
import lines, and the __all__ definition so both symbols are treated the same
way.

In `@diffulex/moe/__init__.py`:
- Around line 10-16: The function build_mlp_or_moe currently returns
FusedSparseMoEBlock.from_config(config) unconditionally, making the subsequent
return of SparseMoEBlock.from_config(config) dead code; either remove the
unreachable fallback or implement a real branch that chooses SparseMoEBlock when
FusedSparseMoEBlock is unavailable. Update build_mlp_or_moe to check
availability (e.g., try/except ImportError or a feature flag) before calling
FusedSparseMoEBlock.from_config(config) and only call
SparseMoEBlock.from_config(config) when the fused implementation is not
available, or delete the redundant return line if the fused block is the sole
supported implementation; reference the symbols build_mlp_or_moe,
FusedSparseMoEBlock, SparseMoEBlock, and is_moe_layer when making the change.

In `@diffulex/moe/topk.py`:
- Around line 9-10: The module currently imports fused_topk at top-level causing
an eager dependency; change to lazy-import diffulex_kernel.fused_topk only where
needed: move the import into the code paths that actually call it (e.g., inside
TopKRouter implementation branch that selects the Triton backend and inside the
function that calls fused_topk), so TopKRouter(impl="torch") and
topk_pytorch_reference() can be imported without resolving diffulex_kernel;
ensure you reference fused_topk by importing it locally just before use and keep
existing function/class names (TopKRouter, topk_pytorch_reference, fused_topk)
to locate the spots to modify.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2ad030d3-f366-4b9c-9e61-c3dd60f9c15e

📥 Commits

Reviewing files that changed from the base of the PR and between 9ead055 and 79f111a.

📒 Files selected for processing (15)
  • diffulex/engine/dllm_block.py
  • diffulex/mixin/multi_block/engine/request.py
  • diffulex/model/sdar_moe.py
  • diffulex/moe/__init__.py
  • diffulex/moe/fused_moe.py
  • diffulex/moe/moe_impl.py
  • diffulex/moe/topk.py
  • diffulex/sampler/sdar.py
  • diffulex/utils/loader.py
  • diffulex_kernel/__init__.py
  • diffulex_kernel/python/fused_moe_triton.py
  • diffulex_kernel/python/fused_topk_triton.py
  • pyproject.toml
  • test/python/kernel/test_fused_moe.py
  • test/python/kernel/test_fused_topk.py

Comment on lines +36 to +47
def __getstate__(self):
s = self.__dict__.copy()
s['_req'] = s['_req']()
if '_dllm_block_buffer' in s:
s['_dllm_block_buffer'] = s['_dllm_block_buffer']()
return s

def __setstate__(self, state):
s = self.__dict__ = state.copy()
s['_req'] = weakref_fn(s['_req'])
if '_dllm_block_buffer' in s:
s['_dllm_block_buffer'] = weakref_fn(s['_dllm_block_buffer'])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

These new pickle hooks are shadowed by the older ones below.

Lines 73-80 and 223-229 redefine __getstate__/__setstate__, so Python drops the versions added here. That makes the new weakref rehydration path unreachable and leaves two conflicting serialization implementations in the same file.

Also applies to: 198-205

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/engine/dllm_block.py` around lines 36 - 47, The file defines
__getstate__ and __setstate__ multiple times which causes the later/older
definitions to shadow the new weakref-based serialization path; consolidate to a
single pair of methods that implement the weakref handling: keep the
implementations that use weakref_fn and convert s['_req'] and
s['_dllm_block_buffer'] to/from weak references, remove or merge the older
duplicate __getstate__/__setstate__ definitions so only one canonical
implementation remains, and ensure the final __setstate__/__getstate__ reference
the weakref_fn helper and the _req and _dllm_block_buffer attributes
consistently.

Comment on lines +29 to +38
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_experts: int,
top_k: int,
*,
hidden_act: str = "silu",
norm_topk_prob: bool = True,
) -> None:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Fail fast on unsupported activations.

from_config() forwards arbitrary config.hidden_act, but the backend currently only supports "silu". Right now an unsupported model will instantiate fine and then abort on its first forward. Guard it here or route those configs to the unfused MoE block.

♻️ Suggested change
         self.num_experts = num_experts
         self.top_k = top_k
         self.norm_topk_prob = norm_topk_prob
-        self.hidden_act = hidden_act
+        if hidden_act != "silu":
+            raise ValueError(
+                "FusedSparseMoEBlock only supports hidden_act='silu'"
+            )
+        self.hidden_act = hidden_act

Also applies to: 148-156

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/fused_moe.py` around lines 29 - 38, Constructor currently
accepts arbitrary hidden_act causing instantiation of a fused MoE that only
supports "silu"; update the validation to fail fast by checking hidden_act in
__init__ (and the alternate constructor/loader referenced at lines ~148-156,
e.g., from_config or similar factory) and raise a clear ValueError if hidden_act
!= "silu" that points users to use the unfused MoE block instead; ensure the
check is implemented early in the FusedMoE initialization path (reference
symbols: __init__, hidden_act, from_config) so unsupported configs never create
a fused instance.

Comment on lines +12 to +18
def fused_moe_pytorch_reference(
hidden_states: torch.Tensor, # (M, H)
w13: torch.Tensor, # (E, 2*I, H)
w2: torch.Tensor, # (E, H, I)
topk_weights: torch.Tensor, # (M, top_k)
topk_ids: torch.Tensor, # (M, top_k)
top_k: int,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Rename I across this file.

Ruff is already flagging every I here as E741, so the new test module will stay lint-red until the dimension name is expanded (intermediate_size, intermediate_dim, etc.). The same rename needs to be applied in _run_test and the later test locals.

Also applies to: 58-69

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@test/python/kernel/test_fused_moe.py` around lines 12 - 18, Rename the
ambiguous dimension name `I` to a descriptive name (e.g., `intermediate_size` or
`intermediate_dim`) everywhere in this test module: update the function
signature of fused_moe_pytorch_reference (change comments and type hints for w13
and w2), update usages inside fused_moe_pytorch_reference, update the helper
function `_run_test` and any local variables or test locals that use `I`
(including the later locals around lines 58-69) so all occurrences are
consistently renamed and lint E741 is resolved.

Comment on lines +375 to +395
def test_determinism():
"""Same inputs → same outputs across two runs."""
device = "cuda"
M, H, I, E, top_k = 64, 128, 64, 8, 2

torch.manual_seed(42)
hidden_states = torch.randn(M, H, device=device) * 0.1
w13 = torch.randn(E, 2 * I, H, device=device) * 0.1
w2 = torch.randn(E, H, I, device=device) * 0.1
router_logits = torch.randn(M, E, device=device)
scores = torch.softmax(router_logits, dim=-1)
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids.to(torch.int32)

out1 = fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, top_k)
out2 = fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, top_k)

# Note: atomic_add may have non-deterministic ordering for fp32,
# but values should still be very close
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "*.py" | grep -E "(fused_moe|test_fused)" | head -20

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 217


🏁 Script executed:

find . -type f -name "*fused_moe*" | head -20

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 177


🏁 Script executed:

cat -n ./diffulex_kernel/python/fused_moe_triton.py | head -150

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 6333


🏁 Script executed:

rg "atomic_add" ./diffulex_kernel/python/fused_moe_triton.py -A 3 -B 3

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 351


🏁 Script executed:

cat -n ./test/python/kernel/test_fused_moe.py | sed -n '370,400p'

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 1621


🏁 Script executed:

cat -n ./diffulex/moe/fused_moe.py | head -100

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 3809


🏁 Script executed:

rg "determinism|atomic" ./test/python/kernel/test_fused_moe.py -B 2 -A 2

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 497


🏁 Script executed:

wc -l ./diffulex_kernel/python/fused_moe_triton.py

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 114


🏁 Script executed:

cat -n ./diffulex_kernel/python/fused_moe_triton.py | sed -n '150,180p'

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 1515


🏁 Script executed:

rg "top_k" ./test/python/kernel/test_fused_moe.py | head -20

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 960


🏁 Script executed:

rg "deterministic|CUDA_LAUNCH_BLOCKING" . --type py -l

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 216


🏁 Script executed:

cat -n ./diffulex_kernel/python/fused_moe_triton.py | sed -n '175,185p'

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 533


Rename or adjust this test—it does not validate determinism.

With top_k=2 and tl.atomic_add(), thread ordering is non-deterministic on GPU, causing rounding variations. The test mitigates this with relaxed tolerances (atol=1e-5, rtol=1e-5), making it a bounded-approximation test, not a determinism test. Either:

  1. Use top_k=1 to eliminate expert conflicts and actual determinism, or
  2. Rename to test_approximation() or similar to reflect what it actually validates.
🧰 Tools
🪛 Ruff (0.15.7)

[error] 378-378: Ambiguous variable name: I

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@test/python/kernel/test_fused_moe.py` around lines 375 - 395, The test named
test_determinism does not actually guarantee determinism because top_k=2 allows
atomic_add race-induced FP32 variations; update the test to either set top_k=1
(change the local top_k variable to 1 so fused_moe runs without expert conflicts
and true determinism is validated) or rename the test (e.g., test_approximation)
and its docstring to reflect it verifies bounded numerical closeness for
fused_moe with top_k=2; also update the test name/docstring and any inline
comment accordingly so readers and CI expectations match the chosen behavior.

…oading

- model-specific param remapping is not enough to support MoE weight repacking
- provides support for distinguishing between TP and EP world sizes and ranks
- add trivial and a2a token dispatcher for EP
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
diffulex/moe/topk/trivial.py (1)

21-22: ⚠️ Potential issue | 🟡 Minor

Potential division by zero in renormalization.

If topk_weights sums to zero (possible with extreme numerical edge cases), this division produces NaN. Consider adding a small epsilon for robustness, consistent with the Triton kernel which uses tl.maximum(selected_sum, 1e-20).

Suggested fix
         if self.renormalize:
-            topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
+            topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True).clamp(min=1e-20)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk/trivial.py` around lines 21 - 22, The renormalization step
in trivial.py can divide by zero when topk_weights.sum(...) == 0; modify the
block in the method where self.renormalize is used to compute a safe denominator
(e.g., denom = topk_weights.sum(dim=-1, keepdim=True).clamp_min(1e-20) or
torch.maximum(..., torch.tensor(1e-20, device=topk_weights.device))) and then do
topk_weights = topk_weights / denom so the division is robust to zero sums (use
1e-20 to match the Triton kernel).
♻️ Duplicate comments (1)
diffulex_kernel/python/fused_moe_triton.py (1)

41-51: ⚠️ Potential issue | 🟡 Minor

Missing validation: topk_ids.shape[0] must equal hidden_states.shape[0].

The validation checks that topk_ids and topk_weights have matching shapes, but does not verify that the number of tokens in topk_ids matches hidden_states. A mismatch would cause incorrect indexing in the kernel.

🛡️ Proposed fix
     if topk_ids.shape != topk_weights.shape:
         raise ValueError(
             f"topk_ids and topk_weights must have the same shape, got {topk_ids.shape} and {topk_weights.shape}."
         )
+    if topk_ids.shape[0] != hidden_states.shape[0]:
+        raise ValueError(
+            f"topk_ids must have one row per token, got {topk_ids.shape[0]} rows but {hidden_states.shape[0]} tokens."
+        )
     if w13.shape[0] != w2.shape[0] or w13.shape[2] != hidden_states.shape[1]:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_moe_triton.py` around lines 41 - 51, The code
validates shapes of topk_ids/topk_weights and weight matrices but misses
verifying that the number of tokens in topk_ids matches hidden_states, which can
break indexing; add a validation after the existing topk_ids/topk_weights check
that asserts topk_ids.shape[0] == hidden_states.shape[0] and raises a ValueError
with a clear message referencing topk_ids.shape and hidden_states.shape so
callers can see the mismatch (place this check alongside the other shape
validations in the same block inside fused_moe_triton where topk_ids,
topk_weights, and hidden_states are validated).
🧹 Nitpick comments (16)
diffulex/moe/topk/bypass.py (2)

7-8: Typo in docstring.

Minor: "implemenation" → "implementation".

✏️ Proposed fix
 class BypassTopKRouter(TopKRouter):
-    """Bypass implemenation, use this if fused moe runner also handles topk"""
+    """Bypass implementation, use this if fused moe runner also handles topk."""
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk/bypass.py` around lines 7 - 8, Fix the typo in the
BypassTopKRouter class docstring: change "Bypass implemenation, use this if
fused moe runner also handles topk" to "Bypass implementation, use this if fused
moe runner also handles topk" in the BypassTopKRouter (subclass of TopKRouter)
docstring.

10-15: Ensure correct router/runner pairing at runtime.

This bypass router returns weights=None and ids=None, which will cause TritonFusedMoERunner to raise a RuntimeError (per diffulex/moe/runner/triton.py:10-14). The design is intentional for runners that handle top-k internally, but consider adding a note in the docstring about compatible runners to prevent misconfiguration.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk/bypass.py` around lines 10 - 15, The bypass router's
forward currently returns weights=None and ids=None which is intentional for
routers that let runners perform top-k, but it lacks documentation and can lead
to a confusing RuntimeError when paired with runners that expect ids/weights
(see TritonFusedMoERunner). Update the forward method's docstring in the bypass
router to clearly state that returning None for weights and ids is intentional
and that this router must only be used with runners that implement internal
top-k (e.g., TritonFusedMoERunner-like runners); additionally, add an optional
runtime compatibility check in forward or router initialization to raise a clear
error message if paired with an incompatible runner instead of allowing the
downstream RuntimeError.
diffulex/utils/loader.py (2)

157-169: Consider using direct attribute access instead of getattr.

Static analysis (B009) flags getattr(param, "weight_loader") with a constant attribute. Since the missing attribute case is already handled by the try/except block, you could use direct attribute access for clarity.

♻️ Proposed fix
             try:
                 param = model.get_parameter(param_name)
-                weight_loader = partial(
-                    getattr(param, "weight_loader"),
-                    param,
-                    loaded_weight,
-                )
+                weight_loader = partial(param.weight_loader, param, loaded_weight)
                 if shard_id is None:
                     weight_loader()
                 else:
                     weight_loader(shard_id)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/utils/loader.py` around lines 157 - 169, The code uses
getattr(param, "weight_loader") inside loader logic though
AttributeError/KeyError is already caught; change to direct attribute access
param.weight_loader when building the partial (in the block that calls
model.get_parameter and constructs weight_loader) to satisfy static analysis and
improve clarity while keeping the same behavior — continue to wrap it in the
existing try/except, create the partial with param.weight_loader, param,
loaded_weight, and then call weight_loader() or weight_loader(shard_id)
depending on shard_id.

139-139: Prefer explicit exception over assert for runtime validation.

Using assert for runtime validation can be disabled with -O flag. Consider raising ValueError for clearer error handling.

🛡️ Proposed fix
-            assert v == "lm_head"
+            if v != "lm_head":
+                raise ValueError(f"Expected 'lm_head' for transformer.ff_out mapping, got '{v}'")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/utils/loader.py` at line 139, Replace the runtime assert "assert v
== 'lm_head'" with an explicit check that raises a ValueError when the condition
fails; specifically, change it to: if v != "lm_head": raise
ValueError(f"Expected 'lm_head' for variable v, got {v!r}"). This ensures the
validation cannot be disabled with -O and provides a clear, descriptive error
message including the actual value.
diffulex/moe/topk/triton.py (1)

2-2: Remove unused import.

torch.nn.functional as F is imported but never used in this file.

🧹 Proposed fix
 import torch
-import torch.nn.functional as F
 
 from diffulex.moe.topk.base import TopKRouter
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk/triton.py` at line 2, Remove the unused import "import
torch.nn.functional as F" from triton.py; locate the import statement in the top
of the file (the symbol to remove is "torch.nn.functional as F") and delete it
so the module no longer includes an unused dependency and to satisfy linting.
diffulex/moe/runner/__init__.py (2)

19-24: Consider sorting __all__ for consistency.

Per static analysis (RUF022), __all__ is not sorted. Sorting improves readability and makes diffs cleaner when adding new exports.

🧹 Proposed fix
 __all__ = [
     "MoERunner",
-    "TrivialMoERunner",
     "TritonFusedMoERunner",
+    "TrivialMoERunner",
     "build_runner",
 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/runner/__init__.py` around lines 19 - 24, The __all__ export
list is unsorted; reorder the list assigned to __all__ so its string entries are
in lexicographical order (e.g., place "MoERunner", "TritonFusedMoERunner",
"TrivialMoERunner", "build_runner" sorted appropriately), updating the __all__
assignment to the sorted sequence to satisfy RUF022 and keep exports consistent.

15-16: Include the invalid impl value in the error message.

When an unsupported backend is requested, the error should indicate what value was passed for easier debugging. This is consistent with the error message pattern in build_dispatcher.

🔧 Proposed fix
     elif impl == "triton":
         return TritonFusedMoERunner(*args, **kwargs)
     else:
-        raise NotImplementedError
+        raise NotImplementedError(f"Unsupported runner backend: {impl!r}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/runner/__init__.py` around lines 15 - 16, The else branch
raising NotImplementedError in __init__.py should include the actual invalid
impl value for clearer debugging; update the error raised in that branch (the
raise in the same block that handles backend selection and mirrors
build_dispatcher’s pattern) to include impl in the message (e.g., "Unsupported
impl: {impl}") so callers see what value was passed.
diffulex/moe/runner/trivial.py (1)

1-5: Missing torch import.

The file uses dispatch_output.hidden_states.new_zeros(...) which works via the tensor method, but adding import torch would be more explicit and consistent with other files.

Suggested addition
 from __future__ import annotations
+import torch
 import torch.nn.functional as F
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/runner/trivial.py` around lines 1 - 5, The module is missing an
explicit torch import used when creating new tensors (e.g.,
dispatch_output.hidden_states.new_zeros(...)); add "import torch" to trivial.py
so tensor creation is explicit and consistent with other modules that use torch,
and ensure any references in MoERunner, CombineInput, or DispatchOutput code
that rely on torch remain unchanged.
diffulex/moe/topk/__init__.py (2)

19-20: Improve error message for unsupported implementations.

The NotImplementedError should include the invalid impl value and available options for easier debugging.

Suggested improvement
     else:
-        raise NotImplementedError
+        raise NotImplementedError(
+            f"Unknown top-k router implementation: {impl!r}. "
+            f"Supported: 'trivial', 'bypass', 'triton'."
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk/__init__.py` around lines 19 - 20, The current else branch
raises a bare NotImplementedError; update the else in
diffulex.moe.topk.__init__.py to raise a NotImplementedError that includes the
invalid impl value and the list of supported implementations (e.g.,
f"Unsupported impl '{impl}'; supported: {supported_list}") so callers can see
what was passed and what options exist; locate the else that currently does
raise NotImplementedError and replace it with an informative message referencing
the impl variable and the available implementations constant or keys from the
dispatch map used in this module.

23-30: Consider sorting __all__ for consistency.

Ruff flagged that __all__ is not sorted (RUF022). While optional, alphabetical ordering improves maintainability.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk/__init__.py` around lines 23 - 30, The __all__ list in this
module is not alphabetically sorted (RUF022); reorder the entries in the __all__
variable so they are in alphabetical order (e.g., "BypassTopKRouter",
"TopKOutput", "TopKRouter", "TritonFusedTopKRouter", "TrivialTopKRouter",
"build_topk_router") to satisfy the lint rule and improve maintainability;
update the __all__ definition accordingly.
diffulex/engine/model_runner.py (1)

60-61: Redundant assignments overwritten later.

self.world_size and self.rank are assigned here from config values, but they are immediately overwritten at lines 84-85 from the layout object. Consider removing these initial assignments to avoid confusion.

Suggested cleanup
-        self.world_size = config.tensor_parallel_size
-        self.rank = rank
         self.event = event
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/engine/model_runner.py` around lines 60 - 61, The assignments to
self.world_size and self.rank from config at the start of ModelRunner.__init__
are redundant because they are overwritten later from layout; remove the initial
assignments (self.world_size = config.tensor_parallel_size and self.rank = rank)
to avoid confusion and keep the single authoritative source (layout) — update or
delete those two lines in ModelRunner.__init__ and ensure any subsequent logic
relies on the values set from layout rather than the earlier config-derived
values.
diffulex/moe/runner/base.py (1)

36-38: Weight tensors not registered as parameters.

w13 and w2 are stored as plain attributes rather than nn.Parameter or registered buffers. This means they won't appear in state_dict() or be moved by .to(). If this is intentional (e.g., weights managed externally by the loader), consider adding a brief comment explaining the design choice.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/runner/base.py` around lines 36 - 38, w13 and w2 are being
stored as plain attributes (self.w13, self.w2) so they won't be included in
state_dict() or moved with .to(); update the class initializer to register them
properly: if these are trainable weights wrap them with torch.nn.Parameter and
assign to self.w13/self.w2, otherwise register them as buffers via
self.register_buffer('w13', w13) / self.register_buffer('w2', w2); if the
current design intentionally keeps them external, add a short comment next to
local_expert_start/w13/w2 explaining that these tensors are externally managed
and intentionally not registered so reviewers understand the choice.
diffulex_kernel/python/fused_topk_triton.py (1)

48-63: Output validation adds safety at runtime cost.

The _validate_fused_topk_outputs function performs GPU-to-CPU synchronization (detach().cpu()) on every call, which can impact performance. Consider making this validation optional (e.g., via an environment variable or debug flag) for production use.

Suggested change
+import os
+
+_VALIDATE_OUTPUTS = os.environ.get("DIFFULEX_VALIDATE_TOPK_OUTPUTS", "0") == "1"
+
 def _validate_fused_topk_outputs(
     topk_weights: torch.Tensor,
     topk_ids: torch.Tensor,
     *,
     num_experts: int,
 ) -> None:
+    if not _VALIDATE_OUTPUTS:
+        return
     invalid_id_mask = (topk_ids < 0) | (topk_ids >= num_experts)
     ...
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_topk_triton.py` around lines 48 - 63, The
runtime validation in _validate_fused_topk_outputs causes GPU->CPU sync
(detach().cpu()) on every call; make this validation conditional by introducing
a toggle (e.g., module-level boolean FUSED_TOPK_VALIDATE_OUTPUTS set from an
environment variable like "FUSED_TOPK_VALIDATE_OUTPUTS" or a debug flag) and
only perform the invalid_id_mask check, detach().cpu(), and isfinite checks when
that toggle is true; update the function to early-return if validation is
disabled and ensure the env-var default is disabled in production but can be
enabled for testing/debugging.
diffulex/moe/dispatcher/trivial.py (1)

67-70: Strict expert ID matching may be overly restrictive.

The current logic raises if active_expert_ids doesn't exactly match the expected list. Consider allowing a subset for scenarios where only specific experts need processing (e.g., during debugging or selective expert evaluation).

Alternative: allow subset
         expert_ids = list(active_expert_ids)
-        if expert_ids != expected:
-            raise ValueError(f"Expected active_expert_ids={expected}, got {expert_ids}.")
+        if not set(expert_ids).issubset(set(expected)):
+            raise ValueError(
+                f"active_expert_ids must be a subset of {expected}, got {expert_ids}."
+            )
         return expert_ids
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/dispatcher/trivial.py` around lines 67 - 70, Change the strict
equality check between active_expert_ids and expected to allow expected to be a
subset: compute expert_ids = list(active_expert_ids), then validate with
something like if expected and not set(expected).issubset(set(expert_ids)):
raise ValueError(...); finally return only the relevant experts (e.g., filter
expert_ids to those in expected if expected is provided, otherwise return all
expert_ids). Update the error message in the raise to reflect that expected must
be a subset when applicable; reference variables active_expert_ids, expected,
expert_ids and the existing ValueError raise site.
diffulex_kernel/python/fused_moe_triton.py (2)

131-132: Early return shape may be inconsistent with weight dimensions.

When hidden_states.shape[0] == 0, the function returns hidden_states.new_zeros((0, w2.shape[1])). However, w2.shape[1] is hidden_size (the output dimension of down_proj). This is correct, but when topk_ids.shape[1] == 0 (zero top-k), returning zeros is semantically correct since no experts contribute.

Consider adding a comment clarifying that w2.shape[1] == hidden_size to improve readability.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_moe_triton.py` around lines 131 - 132, The
early-return when hidden_states.shape[0] == 0 or topk_ids.shape[1] == 0 returns
a zero tensor using w2.shape[1]; clarify this by adding an inline comment in
fused_moe_triton.py next to the return explaining that w2.shape[1] is the model
hidden_size (the output dim of down_proj), and that returning zeros is correct
when there are no tokens or no top-k experts; keep the existing return value but
add the comment referencing hidden_states, topk_ids, and w2 for readability.

101-112: 3D tensor load may cause high register pressure.

The weight load creates a (BLOCK_M, BLOCK_N, BLOCK_K) tensor in registers before reduction. With BLOCK_M=8, BLOCK_N=64, BLOCK_K=32, this allocates 16K elements per thread block. This approach works but may limit occupancy on memory-bound workloads.

Consider whether a more traditional 2D tiled GEMM with per-row expert lookup would reduce register pressure, though this is acceptable for a draft implementation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_moe_triton.py` around lines 101 - 112, The
current tl.load builds a full (BLOCK_M, BLOCK_N, BLOCK_K) tensor (see the load
using w_ptr + expert_offsets... + current_k * stride_wk and mask combining
row_mask, valid_expert, offs_n, k_mask) which raises register pressure; change
the implementation to tile along the K dimension and load smaller 2D slices
instead (e.g., loop over K sub-blocks: load (BLOCK_M, BLOCK_K) or (BLOCK_N,
BLOCK_K) chunks per iteration, multiply-accumulate into acc incrementally, and
only keep a small b-slice in registers per loop) so you perform partial
reductions across K rather than materializing the full 3D block in registers.
Ensure masks (row_mask, valid_expert, offs_n < num_cols, k_mask) and strides
(stride_we, stride_wn, stride_wk) are applied per sub-load and that the
accumulation acc += tl.sum(a[:, None, :] * b, axis=2) still operates on the
smaller b-slice.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@diffulex/utils/loader.py`:
- Line 1: The file contains a UTF-8 BOM at the start (before the first token
"import os"); remove the leading BOM character so the file is plain UTF-8
without BOM and re-save loader.py (ensure your editor/IDE writes UTF-8 without
BOM) so the module import line ("import os") and the rest of the file no longer
include the hidden character.

---

Outside diff comments:
In `@diffulex/moe/topk/trivial.py`:
- Around line 21-22: The renormalization step in trivial.py can divide by zero
when topk_weights.sum(...) == 0; modify the block in the method where
self.renormalize is used to compute a safe denominator (e.g., denom =
topk_weights.sum(dim=-1, keepdim=True).clamp_min(1e-20) or torch.maximum(...,
torch.tensor(1e-20, device=topk_weights.device))) and then do topk_weights =
topk_weights / denom so the division is robust to zero sums (use 1e-20 to match
the Triton kernel).

---

Duplicate comments:
In `@diffulex_kernel/python/fused_moe_triton.py`:
- Around line 41-51: The code validates shapes of topk_ids/topk_weights and
weight matrices but misses verifying that the number of tokens in topk_ids
matches hidden_states, which can break indexing; add a validation after the
existing topk_ids/topk_weights check that asserts topk_ids.shape[0] ==
hidden_states.shape[0] and raises a ValueError with a clear message referencing
topk_ids.shape and hidden_states.shape so callers can see the mismatch (place
this check alongside the other shape validations in the same block inside
fused_moe_triton where topk_ids, topk_weights, and hidden_states are validated).

---

Nitpick comments:
In `@diffulex_kernel/python/fused_moe_triton.py`:
- Around line 131-132: The early-return when hidden_states.shape[0] == 0 or
topk_ids.shape[1] == 0 returns a zero tensor using w2.shape[1]; clarify this by
adding an inline comment in fused_moe_triton.py next to the return explaining
that w2.shape[1] is the model hidden_size (the output dim of down_proj), and
that returning zeros is correct when there are no tokens or no top-k experts;
keep the existing return value but add the comment referencing hidden_states,
topk_ids, and w2 for readability.
- Around line 101-112: The current tl.load builds a full (BLOCK_M, BLOCK_N,
BLOCK_K) tensor (see the load using w_ptr + expert_offsets... + current_k *
stride_wk and mask combining row_mask, valid_expert, offs_n, k_mask) which
raises register pressure; change the implementation to tile along the K
dimension and load smaller 2D slices instead (e.g., loop over K sub-blocks: load
(BLOCK_M, BLOCK_K) or (BLOCK_N, BLOCK_K) chunks per iteration,
multiply-accumulate into acc incrementally, and only keep a small b-slice in
registers per loop) so you perform partial reductions across K rather than
materializing the full 3D block in registers. Ensure masks (row_mask,
valid_expert, offs_n < num_cols, k_mask) and strides (stride_we, stride_wn,
stride_wk) are applied per sub-load and that the accumulation acc += tl.sum(a[:,
None, :] * b, axis=2) still operates on the smaller b-slice.

In `@diffulex_kernel/python/fused_topk_triton.py`:
- Around line 48-63: The runtime validation in _validate_fused_topk_outputs
causes GPU->CPU sync (detach().cpu()) on every call; make this validation
conditional by introducing a toggle (e.g., module-level boolean
FUSED_TOPK_VALIDATE_OUTPUTS set from an environment variable like
"FUSED_TOPK_VALIDATE_OUTPUTS" or a debug flag) and only perform the
invalid_id_mask check, detach().cpu(), and isfinite checks when that toggle is
true; update the function to early-return if validation is disabled and ensure
the env-var default is disabled in production but can be enabled for
testing/debugging.

In `@diffulex/engine/model_runner.py`:
- Around line 60-61: The assignments to self.world_size and self.rank from
config at the start of ModelRunner.__init__ are redundant because they are
overwritten later from layout; remove the initial assignments (self.world_size =
config.tensor_parallel_size and self.rank = rank) to avoid confusion and keep
the single authoritative source (layout) — update or delete those two lines in
ModelRunner.__init__ and ensure any subsequent logic relies on the values set
from layout rather than the earlier config-derived values.

In `@diffulex/moe/dispatcher/trivial.py`:
- Around line 67-70: Change the strict equality check between active_expert_ids
and expected to allow expected to be a subset: compute expert_ids =
list(active_expert_ids), then validate with something like if expected and not
set(expected).issubset(set(expert_ids)): raise ValueError(...); finally return
only the relevant experts (e.g., filter expert_ids to those in expected if
expected is provided, otherwise return all expert_ids). Update the error message
in the raise to reflect that expected must be a subset when applicable;
reference variables active_expert_ids, expected, expert_ids and the existing
ValueError raise site.

In `@diffulex/moe/runner/__init__.py`:
- Around line 19-24: The __all__ export list is unsorted; reorder the list
assigned to __all__ so its string entries are in lexicographical order (e.g.,
place "MoERunner", "TritonFusedMoERunner", "TrivialMoERunner", "build_runner"
sorted appropriately), updating the __all__ assignment to the sorted sequence to
satisfy RUF022 and keep exports consistent.
- Around line 15-16: The else branch raising NotImplementedError in __init__.py
should include the actual invalid impl value for clearer debugging; update the
error raised in that branch (the raise in the same block that handles backend
selection and mirrors build_dispatcher’s pattern) to include impl in the message
(e.g., "Unsupported impl: {impl}") so callers see what value was passed.

In `@diffulex/moe/runner/base.py`:
- Around line 36-38: w13 and w2 are being stored as plain attributes (self.w13,
self.w2) so they won't be included in state_dict() or moved with .to(); update
the class initializer to register them properly: if these are trainable weights
wrap them with torch.nn.Parameter and assign to self.w13/self.w2, otherwise
register them as buffers via self.register_buffer('w13', w13) /
self.register_buffer('w2', w2); if the current design intentionally keeps them
external, add a short comment next to local_expert_start/w13/w2 explaining that
these tensors are externally managed and intentionally not registered so
reviewers understand the choice.

In `@diffulex/moe/runner/trivial.py`:
- Around line 1-5: The module is missing an explicit torch import used when
creating new tensors (e.g., dispatch_output.hidden_states.new_zeros(...)); add
"import torch" to trivial.py so tensor creation is explicit and consistent with
other modules that use torch, and ensure any references in MoERunner,
CombineInput, or DispatchOutput code that rely on torch remain unchanged.

In `@diffulex/moe/topk/__init__.py`:
- Around line 19-20: The current else branch raises a bare NotImplementedError;
update the else in diffulex.moe.topk.__init__.py to raise a NotImplementedError
that includes the invalid impl value and the list of supported implementations
(e.g., f"Unsupported impl '{impl}'; supported: {supported_list}") so callers can
see what was passed and what options exist; locate the else that currently does
raise NotImplementedError and replace it with an informative message referencing
the impl variable and the available implementations constant or keys from the
dispatch map used in this module.
- Around line 23-30: The __all__ list in this module is not alphabetically
sorted (RUF022); reorder the entries in the __all__ variable so they are in
alphabetical order (e.g., "BypassTopKRouter", "TopKOutput", "TopKRouter",
"TritonFusedTopKRouter", "TrivialTopKRouter", "build_topk_router") to satisfy
the lint rule and improve maintainability; update the __all__ definition
accordingly.

In `@diffulex/moe/topk/bypass.py`:
- Around line 7-8: Fix the typo in the BypassTopKRouter class docstring: change
"Bypass implemenation, use this if fused moe runner also handles topk" to
"Bypass implementation, use this if fused moe runner also handles topk" in the
BypassTopKRouter (subclass of TopKRouter) docstring.
- Around line 10-15: The bypass router's forward currently returns weights=None
and ids=None which is intentional for routers that let runners perform top-k,
but it lacks documentation and can lead to a confusing RuntimeError when paired
with runners that expect ids/weights (see TritonFusedMoERunner). Update the
forward method's docstring in the bypass router to clearly state that returning
None for weights and ids is intentional and that this router must only be used
with runners that implement internal top-k (e.g., TritonFusedMoERunner-like
runners); additionally, add an optional runtime compatibility check in forward
or router initialization to raise a clear error message if paired with an
incompatible runner instead of allowing the downstream RuntimeError.

In `@diffulex/moe/topk/triton.py`:
- Line 2: Remove the unused import "import torch.nn.functional as F" from
triton.py; locate the import statement in the top of the file (the symbol to
remove is "torch.nn.functional as F") and delete it so the module no longer
includes an unused dependency and to satisfy linting.

In `@diffulex/utils/loader.py`:
- Around line 157-169: The code uses getattr(param, "weight_loader") inside
loader logic though AttributeError/KeyError is already caught; change to direct
attribute access param.weight_loader when building the partial (in the block
that calls model.get_parameter and constructs weight_loader) to satisfy static
analysis and improve clarity while keeping the same behavior — continue to wrap
it in the existing try/except, create the partial with param.weight_loader,
param, loaded_weight, and then call weight_loader() or weight_loader(shard_id)
depending on shard_id.
- Line 139: Replace the runtime assert "assert v == 'lm_head'" with an explicit
check that raises a ValueError when the condition fails; specifically, change it
to: if v != "lm_head": raise ValueError(f"Expected 'lm_head' for variable v, got
{v!r}"). This ensures the validation cannot be disabled with -O and provides a
clear, descriptive error message including the actual value.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 984d317a-ff55-46a2-a604-ddeb568c046d

📥 Commits

Reviewing files that changed from the base of the PR and between 79f111a and c962e83.

📒 Files selected for processing (35)
  • .gitignore
  • diffulex/config.py
  • diffulex/engine/dp_worker.py
  • diffulex/engine/model_runner.py
  • diffulex/engine/tp_worker.py
  • diffulex/layer/embed_head.py
  • diffulex/layer/linear.py
  • diffulex/model/dream.py
  • diffulex/model/fast_dllm_v2.py
  • diffulex/model/llada.py
  • diffulex/model/sdar.py
  • diffulex/moe/__init__.py
  • diffulex/moe/dispatcher/__init__.py
  • diffulex/moe/dispatcher/base.py
  • diffulex/moe/dispatcher/datatype.py
  • diffulex/moe/dispatcher/trivial.py
  • diffulex/moe/layers.py
  • diffulex/moe/moe_impl.py
  • diffulex/moe/runner/__init__.py
  • diffulex/moe/runner/base.py
  • diffulex/moe/runner/triton.py
  • diffulex/moe/runner/trivial.py
  • diffulex/moe/topk/__init__.py
  • diffulex/moe/topk/base.py
  • diffulex/moe/topk/bypass.py
  • diffulex/moe/topk/datatype.py
  • diffulex/moe/topk/triton.py
  • diffulex/moe/topk/trivial.py
  • diffulex/sampler/sdar.py
  • diffulex/utils/checkpoint.py
  • diffulex/utils/loader.py
  • diffulex/utils/parallelism.py
  • diffulex_kernel/__init__.py
  • diffulex_kernel/python/fused_moe_triton.py
  • diffulex_kernel/python/fused_topk_triton.py
💤 Files with no reviewable changes (1)
  • diffulex/moe/moe_impl.py
✅ Files skipped from review due to trivial changes (4)
  • .gitignore
  • diffulex/moe/topk/datatype.py
  • diffulex/layer/embed_head.py
  • diffulex/utils/checkpoint.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • diffulex/sampler/sdar.py
  • diffulex/moe/init.py
  • diffulex_kernel/init.py

@@ -1,4 +1,4 @@
import os
import os
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove BOM character from file.

The file starts with a UTF-8 BOM (byte order mark: ), which is unusual for Python source files and can cause issues with some tools. This was also flagged by static analysis (EXE002).

🔧 Proposed fix
-import os
+import os
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import os
import os
🧰 Tools
🪛 Ruff (0.15.9)

[warning] 1-1: The file is executable but no shebang is present

(EXE002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/utils/loader.py` at line 1, The file contains a UTF-8 BOM at the
start (before the first token "import os"); remove the leading BOM character so
the file is plain UTF-8 without BOM and re-save loader.py (ensure your
editor/IDE writes UTF-8 without BOM) so the module import line ("import os") and
the rest of the file no longer include the hidden character.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant