Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3afce1f
Add NVTE_KEEP_BACKWARD_UNQUANTIZED
zianglih Feb 3, 2026
72149be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2026
3e6eb64
Merge branch 'main' into keep-bwd
zianglih Feb 3, 2026
927d482
Disable ub and clean up
zianglih Feb 3, 2026
cc85b60
Drop fuser changes
zianglih Feb 3, 2026
fe24f95
Replace use_quantized_bwd with use_fp8_bwd
zianglih Feb 3, 2026
5ca3615
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2026
5ba7674
Ignore keep_backward_unquantized if delayed scaling
zianglih Feb 3, 2026
02b7b2a
Refactor ignoring NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling…
zianglih Feb 3, 2026
01a7de0
Add back missing ctx.debug
zianglih Feb 3, 2026
bf904aa
Refactor changes under fused
zianglih Feb 3, 2026
b449fc4
Clean up
zianglih Feb 3, 2026
de3acaf
Refactor high-precision overwrite if keep_backward_unquantized
zianglih Feb 3, 2026
fe65d34
Clean up
zianglih Feb 3, 2026
59aaf6b
Drop redundant fp8_recipe_bwd
zianglih Feb 4, 2026
44da625
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
0f58793
Drop redundant ub changes
zianglih Feb 4, 2026
192fbad
Drop more redundant ub changes
zianglih Feb 4, 2026
0dd1268
Drop redundant delayed scaling changes
zianglih Feb 4, 2026
216621d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
ab8749b
Drop unneeded backwards_needs_fc1_input
zianglih Feb 4, 2026
5881083
Drop and disallow LayerNormMLP implementation
zianglih Feb 4, 2026
431f0c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
937e34b
Move interface changes to recipe
zianglih Feb 5, 2026
0d26127
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2026
0135366
Move ub overrides to fwd
zianglih Feb 5, 2026
1de3c64
Remove duplication
zianglih Feb 5, 2026
04d3543
Simplify use_fp8_bwd logic in bwd
zianglih Feb 5, 2026
454976e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2026
f7794c9
Set grad quantizers to none if keep bwd unquantized
zianglih Feb 5, 2026
58db8ea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2026
9d0b654
Drop delayed scaling change
zianglih Feb 6, 2026
004cb45
Simplify env var logic
zianglih Feb 9, 2026
9baccfd
Move validation check to recipe
zianglih Feb 9, 2026
207eb5a
Simplify effective_enabled
zianglih Feb 9, 2026
15117b1
Fix inverted assertion logic
zianglih Feb 9, 2026
3fc5270
Simplify changes under ops
zianglih Feb 9, 2026
9201d19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2026
1e0f1d2
Simplify ctx.keep_backward_unquantized
zianglih Feb 9, 2026
253873a
Fix missing attribute
zianglih Feb 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 73 additions & 5 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ def scaling_factor_compute(amax: Tensor,
`LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`.
When `fp8_mha = True, fp8_dpa = True`, it becomes
`LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`.
quantize_forward : bool, default = True
Whether to quantize tensors in the forward pass.
quantize_backward : bool, default = True
Whether to quantize tensors in the backward pass. Delayed scaling
always quantizes backward; setting this to False is not supported.

Notes
-----
Expand All @@ -204,9 +209,15 @@ def scaling_factor_compute(amax: Tensor,
reduce_amax: bool = True
fp8_dpa: bool = False
fp8_mha: bool = False
quantize_forward: bool = True
quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1")

def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert not (
not self.quantize_forward and self.quantize_backward
), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True."
assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assertion prevents using NVTE_KEEP_BACKWARD_UNQUANTIZED=1 with DelayedScaling recipe - when env var is set, quantize_backward becomes False, making this assert fail and blocking the entire feature for this recipe type


def __repr__(self) -> str:
return (
Expand All @@ -216,7 +227,9 @@ def __repr__(self) -> str:
f"amax_history_len={self.amax_history_len}, "
f"reduce_amax={self.reduce_amax}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
f"fp8_mha={self.fp8_mha}, "
f"quantize_forward={self.quantize_forward}, "
f"quantize_backward={self.quantize_backward}"
)


Expand All @@ -230,6 +243,10 @@ class Float8CurrentScaling(Recipe):
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
Controls the FP8 data format used during forward and backward
pass.
quantize_forward : bool, default = True
Whether to quantize tensors in the forward pass.
quantize_backward : bool, default = True
Whether to quantize tensors in the backward pass.
"""

use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1"
Expand All @@ -242,9 +259,14 @@ class Float8CurrentScaling(Recipe):
fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_dpa: bool = False
fp8_mha: bool = False
quantize_forward: bool = True
quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1")

def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert not (
not self.quantize_forward and self.quantize_backward
), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True."

def __repr__(self) -> str:
return (
Expand All @@ -257,7 +279,9 @@ def __repr__(self) -> str:
f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, "
f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
f"fp8_mha={self.fp8_mha}, "
f"quantize_forward={self.quantize_forward}, "
f"quantize_backward={self.quantize_backward}"
)


Expand All @@ -284,21 +308,32 @@ class MXFP8BlockScaling(Recipe):
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3
Controls the FP8 data format used during forward and backward
pass.
quantize_forward : bool, default = True
Whether to quantize tensors in the forward pass.
quantize_backward : bool, default = True
Whether to quantize tensors in the backward pass.
"""

margin: int = 0
fp8_format: Format = Format.E4M3
fp8_dpa: bool = False
fp8_mha: bool = False
quantize_forward: bool = True
quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1")

def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert not (
not self.quantize_forward and self.quantize_backward
), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True."

def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}"
f"format={str(self.fp8_format).split('.')[1]}, "
f"quantize_forward={self.quantize_forward}, "
f"quantize_backward={self.quantize_backward}"
)


Expand Down Expand Up @@ -327,6 +362,10 @@ class Float8BlockScaling(Recipe):
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3
Controls the FP8 data format used during forward and backward
pass.
quantize_forward : bool, default = True
Whether to quantize tensors in the forward pass.
quantize_backward : bool, default = True
Whether to quantize tensors in the backward pass.
"""

use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1"
Expand All @@ -343,6 +382,8 @@ class Float8BlockScaling(Recipe):
fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_dpa: bool = False
fp8_mha: bool = False
quantize_forward: bool = True
quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1")

def __post_init__(self) -> None:
assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x"
Expand All @@ -364,6 +405,9 @@ def __post_init__(self) -> None:
not self.fp8_dpa and not self.fp8_mha
), "FP8 attention is not supported for Float8BlockScaling."
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert not (
not self.quantize_forward and self.quantize_backward
), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True."

def __repr__(self) -> str:
return (
Expand All @@ -379,7 +423,9 @@ def __repr__(self) -> str:
f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, "
f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
f"fp8_mha={self.fp8_mha}, "
f"quantize_forward={self.quantize_forward}, "
f"quantize_backward={self.quantize_backward}"
)


Expand Down Expand Up @@ -428,6 +474,10 @@ class NVFP4BlockScaling(Recipe):
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default = False
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
quantize_forward : bool, default = True
Whether to quantize tensors in the forward pass.
quantize_backward : bool, default = True
Whether to quantize tensors in the backward pass.
"""

# Configuration envvars
Expand All @@ -443,10 +493,15 @@ class NVFP4BlockScaling(Recipe):
# Not applying quantization to attention for now
fp8_dpa: bool = False
fp8_mha: bool = False
quantize_forward: bool = True
quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1")

def __post_init__(self) -> None:
assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling"
assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling"
assert not (
not self.quantize_forward and self.quantize_backward
), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True."

# Quantization params
# Note: RHT is currently only applied to column-wise usage so that
Expand Down Expand Up @@ -474,6 +529,8 @@ def __repr__(self) -> str:
f"fp8_format={str(self.fp8_format).split('.')[1]}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}, "
f"quantize_forward={self.quantize_forward}, "
f"quantize_backward={self.quantize_backward}, "
f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, "
f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, "
f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, "
Expand Down Expand Up @@ -505,12 +562,23 @@ class CustomRecipe(Recipe):

- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
quantize_forward : bool, default = True
Whether to quantize tensors in the forward pass.
quantize_backward : bool, default = True
Whether to quantize tensors in the backward pass.
"""

qfactory: Callable[..., Any]

fp8_dpa: bool = False
fp8_mha: bool = False
quantize_forward: bool = True
quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1")

def __repr__(self) -> str:
return f"recipe_type={self.__class__.__name__}, qfactory={self.qfactory}"
return (
f"recipe_type={self.__class__.__name__}, "
f"qfactory={self.qfactory}, "
f"quantize_forward={self.quantize_forward}, "
f"quantize_backward={self.quantize_backward}"
)
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,9 +1135,10 @@ def grad_output_preprocess(
grad_output = grad_output.reshape((-1, grad_output.shape[-1]))
grad_output = grad_output.contiguous()
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
use_fp8_bwd = ctx.fp8 and not ctx.keep_backward_unquantized

# Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8 and not ctx.debug:
if not use_fp8_bwd and not ctx.debug:
if gather_grad_output:
if not ctx.ub_overlap_ag: # Perform NCCL all-gather
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
Expand Down
25 changes: 23 additions & 2 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def forward(
save_original_input,
debug,
) = non_tensor_args
keep_backward_unquantized = fp8 and (
not FP8GlobalStateManager.get_fp8_recipe().quantize_backward
)
if keep_backward_unquantized:
# Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used
save_original_input = True

num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
Expand Down Expand Up @@ -286,6 +292,7 @@ def forward(
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.keep_backward_unquantized = keep_backward_unquantized
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
Expand All @@ -304,6 +311,17 @@ def forward(
ctx.save_original_input = save_original_input
ctx.input_quantizers = input_quantizers

# keep_backward_unquantized overrides
if keep_backward_unquantized:
ctx.fp8 = ctx.fp8 and not keep_backward_unquantized
ctx.ub_overlap_ag = False
ctx.ub_overlap_rs_dgrad = False
ctx.ub_bulk_dgrad = False
ctx.ub_bulk_wgrad = False
ctx.grad_input_quantizer = None
ctx.grad_weight_quantizer = None
ctx.grad_output_quantizer = None

# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])

Expand Down Expand Up @@ -395,13 +413,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
dtype=ctx.activation_dtype,
device=ctx.device,
)
weights_for_dgrad = weights
if ctx.keep_backward_unquantized:
weights_for_dgrad = origin_weights
# Make sure weights are available in column-wise format
# for dgrad computation.
for weight in weights:
for weight in weights_for_dgrad:
if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)
general_grouped_gemm(
weights,
weights_for_dgrad,
grad_output,
[dgrad],
ctx.grad_input_quantizers,
Expand Down
Loading