Skip to content

Add NVTE_KEEP_BACKWARD_UNQUANTIZED#2644

Open
zianglih wants to merge 32 commits intoNVIDIA:mainfrom
zianglih:keep-bwd
Open

Add NVTE_KEEP_BACKWARD_UNQUANTIZED#2644
zianglih wants to merge 32 commits intoNVIDIA:mainfrom
zianglih:keep-bwd

Conversation

@zianglih
Copy link

@zianglih zianglih commented Feb 3, 2026

Description

@HumansAnd

Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

zianglih and others added 2 commits February 2, 2026 16:45
Signed-off-by: Ziang Li <ziangli@umich.edu>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 3, 2026

Greptile Overview

Greptile Summary

Added NVTE_KEEP_BACKWARD_UNQUANTIZED environment variable to enable quantized forward pass with high-precision backward pass (wgrad & dgrad). The implementation adds quantize_forward and quantize_backward fields to all recipe classes, with backward quantization controlled by the env var via _default_quantize_backward().

Key Changes:

  • Added quantize_forward and quantize_backward fields to all recipe dataclasses (DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, Float8BlockScaling, NVFP4BlockScaling, CustomRecipe)
  • Implemented feature in Linear, LayerNormLinear, GroupedLinear, and BasicLinear by saving unquantized tensors and disabling FP8 backward
  • Added validation that blocks quantize_backward=False with DelayedScaling recipe
  • LayerNormMLP explicitly crashes when feature is enabled (not implemented)

Critical Issue:

  • LayerNormMLP has an assertion that crashes the module when NVTE_KEEP_BACKWARD_UNQUANTIZED=1, making it completely unusable with this feature

Limitations:

  • Feature cannot be used with DelayedScaling recipe (validation blocks it at line 98-102 in quantization.py)
  • LayerNormMLP not supported (crashes with assertion)
  • Increased memory usage due to storing both quantized and unquantized tensors

Confidence Score: 2/5

  • This PR has a critical issue that blocks usage of LayerNormMLP with the new feature
  • Score reflects the hard crash in LayerNormMLP when the env var is enabled, which makes a major module completely unusable. The feature is blocked for DelayedScaling recipe and lacks implementation for LayerNormMLP.
  • Pay close attention to transformer_engine/pytorch/module/layernorm_mlp.py which crashes when the feature is enabled

Important Files Changed

Filename Overview
transformer_engine/common/recipe/init.py Added quantize_forward and quantize_backward fields to all recipe classes with env var support
transformer_engine/pytorch/quantization.py Added validation for recipe flags; blocks DelayedScaling with quantize_backward=False
transformer_engine/pytorch/module/layernorm_mlp.py Crashes with assertion when NVTE_KEEP_BACKWARD_UNQUANTIZED=1 is set - feature not implemented
transformer_engine/pytorch/module/linear.py Implements keep_backward_unquantized by saving unquantized tensors and disabling FP8 backward
transformer_engine/pytorch/module/layernorm_linear.py Implements keep_backward_unquantized by storing high-precision ln_out for backward pass

Sequence Diagram

sequenceDiagram
    participant User
    participant Autocast
    participant Recipe
    participant Linear
    participant Quantizer
    participant Backward

    User->>Autocast: Set NVTE_KEEP_BACKWARD_UNQUANTIZED=1
    User->>Autocast: autocast(enabled=True, recipe=...)
    
    Autocast->>Recipe: _default_quantize_backward()
    Recipe-->>Autocast: quantize_backward=False
    
    Autocast->>Autocast: _validate_recipe_quantization_flags()
    alt DelayedScaling recipe
        Autocast->>Autocast: Raise error (not supported)
    else Other recipes
        Autocast->>Autocast: Validation passes
    end
    
    User->>Linear: forward(input, weight)
    Linear->>Recipe: check quantize_backward
    Recipe-->>Linear: quantize_backward=False
    
    Linear->>Linear: keep_backward_unquantized=True
    Linear->>Linear: save_original_input=True
    
    Linear->>Quantizer: quantize input (FP8)
    Quantizer-->>Linear: quantized_input
    
    Linear->>Linear: forward gemm with FP8
    Linear->>Linear: Save unquantized input & weight for backward
    Linear->>Linear: Disable FP8 quantizers in ctx
    Linear-->>User: output
    
    User->>Backward: backward(grad_output)
    Backward->>Linear: Check ctx.keep_backward_unquantized
    alt keep_backward_unquantized=True
        Linear->>Linear: use_fp8_bwd=False
        Linear->>Linear: Use saved unquantized input & weight
        Linear->>Linear: Compute dgrad & wgrad in high precision
    else keep_backward_unquantized=False
        Linear->>Linear: Use quantized tensors
        Linear->>Linear: Compute dgrad & wgrad in FP8
    end
    Backward-->>User: gradients
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zianglih
Copy link
Author

zianglih commented Feb 3, 2026

I'll work on potential unit test breakage.

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

zianglih and others added 2 commits February 3, 2026 09:56
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

… is used

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
ln_out_hp = ln_out if keep_backward_unquantized else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

storing both ln_out (quantized) and ln_out_hp (high precision) doubles the memory footprint for this activation

verify this memory overhead is acceptable for your target models, especially during training with large batch sizes or long sequences

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 4, 2026

Additional Comments (1)

transformer_engine/pytorch/module/layernorm_mlp.py
Incorrect instance check

In _LayerNormMLP.backward, this block checks isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) and then calls ctx.fc1_weight.update_usage(...).

QuantizedTensorStorage is a tensor storage type, not a quantizer; this condition will never be true, so usage for ctx.fc1_weight won’t be updated when it should be (FP8 backward + quantized weight path). This looks like a typo for checking the weight (or QuantizedTensorStorage on ctx.fc1_weight) and can break backward that relies on correct usage flags.

Also appears in: none found in this diff.

zianglih and others added 2 commits February 4, 2026 11:25
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih
Copy link
Author

zianglih commented Feb 4, 2026

Some nvfuser tests are failing:

=================================================================== short test summary info ===================================================================
FAILED tests/pytorch/test_sanity.py::test_sanity_amp_and_nvfuser[True-small-None-dtype1] - RuntimeError: /root/TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:764 in function cublas_gemm: Assertion failed: status != CUBLAS_STAT...
FAILED tests/pytorch/test_sanity.py::test_sanity_amp_and_nvfuser[True-small-None-dtype2] - RuntimeError: /root/TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:764 in function cublas_gemm: Assertion failed: status != CUBLAS_STAT...
FAILED tests/pytorch/test_sanity.py::test_sanity_amp_and_nvfuser[False-small-None-dtype1] - RuntimeError: /root/TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:764 in function cublas_gemm: Assertion failed: status != CUBLAS_STAT...
FAILED tests/pytorch/test_sanity.py::test_sanity_amp_and_nvfuser[False-small-None-dtype2] - RuntimeError: /root/TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:764 in function cublas_gemm: Assertion failed: status != CUBLAS_STAT...
================================================ 4 failed, 12918 passed, 16523 skipped, 20 warnings in 40.71s =================================================

zianglih and others added 2 commits February 4, 2026 14:01
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feature is reasonably straightforward, although I have some design suggestions to make it more general. Also, we should add some unit tests to make sure this works as expected.

recipe = cls.get_fp8_recipe()
if recipe is not None and recipe.delayed():
# Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.

return cls.HIGH_PRECISION_INIT_VAL

@classmethod
def keep_backward_unquantized(cls) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer this option to live in Recipe rather than FP8GlobalStateManager. FP8GlobalStateManager is for state that changes very frequently (e.g. when entering or exiting a te.autocast), while Recipe has configs that persist throughout training. Exposing the option in Recipe also makes it easier to configure programmatically rather than with an obscure envvar.

return cls.HIGH_PRECISION_INIT_VAL

@classmethod
def keep_backward_unquantized(cls) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This option name is specific to this workflow and doesn't generalize well. How about we break this up into two options: quantize_forward and quantize_backward. We have the following cases:

  • quantize_forward=True, quantize_backward=True: Equivalent to quantized case. In the future we might be able to replace FP8GlobalStateManager.FP8_ENABLED with FP8GlobalStateManager.QUANTIZE_FORWARD or FP8GlobalStateManager.QUANTIZE_BACKWARD.
  • quantize_forward=False, quantize_backward=False: Equivalent to unquantized case.
  • quantize_forward=True, quantize_backward=False: Your desired workflow.
  • quantize_forward=False, quantize_backward=True: We can error out in this case, but who know if someone in the future might want this.

Comment on lines 448 to +450
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.keep_backward_unquantized = keep_backward_unquantized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the backward pass has unquantized compute, does it need to know that the forward pass was quantized? If possible, it would be nice to keep all the changed confined here where we configure the autograd context.

Suggested change
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.keep_backward_unquantized = keep_backward_unquantized
ctx.fp8 = fp8 and not keep_backward_unquantized
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None

zianglih and others added 2 commits February 5, 2026 13:10
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 5, 2026

Additional Comments (1)

transformer_engine/pytorch/quantization.py
autocast recipe defaulting

In autocast(), fp8_recipe is now set to get_default_fp8_recipe() when recipe is None, and _validate_recipe_quantization_flags(fp8_recipe) is run when enabled or calibrating (lines ~1014-1017). If get_default_fp8_recipe() can return None, this will throw (_validate_recipe_quantization_flags calls getattr/recipe.delayed()), and later check_recipe_support(fp8_recipe) will also fail. If None is a valid recipe value in this API, add a guard before validating/using it; otherwise, ensure get_default_fp8_recipe() never returns None and document/enforce that contract.

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

zianglih and others added 2 commits February 5, 2026 13:59
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

if enabled or calibrating:
_validate_recipe_quantization_flags(fp8_recipe)
quantize_forward = getattr(fp8_recipe, "quantize_forward", True)
effective_enabled = enabled and quantize_forward
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not very sure if we should disable when quantize_forward is false

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants