-
Notifications
You must be signed in to change notification settings - Fork 633
[PyTorch] Add grouped linear op and experimental fusion for grouped MLP #2622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Review suggestion from @greptile-apps Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, 4 comments
| quantizer=fc2_input_quantizers[group_idx], | ||
| requires_grad=False, | ||
| with_gemm_swizzled_scales=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect grad-required flags
In ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.fuser_forward, swiglu_ctx.input_requires_grad and swiglu_ctx.extra_input_requires_grad are set to True unconditionally (and input_requires_grad is set to requires_grad unconditionally). This will make ScaledSwiGLU.fuser_backward compute grad_input and grad_extra_input even when neither input_ nor scales require grads, which violates autograd semantics and can raise (e.g., scales.detach() passed into the fused kernel, but extra_input_requires_grad=True forces a gradient).
This should be set based on the actual requirements:
input_requires_grad = input_.requires_gradswiglu_ctx.extra_input_requires_grad = scales.requires_grad- and for FC weights, check each parameter’s
requires_grad(not justweight0).
| # Return immediately if fused kernel is not supported | ||
| if not BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): | ||
| return ops | ||
|
|
||
| # Check if recipe is supported | ||
| if recipe is None: | ||
| return ops | ||
| if not recipe.mxfp8(): | ||
| return ops | ||
|
|
||
| # Scan through ops, fusing if possible | ||
| out = [] | ||
| window, ops = ops[:3], ops[3:] | ||
| while len(window) == 3: | ||
|
|
||
| # Check if window matches pattern | ||
| matches_pattern = True | ||
| if not ( | ||
| isinstance(window[0], GroupedLinear) | ||
| and isinstance(window[1], ScaledSwiGLU) | ||
| and isinstance(window[2], GroupedLinear) | ||
| ): | ||
| matches_pattern = False | ||
| elif window[0].has_bias or window[2].has_bias: | ||
| matches_pattern = False | ||
| elif window[0].num_groups != window[2].num_groups: | ||
| matches_pattern = False | ||
| elif ( | ||
| window[0].in_features % 256 != 0 | ||
| or window[0].out_features % 256 != 0 | ||
| or window[2].in_features % 256 != 0 | ||
| or window[2].out_features % 256 != 0 | ||
| ): | ||
| matches_pattern = False | ||
| elif window[1].glu_interleave_size != 32: | ||
| matches_pattern = False | ||
|
|
||
| if matches_pattern: | ||
| # Construct fused op if window matches pattern | ||
| op = BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8( | ||
| fc1=window[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broken fusion window scan
Both fuse_backward_ops and fuse_forward_ops have a window/shift loop that can drop or reorder ops when the pattern doesn’t match. In the non-matching branch you do out.extend(window[:-2]); window = window[-2:] and then immediately do out.extend(window[:-3]) (which is a no-op for a 2-element window) before refilling. This causes the scan to advance by 1 op in some cases and by 2 in others, and it never emits window[-1] until the very end. For sequences like [A,B,C,D] where [A,B,C] doesn’t match but [B,C,D] would (or vice versa), this loop will not correctly consider all 3-op windows and can produce an incorrect fused op list.
This needs a standard sliding-window approach (advance by 1 when not matching; replace 3->1 when matching) to ensure no ops are skipped or duplicated.
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Fusible operation for bias.""" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect module docstring
transformer_engine/pytorch/ops/basic/grouped_linear.py starts with """Fusible operation for bias.""", which is a copy/paste from basic/bias.py and is incorrect for this file. This impacts generated docs and module-level help text and is user-visible.
Update it to describe GroupedLinear (e.g., “Fusible operation for grouped linear / grouped GEMM”).
| ctx.dtype = dtype | ||
| ctx.save_for_backward( | ||
| input_, | ||
| scales if ctx.input_requires_grad else None, | ||
| ) | ||
|
|
||
| return out, [()] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always-saving scales tensor
In ScaledSwiGLU.fuser_forward, you set ctx.extra_input_requires_grad = extra_input.requires_grad, but you always save scales because the conditional is scales if ctx.input_requires_grad else None and ctx.input_requires_grad is forced to True. When scales.requires_grad=False, this needlessly keeps scales alive for backward and increases activation memory; worse, fuser_backward uses scales inside the ctx.input_requires_grad branch, so if someone later changes input_requires_grad to be accurate, the save condition would become wrong.
Save scales based on ctx.extra_input_requires_grad (or save both unconditionally but keep the grad-required flags consistent).
Description
This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.
Type of change
Changes
Checklist: