Add examples for MoE models - Mixtral in TE#2642
Add examples for MoE models - Mixtral in TE#2642faradawn wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Greptile OverviewGreptile SummaryThis PR adds a tutorial demonstrating how to integrate Mixtral's MoE (Mixture of Experts) layers with Transformer Engine's Key changes:
Critical issue found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant MoE as TEMixtralSparseMoeBlock
participant Router as Router (gate)
participant Permute as te.moe_permute
participant GLinear as GroupedLinear
participant Unpermute as te.moe_unpermute
User->>MoE: forward(hidden_states)
MoE->>MoE: Flatten to [num_tokens, hidden_dim]
MoE->>Router: Get expert assignments
Router-->>MoE: router_logits [num_tokens, num_experts]
MoE->>MoE: Compute routing_weights & select top_k experts
MoE->>Permute: moe_permute(tokens, selected_experts)
Permute-->>MoE: permuted_tokens, row_id_map
MoE->>MoE: Calculate m_splits per expert
MoE->>GLinear: experts_gate_up(permuted_tokens, m_splits)
GLinear-->>MoE: intermediate [combined gate+up projections]
MoE->>MoE: Apply SwiGLU: silu(gate) * up
MoE->>GLinear: experts_down(intermediate_act, m_splits)
GLinear-->>MoE: expert_outputs
MoE->>Unpermute: moe_unpermute(expert_outputs, row_id_map, routing_weights)
Unpermute-->>MoE: final_hidden_states
MoE->>MoE: Reshape to [batch, seq_len, hidden_dim]
MoE-->>User: final_hidden_states, router_logits
|
| " # Calculate m_splits: number of tokens assigned to each expert\n", | ||
| " m_splits = []\n", | ||
| " for expert_idx in range(self.num_experts):\n", | ||
| " expert_mask = (selected_experts == expert_idx).any(dim=-1)\n", | ||
| " m_splits.append(expert_mask.sum().item() * self.top_k)\n", |
There was a problem hiding this comment.
Logic error in m_splits calculation. The current approach counts tokens incorrectly by multiplying by top_k after already considering all expert assignments.
The issue: expert_mask already captures ALL tokens that selected this expert (across all top-k positions), so multiplying by self.top_k double-counts.
For example, if token 0 selects experts [1, 3] and token 1 selects experts [1, 2], then for expert 1: expert_mask will be [True, True] (sum=2). Multiplying by top_k=2 gives 4, but only 2 tokens actually go to expert 1.
| " # Calculate m_splits: number of tokens assigned to each expert\n", | |
| " m_splits = []\n", | |
| " for expert_idx in range(self.num_experts):\n", | |
| " expert_mask = (selected_experts == expert_idx).any(dim=-1)\n", | |
| " m_splits.append(expert_mask.sum().item() * self.top_k)\n", | |
| # Calculate m_splits: number of tokens assigned to each expert | |
| m_splits = [] | |
| for expert_idx in range(self.num_experts): | |
| expert_mask = (selected_experts == expert_idx).any(dim=-1) | |
| m_splits.append(expert_mask.sum().item()) |
| " permuted_tokens, row_id_map = te.moe_permute(\n", | ||
| " hidden_states_flat,\n", | ||
| " selected_experts.to(torch.int32),\n", | ||
| " num_out_tokens=None, # Auto-calculate\n", | ||
| " max_token_num=num_tokens\n", |
There was a problem hiding this comment.
Setting num_out_tokens to None is fine for auto-calculation, but when using top_k > 1, the expected output token count should be num_tokens times top_k since each token is routed to multiple experts.
Description
Create a MoE tutorial for TE. The model used is Mixtral 7B.
View the notebook better: https://github.com/faradawn/TransformerEngine/blob/add-moe-example/docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb
Fixes #2573
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: