-
Notifications
You must be signed in to change notification settings - Fork 635
Add multi-precision training support to FSDP script #2662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Enable configurable precision training with support for FP32, FP16, FP8, MXFP8, and NVFP4 formats. Added precision argument parser and match statement to configure appropriate dtype and recipe based on selected precision. - Add precision() type validator function - Implement precision-based configuration in train() - Support FP32, FP16, FP8, MXFP8, and NVFP4 formats - Configure format-specific recipes (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling) - Set appropriate no_fp8 flags based on precision selection Signed-off-by: aagallo <[email protected]>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR extends the PyTorch FSDP example ( The script now tracks whether Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant U as User/CLI
participant P as parse_fsdp_args()
participant T as train(opts)
participant TE as te.autocast
participant F as FSDP
U->>P: Run examples/pytorch/fsdp/fsdp.py with flags
P-->>T: opts incl. precision/dtype/no_fp8 + explicit-set markers
T->>T: Compute dtype, no_fp8, recipe (preset + overrides)
T->>F: Wrap model with MixedPrecision(param_dtype=dtype)
loop Each iteration
T->>T: Create input tensor x (dtype)
T->>TE: Enter autocast (enabled = not no_fp8)
TE-->>T: Forward executes with recipe
T->>T: Backward + optimizer step
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 3 comments
| case "fp16": | ||
| dtype = torch.bfloat16 | ||
| no_fp8 = True | ||
| case "fp8": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect fp16 dtype
In the case "fp16" branch, the code sets dtype = torch.bfloat16. That contradicts the meaning of fp16 and also diverges from the existing --dtype parsing which supports torch.float16. If a user runs with --precision fp16 expecting fp16 parameters/inputs, they’ll silently get bf16 instead.
| parser.add_argument( | ||
| "--precision", | ||
| type=precision, | ||
| default="fp8", | ||
| help="Precision to apply to model training (FP32, FP16, FP8, MXFP8, NVFP4)", | ||
| ) | ||
| return parser.parse_args() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conflicting CLI flags
--precision and --dtype/--no-fp8 now overlap: train() overrides dtype and no_fp8 based on --precision, but --dtype/--no-fp8 are still accepted and used as defaults. As written, --precision fp8 will force no_fp8=False even if the user explicitly passed --no-fp8, and --precision fp32 will ignore an explicit --dtype fp16/bf16. This makes the CLI behavior surprising and hard to reason about; either make --precision derive defaults only when the user didn’t specify --dtype/--no-fp8, or document/enforce precedence (e.g., error on incompatible combinations).
Additional Comments (1)
|
Correct FP16 precision to use torch.float16 instead of torch.bfloat16, and add precedence logic where --dtype and --no-fp8 flags override --precision when explicitly set, with warnings issued for conflicts. - Fix case fp16 to use torch.float16 instead of torch.bfloat16 - Add flag precedence detection by comparing against default values - Implement warning messages when --dtype or --no-fp8 override --precision - Update argument parser help text to document precedence behavior - Ensure --dtype and --no-fp8 take precedence over --precision presets Signed-off-by: Andrea Gallo <[email protected]>
Add informative log messages and enhanced help text to clarify precision configuration behavior and flag precedence for better user transparency. - Add log message showing which precision preset is being used - Add warning logs when --dtype or --no-fp8 override --precision - Add final training configuration log (dtype, FP8 status, recipe) - Enhance argument parser help text with precedence examples - Add inline code comments explaining precedence logic Signed-off-by: Andrea Gallo <[email protected]>
Add recipe initialization for fp32 and fp16 precision cases to prevent undefined variable errors, even though recipe is not used when no_fp8 is set to True. - Add DelayedScaling recipe setup for fp32 case with no_fp8=True - Add DelayedScaling recipe setup for fp16 case with no_fp8=True - Add inline comments explaining recipe is set up but not used by autocast - Ensure recipe variable is defined in all precision branches for consistency Signed-off-by: Andrea Gallo <[email protected]>
Update flag precedence detection to use sys.argv for checking if --dtype was explicitly set, ensuring dtype always overrides precision regardless of whether it matches the default value. - Add sys import for command-line argument detection - Change dtype_explicitly_set check to use '--dtype' in sys.argv - Change no_fp8_explicitly_set check to use '--no-fp8' in sys.argv - Ensure --dtype bf16 correctly overrides --precision even when matching default - Maintain warning messages when explicit flags override precision presets Signed-off-by: Andrea Gallo <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 3 comments
examples/pytorch/fsdp/fsdp.py
Outdated
| case _: | ||
| dtype = torch.float16 | ||
| precision_format = Format.HYBRID | ||
| recipe = DelayedScaling( | ||
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| no_fp8 = opts.no_fp8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recipe variable not initialized in default case. If precision validator allows an unexpected value, this will cause UnboundLocalError at line 403 when passed to te.autocast().
| case _: | |
| dtype = torch.float16 | |
| precision_format = Format.HYBRID | |
| recipe = DelayedScaling( | |
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | |
| ) | |
| no_fp8 = opts.no_fp8 | |
| case _: | |
| dtype = torch.float16 | |
| precision_format = Format.HYBRID | |
| recipe = DelayedScaling( | |
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | |
| ) | |
| no_fp8 = False |
examples/pytorch/fsdp/fsdp.py
Outdated
| dtype_explicitly_set = "--dtype" in sys.argv | ||
| no_fp8_explicitly_set = "--no-fp8" in sys.argv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parsing sys.argv directly is fragile - will break if args are passed via config file, environment variables, or if arg uses = syntax (--dtype=fp16). Use parser.parse_known_args() or track which args were explicitly set via custom action class.
Example with custom action:
class StoreExplicitAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, values)
setattr(namespace, f'{self.dest}_explicitly_set', True)
Additional Comments (1)
|
Replace fragile sys.argv parsing with robust custom argparse action class to track explicitly set arguments, and fix default precision case to explicitly set no_fp8 to False for consistent FP8-enabled behavior. - Add StoreExplicitAction custom action class for tracking explicit arguments - Update --dtype argument to use StoreExplicitAction - Replace sys.argv check with getattr for dtype_explicitly_set attribute - Remove sys import from train() function - Fix default case to set no_fp8 = False instead of opts.no_fp8 - Ensure recipe variable is properly initialized in all code paths - Support all argument passing methods including config files and = syntax Signed-off-by: Andrea Gallo <[email protected]>
…rmerEngine into extend-precision
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 5 comments
examples/pytorch/fsdp/fsdp.py
Outdated
| def train(opts): | ||
| # Check which flags were explicitly set | ||
| dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) | ||
| no_fp8_explicitly_set = opts.no_fp8 != False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Detection of explicit --no-fp8 is incorrect - opts.no_fp8 != False is True when the flag is absent (defaults to False). User running with defaults will trigger else branch instead of precision preset logic.
| no_fp8_explicitly_set = opts.no_fp8 != False | |
| no_fp8_explicitly_set = "--no-fp8" in sys.argv |
| @@ -173,7 +193,11 @@ def parse_fsdp_args(): | |||
| "--no-fp8", | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing action=StoreExplicitAction to track explicit --no-fp8 usage - required for precedence logic at line 250
| "--no-fp8", | |
| action=StoreExplicitAction, |
examples/pytorch/fsdp/fsdp.py
Outdated
| dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") | ||
| if not no_fp8: | ||
| dist_print(f"Using FP8 recipe: {type(recipe).__name__}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logging FP8 recipe when no_fp8=True and recipe=None (line 322) will cause AttributeError on type(recipe).__name__
examples/pytorch/fsdp/fsdp.py
Outdated
| parser.add_argument( | ||
| "--precision", | ||
| type=precision, | ||
| default="fp8", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default changed from implicit FP8 disabled (via --dtype bf16 + --no-fp8 False) to explicit --precision fp8 enabled - existing scripts relying on default behavior will now use FP8 unexpectedly
Additional Comments (1)
|
Remove params_dtype initialization from get_layer_args() and update FSDP MixedPrecision to use computed dtype variable instead of raw opts.dtype, ensuring precision presets are properly applied throughout the model. - Remove params_dtype from get_layer_args() layer_kwargs initialization - Update FSDP MixedPrecision param_dtype to use computed dtype variable - Ensure precision preset logic is respected in both layer initialization and FSDP - Maintain backward compatibility with original FP8-enabled default behavior Signed-off-by: Andrea Gallo <[email protected]>
…rmerEngine into extend-precision
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Additional Comments (1)
One fix is to perform conversion inside the action: (or alternatively, implement a custom action that subclasses |
Add type converter application in StoreExplicitAction custom action to ensure --dtype values are properly converted from strings to torch dtype objects, preventing runtime errors in torch operations. - Store type converter in StoreExplicitAction.__init__ - Apply type conversion in __call__ before setting attribute value - Add error handling for invalid type conversions - Ensure opts.dtype contains torch dtype object, not raw string - Fix runtime errors in torch.rand() and MixedPrecision() calls Signed-off-by: Andrea Gallo <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 2 comments
| if not no_fp8: | ||
| precision_format = Format.HYBRID | ||
| recipe = DelayedScaling( | ||
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| else: | ||
| recipe = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL BUG: Incorrect recipe selection when --dtype or --no-fp8 is explicitly set
When users explicitly set --dtype but also specify --precision mxfp8 or --precision nvfp4, the code ignores the precision preset and falls back to DelayedScaling recipe instead of the expected MXFP8BlockScaling or NVFP4BlockScaling recipe.
Reproduction:
# User expects MXFP8BlockScaling recipe
./fsdp.py --precision mxfp8 --dtype bf16
# But gets DelayedScaling recipe instead!Root cause:
Lines 340-344 unconditionally use DelayedScaling when not no_fp8 is true, completely ignoring the opts.precision value.
Expected behavior:
The code should respect the precision preset even when dtype is explicitly set. The logic should check opts.precision and select the appropriate recipe (MXFP8BlockScaling for "mxfp8", NVFP4BlockScaling for "nvfp4", etc.).
Suggested fix:
# Set up recipe based on precision preset
if not no_fp8:
match opts.precision:
case "mxfp8":
recipe = MXFP8BlockScaling(fp8_format=Format.E4M3)
case "nvfp4":
recipe = NVFP4BlockScaling()
case _:
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
else:
recipe = NoneThis is a merge-blocking bug as it breaks the advertised functionality of the --precision flag.
examples/pytorch/fsdp/fsdp.py
Outdated
| else: | ||
| # dtype and/or no_fp8 were explicitly set - they take precedence | ||
| dtype = opts.dtype | ||
| no_fp8 = opts.no_fp8 | ||
|
|
||
| # Set up default recipe for FP8 cases | ||
| if not no_fp8: | ||
| precision_format = Format.HYBRID | ||
| recipe = DelayedScaling( | ||
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| else: | ||
| recipe = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Precision preset ignored
In the else: branch (when --dtype and/or --no-fp8 were explicitly set), the code unconditionally falls back to DelayedScaling whenever FP8 is enabled:
if not no_fp8:
recipe = DelayedScaling(...)This ignores opts.precision entirely, so --precision mxfp8 or --precision nvfp4 will silently use DelayedScaling if the user also sets --dtype/--no-fp8 (e.g. --precision mxfp8 --dtype bf16). That breaks the advertised presets and applies the wrong quantization recipe.
Fix: in the explicit-flags path, either (a) still select recipe based on opts.precision when FP8 is enabled, or (b) explicitly error/warn and force opts.precision back to fp8 if you’re going to always use DelayedScaling.
Address critical bugs where FP8 recipes were incorrectly selected when explicit flags were set, and add validation to prevent incompatible flag combinations that would silently disable FP8 training. - Remove default value from --precision parameter (set to None for backward compatibility) - Add get_precision_preset() and get_recipe_for_precision() helper functions - Implement two-path configuration logic: backward compatibility mode vs. precision preset mode - Add incompatibility validation: raise ValueError when --no-fp8 used with fp8/mxfp8/nvfp4 presets - Preserve FP8 recipe selection when --dtype explicitly overrides precision preset dtype - Fix fp16 case to correctly map to torch.float16 instead of torch.bfloat16 - Update parameter help text with precedence rules and usage examples - Ensure backward compatibility: scripts without --precision work identically to original version Signed-off-by: Andrea Gallo <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 2 comments
| case _: | ||
| # Default to fp8 behavior | ||
| recipe = DelayedScaling( | ||
| fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unreachable default preset
In get_precision_preset(), the case _: branch defaults to FP8 behavior, but precision() already rejects any value outside {fp32, fp16, fp8, mxfp8, nvfp4}. As a result this branch should be dead code in normal execution, and keeping it can mask future bugs (e.g., if validation is accidentally bypassed). Consider replacing it with raise AssertionError/ValueError or removing the branch entirely so invalid values fail loudly.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| dtype = opts.dtype | ||
| dist_print( | ||
| f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting" | ||
| ) | ||
|
|
||
| # If FP8 is still enabled, keep recipe based on precision | ||
| # (dtype only affects parameter storage, not FP8 recipe) | ||
| if not no_fp8: | ||
| recipe = get_recipe_for_precision(opts.precision) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant recipe recreation
When dtype_explicitly_set is true and FP8 remains enabled, the code overwrites recipe with get_recipe_for_precision(opts.precision), even though recipe was already set from the selected preset earlier. This creates a second recipe instance (and can diverge if preset recipe configuration changes). If the intent is “dtype override shouldn’t affect recipe”, you can keep the existing recipe rather than re-instantiating it.
Remove dead code in get_precision_preset() default case and eliminate redundant recipe recreation when dtype is explicitly overridden, ensuring cleaner logic flow and preventing duplicate recipe instantiation. - Remove unreachable case _: branch from get_precision_preset() function - Delete redundant recipe recreation when dtype_explicitly_set is true - Preserve existing recipe from preset when dtype override occurs - Ensure dtype override only affects parameter storage, not FP8 recipe selection Signed-off-by: Andrea Gallo <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
| def get_precision_preset(precision_value): | ||
| """Get dtype, no_fp8, and recipe based on precision preset. | ||
|
|
||
| Returns: | ||
| tuple: (dtype, no_fp8, recipe) | ||
| """ | ||
| match precision_value: | ||
| case "fp32": | ||
| return torch.float32, True, None | ||
| case "fp16": | ||
| return torch.float16, True, None | ||
| case "fp8": | ||
| recipe = DelayedScaling( | ||
| fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| return torch.bfloat16, False, recipe | ||
| case "mxfp8": | ||
| recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) | ||
| return torch.bfloat16, False, recipe | ||
| case "nvfp4": | ||
| recipe = NVFP4BlockScaling() | ||
| return torch.bfloat16, False, recipe | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing default return
get_precision_preset() has no case _: / explicit raise, so it can return None and train() will crash on unpacking (preset_dtype, preset_no_fp8, preset_recipe = ...) if precision_value is ever outside the handled set (e.g., if validation is bypassed or a new preset is added but this match isn’t updated). Consider adding an explicit case _: that raises ValueError/AssertionError so invalid values fail loudly instead of a TypeError later.
Prevent silent failures when precision validation is bypassed or new presets are added without updating get_precision_preset() function by adding explicit ValueError for unhandled cases. - Add case _: branch to get_precision_preset() that raises ValueError - Ensure invalid precision values fail loudly with clear error message - Prevent TypeError on tuple unpacking if function returns None - Improve maintainability when adding new precision presets Signed-off-by: Andrea Gallo <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
Description
This PR adds comprehensive precision parameter support to the FSDP training script, enabling users to configure training with multiple precision formats (FP32, FP16, FP8, MXFP8, NVFP4) via command-line argument. The implementation includes automatic configuration of appropriate dtypes and format-specific recipes for each precision type.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist:
Please reach out to Santosh Bhavani ([email protected]) for additional context on the work