diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index a13dfada79..5ee843987c 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -40,6 +40,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +NVTE_KEEP_BACKWARD_UNQUANTIZED=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_keep_backward_unquantized.xml $TE_PATH/tests/pytorch/test_keep_backward_unquantized.py || test_fail "test_keep_backward_unquantized.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" diff --git a/tests/pytorch/test_keep_backward_unquantized.py b/tests/pytorch/test_keep_backward_unquantized.py new file mode 100644 index 0000000000..f5c3339a71 --- /dev/null +++ b/tests/pytorch/test_keep_backward_unquantized.py @@ -0,0 +1,756 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +from contextlib import nullcontext +import math +import os +from typing import Optional + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common import recipe +from transformer_engine.pytorch.ops.fused import ( + BackwardActivationBias, + ForwardLinearBiasActivation, + ForwardLinearBiasAdd, + ForwardLinearScaleAdd, +) + +from utils import quantization_tols, reset_rng_states + + +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +# This file is intended to run in dedicated keep-backward-unquantized mode. +pytestmark = pytest.mark.skipif( + os.environ.get("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") != "1", + reason="Requires NVTE_KEEP_BACKWARD_UNQUANTIZED=1", +) + + +_quantized_numerics_recipe_list = [ + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + id="Float8CurrentScaling", + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + id="MXFP8BlockScaling", + ), + pytest.param( + "fp8_block_scaling", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + id="Float8BlockScaling", + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4BlockScaling", + ), +] + +_shape_test_cases = [ + pytest.param((1, 64), 64, id="2d_m1_k64_n64"), + pytest.param((32, 64), 64, id="2d_m32_k64_n64"), + pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), + pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), + pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), +] + +_bias_activation_shape_cases = [ + pytest.param((32, 64), id="2d_m32_k64"), + pytest.param((8, 4, 64), id="3d_m32_k64"), +] + + +def _make_recipe(recipe_name: str, quantize_backward: Optional[bool]) -> recipe.Recipe: + kwargs = {} + if quantize_backward is not None: + kwargs = {"quantize_forward": True, "quantize_backward": quantize_backward} + + if recipe_name == "fp8_current_scaling": + return recipe.Float8CurrentScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "mxfp8": + return recipe.MXFP8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "fp8_block_scaling": + return recipe.Float8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "nvfp4": + return recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + **kwargs, + ) + + raise ValueError(f"Unsupported recipe for keep-backward-unquantized test: {recipe_name}") + + +def _build_keep_backward_unquantized_recipe(recipe_name: str) -> recipe.Recipe: + fp8_recipe = _make_recipe(recipe_name, quantize_backward=None) + assert fp8_recipe.quantize_forward + assert not fp8_recipe.quantize_backward + return fp8_recipe + + +def _build_quantized_reference_recipe(recipe_name: str) -> recipe.Recipe: + return _make_recipe(recipe_name, quantize_backward=True) + + +def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: + src_params = dict(src_module.named_parameters()) + with torch.no_grad(): + for name, dst_param in dst_module.named_parameters(): + if name not in src_params: + raise RuntimeError(f"Parameter {name} missing in source module") + dst_param.copy_(src_params[name]) + + +def _fprop_tolerances(recipe_name: str) -> dict[str, float]: + if recipe_name == "mxfp8": + return quantization_tols("mxfp8") + if recipe_name in ("fp8_current_scaling", "fp8_block_scaling"): + return quantization_tols("fp8_current_scaling") + if recipe_name == "nvfp4": + return quantization_tols("nvfp4") + raise ValueError(f"Unsupported recipe for keep-backward-unquantized test: {recipe_name}") + + +def _make_linear_like_module( + module_type: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + bias: bool = False, +) -> torch.nn.Module: + if module_type == "linear": + return te.Linear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "layernorm_linear": + return te.LayerNormLinear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "ops_linear": + return te_ops.Linear( + in_features, + out_features, + bias=bias, + dtype=dtype, + device="cuda", + ) + raise ValueError(f"Unsupported module type: {module_type}") + + +def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: + if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": + pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") + + +def _maybe_skip_unsupported_recipe_shape( + recipe_name: str, + input_shape: tuple[int, ...], + module_type: str, +) -> None: + flat_first_dim = math.prod(input_shape[:-1]) + last_dim = input_shape[-1] + + # TE Linear / LayerNormLinear FP8 kernels require FP8-GEMM-compatible dimensions. + if module_type in ("linear", "layernorm_linear"): + if flat_first_dim % 8 != 0 or last_dim % 16 != 0: + pytest.skip( + "Linear/LayerNormLinear FP8 execution requires prod(shape[:-1]) divisible by 8 " + "and shape[-1] divisible by 16." + ) + return + + # te_ops.Linear (fusible ops) has stricter constraints for some block-scaled recipes. + if module_type == "ops_linear": + if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): + pytest.skip( + "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." + ) + if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + pytest.skip( + "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." + ) + + +def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]) -> None: + # Grouped GEMM paths enforce additional split-alignment constraints for block-scaled recipes. + non_empty_splits = [m for m in m_splits if m > 0] + if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): + pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") + if recipe_name == "fp8_block_scaling" and any(m % 4 != 0 for m in non_empty_splits): + pytest.skip( + "GroupedLinear + Float8BlockScaling requires each non-empty m_split divisible by 4." + ) + + +def _run_single_step( + module: torch.nn.Module, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + y.backward(dy) + assert x_run.grad is not None + assert module.weight.grad is not None + return ( + y.detach().clone(), + x_run.grad.detach().clone(), + module.weight.grad.detach().clone(), + ) + + +def _extract_bias_grad(module: torch.nn.Module) -> Optional[torch.Tensor]: + bias = getattr(module, "bias", None) + if bias is None or bias.grad is None: + return None + return bias.grad.detach().clone() + + +def _run_grouped_linear_single_step( + module: te.GroupedLinear, + x: torch.Tensor, + m_splits: list[int], + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[Optional[torch.Tensor]]]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run, m_splits) + y.backward(dy) + assert x_run.grad is not None + weight_grads = [ + getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms) + ] + bias_grads: list[Optional[torch.Tensor]] = [] + for i in range(module.num_gemms): + if module.use_bias: + bias_grads.append(getattr(module, f"bias{i}").grad.detach().clone()) + else: + bias_grads.append(None) + return y.detach().clone(), x_run.grad.detach().clone(), weight_grads, bias_grads + + +def _make_fused_model( + pattern: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + scale: float = 0.5, +) -> te_ops.Sequential: + if pattern == "bias_activation": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.ReLU(), + ) + if pattern == "bias_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.AddExtraInput(in_place=True), + ) + if pattern == "scale_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=False, device="cuda", dtype=dtype), + te_ops.ConstantScale(scale), + te_ops.AddExtraInput(in_place=True), + ) + raise ValueError(f"Unsupported fused test pattern: {pattern}") + + +def _run_fused_single_step( + pattern: str, + model: te_ops.Sequential, + x1: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], + x2: Optional[torch.Tensor] = None, +) -> tuple[ + torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor] +]: + model.zero_grad(set_to_none=True) + x1_run = x1.detach().clone().requires_grad_(True) + x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + if pattern in ("bias_add", "scale_add"): + assert x2_run is not None + y = model(x1_run, x2_run) + else: + y = model(x1_run) + y.backward(dy) + assert x1_run.grad is not None + weight_grad = model[0].weight.grad.detach().clone() + bias_grad = None + if getattr(model[0], "bias", None) is not None and model[0].bias.grad is not None: + bias_grad = model[0].bias.grad.detach().clone() + x2_grad = ( + x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + ) + return y.detach().clone(), x1_run.grad.detach().clone(), x2_grad, weight_grad, bias_grad + + +def _run_quantize_op_single_step( + model: te_ops.Sequential, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor]: + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = model(x_run) + y.backward(dy) + assert x_run.grad is not None + return y.detach().clone(), x_run.grad.detach().clone() + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +def test_keep_backward_unquantized_recipe_defaults(recipe_name: str): + _ = _build_keep_backward_unquantized_recipe(recipe_name) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize( + "module_type", + ("linear", "layernorm_linear", "ops_linear"), +) +@pytest.mark.parametrize( + "input_shape,out_features", + _shape_test_cases, +) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads( + recipe_name: str, + module_type: str, + input_shape: tuple[int, ...], + out_features: int, + use_bias: bool, +): + reset_rng_states() + _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) + dtype = torch.bfloat16 + in_features = input_shape[-1] + + module_quantized_ref = _make_linear_like_module( + module_type, in_features, out_features, dtype, bias=use_bias + ) + module_keep_bwd_hp = _make_linear_like_module( + module_type, in_features, out_features, dtype, bias=use_bias + ) + module_unquantized_ref = _make_linear_like_module( + module_type, in_features, out_features, dtype, bias=use_bias + ) + + # Start all runs from identical parameters. + _copy_named_parameters(module_quantized_ref, module_keep_bwd_hp) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + + output_shape = input_shape[:-1] + (out_features,) + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*output_shape, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _ = _run_single_step(module_quantized_ref, x, dy, quantized_ref_recipe) + y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp = _run_single_step( + module_keep_bwd_hp, x, dy, keep_bwd_hp_recipe + ) + _, dx_unquantized_ref, dw_unquantized_ref = _run_single_step( + module_unquantized_ref, x, dy, None + ) + + # Forward pass should still match quantized reference when only backward is unquantized. + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + + # Backward pass should match unquantized reference for dgrad and wgrad. + torch.testing.assert_close(dx_keep_bwd_hp, dx_unquantized_ref, rtol=0, atol=0) + torch.testing.assert_close(dw_keep_bwd_hp, dw_unquantized_ref, rtol=0, atol=0) + if use_bias: + bgrad_keep = _extract_bias_grad(module_keep_bwd_hp) + bgrad_unquantized = _extract_bias_grad(module_unquantized_ref) + assert bgrad_keep is not None + assert bgrad_unquantized is not None + torch.testing.assert_close(bgrad_keep, bgrad_unquantized, rtol=0, atol=0) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize( + "m_splits", + ([32, 32, 32, 32], [64, 0, 32, 32], [1, 31, 0, 96]), + ids=("uniform_splits", "with_empty_split", "small_and_empty_splits"), +) +def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_unquantized_grads( + recipe_name: str, + use_bias: bool, + m_splits: list[int], +): + if recipe_name == "nvfp4": + pytest.skip("NVFP4 not supported for grouped linear") + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) + + reset_rng_states() + dtype = torch.bfloat16 + in_features = 64 + out_features = 64 + num_gemms = len(m_splits) + num_tokens = sum(m_splits) + + module_quantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + module_keep_bwd_hp = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + module_unquantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + + _copy_named_parameters(module_quantized_ref, module_keep_bwd_hp) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _, _ = _run_grouped_linear_single_step( + module_quantized_ref, x, m_splits, dy, quantized_ref_recipe + ) + y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp, db_keep_bwd_hp = _run_grouped_linear_single_step( + module_keep_bwd_hp, x, m_splits, dy, keep_bwd_hp_recipe + ) + _, dx_unquantized_ref, dw_unquantized_ref, db_unquantized_ref = _run_grouped_linear_single_step( + module_unquantized_ref, x, m_splits, dy, None + ) + + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + torch.testing.assert_close(dx_keep_bwd_hp, dx_unquantized_ref, rtol=0, atol=0) + for test_dw, ref_dw in zip(dw_keep_bwd_hp, dw_unquantized_ref): + torch.testing.assert_close(test_dw, ref_dw, rtol=0, atol=0) + if use_bias: + for test_db, ref_db in zip(db_keep_bwd_hp, db_unquantized_ref): + assert test_db is not None + assert ref_db is not None + torch.testing.assert_close(test_db, ref_db, rtol=0, atol=0) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize( + "fused_pattern,expected_fused_op", + ( + ("bias_add", ForwardLinearBiasAdd), + ("scale_add", ForwardLinearScaleAdd), + ), +) +@pytest.mark.parametrize("m", (1, 32), ids=("m1", "m32")) +def test_keep_backward_unquantized_fused_linear_paths( + recipe_name: str, + fused_pattern: str, + expected_fused_op: type, + m: int, +): + # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + + reset_rng_states() + dtype = torch.bfloat16 + in_features = 64 + out_features = 64 + _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") + model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + model_keep_bwd_hp = _make_fused_model(fused_pattern, in_features, out_features, dtype) + model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + + _copy_named_parameters(model_quantized_ref, model_keep_bwd_hp) + _copy_named_parameters(model_quantized_ref, model_unquantized_ref) + + x1 = torch.randn(m, in_features, dtype=dtype, device="cuda") + x2 = None + if fused_pattern in ("bias_add", "scale_add"): + x2 = torch.randn(m, out_features, dtype=dtype, device="cuda") + dy = torch.randn(m, out_features, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + fused_pattern, model_quantized_ref, x1, dy, quantized_ref_recipe, x2=x2 + ) + y_keep_bwd_hp, dx1_keep_bwd_hp, dx2_keep_bwd_hp, dw_keep_bwd_hp, db_keep_bwd_hp = ( + _run_fused_single_step( + fused_pattern, + model_keep_bwd_hp, + x1, + dy, + keep_bwd_hp_recipe, + x2=x2, + ) + ) + _, dx1_unquantized_ref, dx2_unquantized_ref, dw_unquantized_ref, db_unquantized_ref = ( + _run_fused_single_step( + fused_pattern, + model_unquantized_ref, + x1, + dy, + None, + x2=x2, + ) + ) + + # Ensure this test executes the fused path changed by the keep-bwd feature. + fused_ops = model_keep_bwd_hp._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], expected_fused_op) + + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + torch.testing.assert_close(dx1_keep_bwd_hp, dx1_unquantized_ref, rtol=0, atol=0) + torch.testing.assert_close(dw_keep_bwd_hp, dw_unquantized_ref, rtol=0, atol=0) + if dx2_keep_bwd_hp is not None and dx2_unquantized_ref is not None: + torch.testing.assert_close(dx2_keep_bwd_hp, dx2_unquantized_ref, rtol=0, atol=0) + if db_keep_bwd_hp is not None and db_unquantized_ref is not None: + torch.testing.assert_close(db_keep_bwd_hp, db_unquantized_ref, rtol=0, atol=0) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize("input_shape", _bias_activation_shape_cases) +def test_keep_backward_unquantized_fused_bias_activation_matches_masked_linear_backward( + recipe_name: str, + input_shape: tuple[int, ...], +): + # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + + reset_rng_states() + dtype = torch.bfloat16 + in_features = input_shape[-1] + out_features = 64 + + model_quantized_ref = _make_fused_model("bias_activation", in_features, out_features, dtype) + model_keep_bwd_hp = _make_fused_model("bias_activation", in_features, out_features, dtype) + linear_unquantized_ref = _make_linear_like_module( + "ops_linear", in_features, out_features, dtype, bias=True + ) + + _copy_named_parameters(model_quantized_ref, model_keep_bwd_hp) + _copy_named_parameters(model_keep_bwd_hp[0], linear_unquantized_ref) + + x1 = torch.randn(*input_shape, dtype=dtype, device="cuda") + out_shape = x1.shape[:-1] + (out_features,) + dy = torch.randn(*out_shape, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + "bias_activation", model_quantized_ref, x1, dy, quantized_ref_recipe + ) + y_keep_bwd_hp, dx1_keep_bwd_hp, _, dw_keep_bwd_hp, db_keep_bwd_hp = _run_fused_single_step( + "bias_activation", model_keep_bwd_hp, x1, dy, keep_bwd_hp_recipe + ) + + # Ensure this test executes the fused path changed by the keep-bwd feature. + fused_ops = model_keep_bwd_hp._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], ForwardLinearBiasActivation) + + # keep-bwd mode should disable backward-activation+bias fusion, while quantized + # reference should still use it. + keep_bwd_backward_ops = model_keep_bwd_hp._module_groups[0]._backward_ops + assert not any(isinstance(op, BackwardActivationBias) for op, _ in keep_bwd_backward_ops) + quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops + assert any(isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops) + + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + + # In keep-backward-unquantized mode, backward should behave as high-precision linear backward + # given the ReLU mask induced by quantized forward activations. + dy_after_activation = dy * (y_keep_bwd_hp > 0).to(dy.dtype) + _, dx1_expected, dw_expected = _run_single_step( + linear_unquantized_ref, x1, dy_after_activation, None + ) + db_expected = _extract_bias_grad(linear_unquantized_ref) + assert db_keep_bwd_hp is not None + assert db_expected is not None + + torch.testing.assert_close(dx1_keep_bwd_hp, dx1_expected, rtol=0, atol=0) + torch.testing.assert_close(dw_keep_bwd_hp, dw_expected, rtol=0, atol=0) + torch.testing.assert_close(db_keep_bwd_hp, db_expected, rtol=0, atol=0) + + +def test_keep_backward_unquantized_autocast_respects_quantize_forward_flag(): + reset_rng_states() + dtype = torch.bfloat16 + in_features = 64 + out_features = 64 + + module_quantization_disabled = _make_linear_like_module( + "linear", in_features, out_features, dtype, bias=True + ) + module_unquantized_ref = _make_linear_like_module( + "linear", in_features, out_features, dtype, bias=True + ) + _copy_named_parameters(module_quantization_disabled, module_unquantized_ref) + + x = torch.randn(32, in_features, dtype=dtype, device="cuda") + dy = torch.randn(32, out_features, dtype=dtype, device="cuda") + + recipe_no_fwd_quant = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.E4M3, + quantize_forward=False, + quantize_backward=False, + ) + + y_test, dx_test, dw_test = _run_single_step( + module_quantization_disabled, x, dy, recipe_no_fwd_quant + ) + y_ref, dx_ref, dw_ref = _run_single_step(module_unquantized_ref, x, dy, None) + + torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0) + torch.testing.assert_close(dw_test, dw_ref, rtol=0, atol=0) + bgrad_test = _extract_bias_grad(module_quantization_disabled) + bgrad_ref = _extract_bias_grad(module_unquantized_ref) + assert bgrad_test is not None + assert bgrad_ref is not None + torch.testing.assert_close(bgrad_test, bgrad_ref, rtol=0, atol=0) + + +def test_keep_backward_unquantized_quantize_op_respects_recipe_overrides(): + reset_rng_states() + dtype = torch.bfloat16 + x = torch.randn(32, 64, dtype=dtype, device="cuda") + dy = torch.randn(32, 64, dtype=dtype, device="cuda") + + model_override = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) + model_ref = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) + + recipe_no_quant = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.E4M3, + quantize_forward=False, + quantize_backward=False, + ) + y_override, dx_override = _run_quantize_op_single_step(model_override, x, dy, recipe_no_quant) + y_ref, dx_ref = _run_quantize_op_single_step(model_ref, x, dy, None) + + torch.testing.assert_close(y_override, y_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_override, dx_ref, rtol=0, atol=0) + + +def test_keep_backward_unquantized_is_invalid_for_delayed_scaling(): + with pytest.raises( + (AssertionError, ValueError), + match="Delayed scaling does not support quantize_backward=False", + ): + _ = recipe.DelayedScaling() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_keep_backward_unquantized_not_implemented_for_layernorm_mlp(): + reset_rng_states() + layer = te.LayerNormMLP( + hidden_size=64, + ffn_hidden_size=64, + params_dtype=torch.bfloat16, + bias=False, + device="cuda", + ) + x = torch.randn(32, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe("fp8_current_scaling") + + with pytest.raises( + AssertionError, match="NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" + ): + with te.autocast(enabled=True, recipe=keep_bwd_hp_recipe): + _ = layer(x) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 64ee2a5a16..d534ad883b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -181,6 +181,11 @@ def scaling_factor_compute(amax: Tensor, `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. When `fp8_mha = True, fp8_dpa = True`, it becomes `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. Delayed scaling + always quantizes backward; setting this to False is not supported. Notes ----- @@ -204,9 +209,15 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." + assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." def __repr__(self) -> str: return ( @@ -216,7 +227,9 @@ def __repr__(self) -> str: f"amax_history_len={self.amax_history_len}, " f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -230,6 +243,10 @@ class Float8CurrentScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" @@ -242,9 +259,14 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -257,7 +279,9 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -284,21 +308,32 @@ class MXFP8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ margin: int = 0 fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " - f"format={str(self.fp8_format).split('.')[1]}" + f"format={str(self.fp8_format).split('.')[1]}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -327,6 +362,10 @@ class Float8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" @@ -343,6 +382,8 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" @@ -364,6 +405,9 @@ def __post_init__(self) -> None: not self.fp8_dpa and not self.fp8_mha ), "FP8 attention is not supported for Float8BlockScaling." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -379,7 +423,9 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -428,6 +474,10 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ # Configuration envvars @@ -443,10 +493,15 @@ class NVFP4BlockScaling(Recipe): # Not applying quantization to attention for now fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." # Quantization params # Note: RHT is currently only applied to column-wise usage so that @@ -474,6 +529,8 @@ def __repr__(self) -> str: f"fp8_format={str(self.fp8_format).split('.')[1]}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " @@ -505,12 +562,23 @@ class CustomRecipe(Recipe): - forward: "linear_input", "linear_weight", "linear_output" - backward: "linear_grad_output", "linear_grad_input" + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ qfactory: Callable[..., Any] fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __repr__(self) -> str: - return f"recipe_type={self.__class__.__name__}, qfactory={self.qfactory}" + return ( + f"recipe_type={self.__class__.__name__}, " + f"qfactory={self.qfactory}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" + ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 841cdf04ca..a878f2ace2 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1135,9 +1135,10 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel + use_fp8_bwd = ctx.fp8 and not ctx.keep_backward_unquantized # Non-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8 and not ctx.debug: + if not use_fp8_bwd and not ctx.debug: if gather_grad_output: if not ctx.ub_overlap_ag: # Perform NCCL all-gather grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c9ceb714e3..abe6df6875 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -96,6 +96,12 @@ def forward( save_original_input, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) + if keep_backward_unquantized: + # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used + save_original_input = True num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] @@ -286,6 +292,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -304,6 +311,17 @@ def forward( ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + # keep_backward_unquantized overrides + if keep_backward_unquantized: + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None + # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -395,13 +413,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) + weights_for_dgrad = weights + if ctx.keep_backward_unquantized: + weights_for_dgrad = origin_weights # Make sure weights are available in column-wise format # for dgrad computation. - for weight in weights: + for weight in weights_for_dgrad: if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) general_grouped_gemm( - weights, + weights_for_dgrad, grad_output, [dgrad], ctx.grad_input_quantizers, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..187fd70f92 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,6 +141,9 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" @@ -200,7 +203,10 @@ def forward( if fp8: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and not keep_backward_unquantized, + ) if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) @@ -213,6 +219,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not keep_backward_unquantized and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) @@ -236,6 +243,7 @@ def forward( ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out + ln_out_hp = ln_out if keep_backward_unquantized else None # ------------------------------------------------------ # Prepare GEMM input tensor @@ -409,13 +417,16 @@ def forward( # ------------------------------------------------------ if is_grad_enabled: + ln_out_to_save = ln_out + if keep_backward_unquantized: + ln_out_to_save = ln_out_hp ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) # Input with column-wise usage is needed for wgrad GEMM. - if backward_needs_input: + if backward_needs_input and not keep_backward_unquantized: if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data @@ -427,7 +438,7 @@ def forward( ln_out.update_usage(rowwise_usage=False) if cpu_offloading: - mark_activation_offload(inputmat, mu, rsigma, ln_out) + mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -439,7 +450,7 @@ def forward( mu, rsigma, weightmat if fp8 and not is_weight_param_quantized else None, - ln_out if weight.requires_grad else None, + ln_out_to_save if weight.requires_grad else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") @@ -466,7 +477,7 @@ def forward( weight, bias, ln_weight, - ln_out, + ln_out_to_save, mu, rsigma, ) @@ -493,6 +504,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -523,6 +535,17 @@ def forward( ctx.wgrad_store = wgrad_store ctx.debug = debug + # keep_backward_unquantized overrides + if keep_backward_unquantized: + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None + # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ @@ -665,7 +688,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None: + if ctx.input_quantizer is not None and ctx.fp8: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -703,7 +726,11 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): + if ( + ctx.fp8 + and ctx.weight_quantizer is not None + and isinstance(weight, QuantizedTensorStorage) + ): weight.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator @@ -730,8 +757,11 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight + if ctx.keep_backward_unquantized: + weight_for_dgrad = origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( - weight, + weight_for_dgrad, grad_output, layout="NN", grad=True, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bec6744518..ac10534012 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -232,6 +232,12 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) + assert ( + not keep_backward_unquantized + ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -778,6 +784,7 @@ def _forward( ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..7d960102ec 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,6 +129,12 @@ def forward( save_original_input, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) + if keep_backward_unquantized: + # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used + save_original_input = True # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -443,6 +449,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.input_quantizer = input_quantizer ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer @@ -486,6 +493,17 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module ctx.wgrad_store = wgrad_store + # keep_backward_unquantized overrides + if keep_backward_unquantized: + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None + # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ @@ -690,8 +708,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance( - weight_fp8, QuantizedTensorStorage + if ( + ctx.fp8 + and ctx.weight_quantizer is not None + and isinstance(weight_fp8, QuantizedTensorStorage) ): weight_fp8.update_usage(columnwise_usage=True) @@ -720,8 +740,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight_fp8 + if ctx.keep_backward_unquantized: + weight_for_dgrad = weight gemm_out, *_, reduce_scatter_out = general_gemm( - weight_fp8, + weight_for_dgrad, grad_output, layout="NN", grad=True, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index e640f3ffb1..16b7bcb7c5 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,12 +332,16 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad + keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) + columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) grad_output_quantizer = self.get_quantizer("backward", 0) - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=False) - grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) @@ -420,6 +424,7 @@ def _functional_forward( tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, + keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -459,6 +464,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = `False` Whether to perform compute with quantized data. + keep_backward_unquantized: bool, default = `False` + Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -510,7 +517,10 @@ def _functional_forward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) if with_x_all_gather: input_quantizer.set_usage(columnwise=False) x, x_async = gather_along_first_dim( @@ -542,7 +552,10 @@ def _functional_forward( elif with_quantized_compute and not is_quantized_tensor(w): if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + weight_quantizer.set_usage( + rowwise=True, + columnwise=input_requires_grad and not keep_backward_unquantized, + ) w = weight_quantizer(w) # Check output tensor @@ -611,14 +624,23 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if w is not weight and with_quantized_compute and is_quantized_tensor(w): + if ( + w is not weight + and with_quantized_compute + and is_quantized_tensor(w) + and not keep_backward_unquantized + ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if with_quantized_compute and is_quantized_tensor(x_local): + if ( + with_quantized_compute + and is_quantized_tensor(x_local) + and not keep_backward_unquantized + ): if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -968,6 +990,9 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -984,6 +1009,7 @@ def op_forward( tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -993,10 +1019,12 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = self.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - ctx.save_for_backward(x_local, w) - ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + ctx.save_for_backward(saved_input, saved_weight) + ctx.with_quantized_compute = with_quantized_compute and not keep_backward_unquantized ctx.input_quantizer = input_quantizer ctx.weight_quantizer = weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index 8b60251088..d0ff6d5e15 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -10,6 +10,7 @@ import torch import transformer_engine_torch as tex +from ...quantization import FP8GlobalStateManager from ..op import BasicOperation, OperationContext from ...utils import canonicalize_device, canonicalize_dtype from ...tensor import Quantizer @@ -123,7 +124,12 @@ def op_forward( b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) if ctx.requires_grad: - ctx.grad_input_quantizer = prev_op_grad_output_quantizer + keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) + ctx.grad_input_quantizer = ( + None if keep_backward_unquantized else prev_op_grad_output_quantizer + ) return x + b diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index d126b554b5..33062d5b88 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -59,6 +59,15 @@ def op_forward( quantize_forward = fp8_enabled and self._quantize_forward quantize_backward = fp8_enabled and self._quantize_backward + # Recipe quantize overrides + if FP8GlobalStateManager.get_fp8_recipe() is not None: + quantize_forward = ( + quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward + ) + quantize_backward = ( + quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) + # Quantize if needed out = input_ if quantize_forward and not is_quantized_tensor(out): diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 4ab082d32b..395a9dbd67 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -104,8 +104,9 @@ def fuse_backward_ops( """ - # Check if recipe supports bias activation fusion - if recipe is None: + # Check if recipe supports bias activation fusion. + # keep-backward-unquantized mode should use unfused backward ops. + if recipe is None or not recipe.quantize_backward: return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index dfc11a19e7..42f459a41e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,6 +92,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -109,6 +112,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -118,10 +122,14 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer @@ -130,7 +138,9 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() + bias_op_ctx.grad_input_quantizer = ( + None if keep_backward_unquantized else linear_op.get_grad_output_quantizer() + ) return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 2dfc0566b7..75d58fd5cc 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,6 +86,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -106,6 +109,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -115,10 +119,14 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer @@ -127,7 +135,9 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() + bias_op_ctx.grad_input_quantizer = ( + None if keep_backward_unquantized else linear_op.get_grad_output_quantizer() + ) return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index ae4bdd4b19..dfdd11a231 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,6 +65,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # Get extra input tensor for add operation extra_input = basic_op_extra_inputs[2][0] @@ -87,6 +90,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -96,10 +100,14 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..00196c584f 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -842,14 +842,15 @@ def autocast( are reduced at the end of each training step. """ - if enabled: + effective_enabled = enabled and getattr(recipe, "quantize_forward", True) + if effective_enabled: check_recipe_support(recipe) # Save current state so we always restore it on exit. fp8_state = FP8GlobalStateManager.get_autocast_state() FP8GlobalStateManager.autocast_enter( - enabled=enabled, + enabled=effective_enabled, calibrating=calibrating, fp8_recipe=recipe, fp8_group=amax_reduction_group, @@ -859,7 +860,7 @@ def autocast( yield finally: FP8GlobalStateManager.set_autocast_state(fp8_state) - FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) + FP8GlobalStateManager.autocast_exit(effective_enabled, _graph=_graph) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: