Conversation
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryAdded Key Changes:
Critical Issue:
Limitations:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Autocast
participant Recipe
participant Linear
participant Quantizer
participant Backward
User->>Autocast: Set NVTE_KEEP_BACKWARD_UNQUANTIZED=1
User->>Autocast: autocast(enabled=True, recipe=...)
Autocast->>Recipe: _default_quantize_backward()
Recipe-->>Autocast: quantize_backward=False
Autocast->>Autocast: _validate_recipe_quantization_flags()
alt DelayedScaling recipe
Autocast->>Autocast: Raise error (not supported)
else Other recipes
Autocast->>Autocast: Validation passes
end
User->>Linear: forward(input, weight)
Linear->>Recipe: check quantize_backward
Recipe-->>Linear: quantize_backward=False
Linear->>Linear: keep_backward_unquantized=True
Linear->>Linear: save_original_input=True
Linear->>Quantizer: quantize input (FP8)
Quantizer-->>Linear: quantized_input
Linear->>Linear: forward gemm with FP8
Linear->>Linear: Save unquantized input & weight for backward
Linear->>Linear: Disable FP8 quantizers in ctx
Linear-->>User: output
User->>Backward: backward(grad_output)
Backward->>Linear: Check ctx.keep_backward_unquantized
alt keep_backward_unquantized=True
Linear->>Linear: use_fp8_bwd=False
Linear->>Linear: Use saved unquantized input & weight
Linear->>Linear: Compute dgrad & wgrad in high precision
else keep_backward_unquantized=False
Linear->>Linear: Use quantized tensors
Linear->>Linear: Compute dgrad & wgrad in FP8
end
Backward-->>User: gradients
|
|
I'll work on potential unit test breakage. |
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
… is used Signed-off-by: Ziang Li <ziangli@umich.edu>
| ln_out_return = None | ||
| if return_layernorm_output or return_layernorm_output_gathered: | ||
| ln_out_return = ln_out | ||
| ln_out_hp = ln_out if keep_backward_unquantized else None |
There was a problem hiding this comment.
storing both ln_out (quantized) and ln_out_hp (high precision) doubles the memory footprint for this activation
verify this memory overhead is acceptable for your target models, especially during training with large batch sizes or long sequences
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Additional Comments (1)
In
Also appears in: none found in this diff. |
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
Some nvfuser tests are failing: |
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
timmoon10
left a comment
There was a problem hiding this comment.
This feature is reasonably straightforward, although I have some design suggestions to make it more general. Also, we should add some unit tests to make sure this works as expected.
| recipe = cls.get_fp8_recipe() | ||
| if recipe is not None and recipe.delayed(): | ||
| # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used | ||
| return False |
There was a problem hiding this comment.
I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.
| return cls.HIGH_PRECISION_INIT_VAL | ||
|
|
||
| @classmethod | ||
| def keep_backward_unquantized(cls) -> bool: |
There was a problem hiding this comment.
I would prefer this option to live in Recipe rather than FP8GlobalStateManager. FP8GlobalStateManager is for state that changes very frequently (e.g. when entering or exiting a te.autocast), while Recipe has configs that persist throughout training. Exposing the option in Recipe also makes it easier to configure programmatically rather than with an obscure envvar.
| return cls.HIGH_PRECISION_INIT_VAL | ||
|
|
||
| @classmethod | ||
| def keep_backward_unquantized(cls) -> bool: |
There was a problem hiding this comment.
This option name is specific to this workflow and doesn't generalize well. How about we break this up into two options: quantize_forward and quantize_backward. We have the following cases:
quantize_forward=True,quantize_backward=True: Equivalent to quantized case. In the future we might be able to replaceFP8GlobalStateManager.FP8_ENABLEDwithFP8GlobalStateManager.QUANTIZE_FORWARD or FP8GlobalStateManager.QUANTIZE_BACKWARD.quantize_forward=False,quantize_backward=False: Equivalent to unquantized case.quantize_forward=True,quantize_backward=False: Your desired workflow.quantize_forward=False,quantize_backward=True: We can error out in this case, but who know if someone in the future might want this.
| ctx.fp8 = fp8 | ||
| ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None | ||
| ctx.keep_backward_unquantized = keep_backward_unquantized |
There was a problem hiding this comment.
If the backward pass has unquantized compute, does it need to know that the forward pass was quantized? If possible, it would be nice to keep all the changed confined here where we configure the autograd context.
| ctx.fp8 = fp8 | |
| ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None | |
| ctx.keep_backward_unquantized = keep_backward_unquantized | |
| ctx.fp8 = fp8 and not keep_backward_unquantized | |
| ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None |
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Additional Comments (1)
In |
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
| if enabled or calibrating: | ||
| _validate_recipe_quantization_flags(fp8_recipe) | ||
| quantize_forward = getattr(fp8_recipe, "quantize_forward", True) | ||
| effective_enabled = enabled and quantize_forward |
There was a problem hiding this comment.
I am not very sure if we should disable when quantize_forward is false
Description
@HumansAnd
Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: