Add multi-precision training support to FSDP script#2662
Add multi-precision training support to FSDP script#2662aagallo wants to merge 44 commits intoNVIDIA:mainfrom
Conversation
Enable configurable precision training with support for FP32, FP16, FP8, MXFP8, and NVFP4 formats. Added precision argument parser and match statement to configure appropriate dtype and recipe based on selected precision. - Add precision() type validator function - Implement precision-based configuration in train() - Support FP32, FP16, FP8, MXFP8, and NVFP4 formats - Configure format-specific recipes (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling) - Set appropriate no_fp8 flags based on precision selection Signed-off-by: aagallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a The current revision has incorporated the bulk of previously-raised feedback:
Remaining concern worth addressing:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[parse_fsdp_args] --> B{--precision set?}
B -- "None (default)" --> C[Backward-compat path\ndtype = opts.dtype\nno_fp8 = opts.no_fp8]
C --> C1{no_fp8?}
C1 -- False --> C2[recipe = DelayedScaling]
C1 -- True --> C3[recipe = None]
B -- "fp32/fp16/fp8/mxfp8/nvfp4" --> D{--no-fp8 set with\nFP8 preset?}
D -- Yes --> E[raise ValueError]
D -- No --> F[get_precision_preset]
F --> G[dtype, no_fp8, recipe from preset]
G --> H{--dtype explicitly set?}
H -- No --> I[Keep preset values]
H -- Yes --> J{new_dtype == preset_dtype?}
J -- No --> K[Override dtype\nwarn user\nrecreate recipe if FP8]
J -- Yes --> L[Log 'no override needed'\nrecreate recipe if FP8]
C2 & C3 & I & K & L --> M[layer_kwargs params_dtype = dtype]
M --> N[Build TE model]
N --> O[Wrap with FSDP MixedPrecision param_dtype=dtype]
O --> P[Training loop\nte.autocast enabled=not no_fp8 recipe=recipe]
Last reviewed commit: afa756d |
| case "fp16": | ||
| dtype = torch.bfloat16 | ||
| no_fp8 = True | ||
| case "fp8": |
There was a problem hiding this comment.
Incorrect fp16 dtype
In the case "fp16" branch, the code sets dtype = torch.bfloat16. That contradicts the meaning of fp16 and also diverges from the existing --dtype parsing which supports torch.float16. If a user runs with --precision fp16 expecting fp16 parameters/inputs, they’ll silently get bf16 instead.
| parser.add_argument( | ||
| "--precision", | ||
| type=precision, | ||
| default="fp8", | ||
| help="Precision to apply to model training (FP32, FP16, FP8, MXFP8, NVFP4)", | ||
| ) | ||
| return parser.parse_args() |
There was a problem hiding this comment.
Conflicting CLI flags
--precision and --dtype/--no-fp8 now overlap: train() overrides dtype and no_fp8 based on --precision, but --dtype/--no-fp8 are still accepted and used as defaults. As written, --precision fp8 will force no_fp8=False even if the user explicitly passed --no-fp8, and --precision fp32 will ignore an explicit --dtype fp16/bf16. This makes the CLI behavior surprising and hard to reason about; either make --precision derive defaults only when the user didn’t specify --dtype/--no-fp8, or document/enforce precedence (e.g., error on incompatible combinations).
Additional Comments (1)
|
Correct FP16 precision to use torch.float16 instead of torch.bfloat16, and add precedence logic where --dtype and --no-fp8 flags override --precision when explicitly set, with warnings issued for conflicts. - Fix case fp16 to use torch.float16 instead of torch.bfloat16 - Add flag precedence detection by comparing against default values - Implement warning messages when --dtype or --no-fp8 override --precision - Update argument parser help text to document precedence behavior - Ensure --dtype and --no-fp8 take precedence over --precision presets Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Add informative log messages and enhanced help text to clarify precision configuration behavior and flag precedence for better user transparency. - Add log message showing which precision preset is being used - Add warning logs when --dtype or --no-fp8 override --precision - Add final training configuration log (dtype, FP8 status, recipe) - Enhance argument parser help text with precedence examples - Add inline code comments explaining precedence logic Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Add recipe initialization for fp32 and fp16 precision cases to prevent undefined variable errors, even though recipe is not used when no_fp8 is set to True. - Add DelayedScaling recipe setup for fp32 case with no_fp8=True - Add DelayedScaling recipe setup for fp16 case with no_fp8=True - Add inline comments explaining recipe is set up but not used by autocast - Ensure recipe variable is defined in all precision branches for consistency Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Update flag precedence detection to use sys.argv for checking if --dtype was explicitly set, ensuring dtype always overrides precision regardless of whether it matches the default value. - Add sys import for command-line argument detection - Change dtype_explicitly_set check to use '--dtype' in sys.argv - Change no_fp8_explicitly_set check to use '--no-fp8' in sys.argv - Ensure --dtype bf16 correctly overrides --precision even when matching default - Maintain warning messages when explicit flags override precision presets Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
examples/pytorch/fsdp/fsdp.py
Outdated
| case _: | ||
| dtype = torch.float16 | ||
| precision_format = Format.HYBRID | ||
| recipe = DelayedScaling( | ||
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| no_fp8 = opts.no_fp8 |
There was a problem hiding this comment.
recipe variable not initialized in default case. If precision validator allows an unexpected value, this will cause UnboundLocalError at line 403 when passed to te.autocast().
| case _: | |
| dtype = torch.float16 | |
| precision_format = Format.HYBRID | |
| recipe = DelayedScaling( | |
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | |
| ) | |
| no_fp8 = opts.no_fp8 | |
| case _: | |
| dtype = torch.float16 | |
| precision_format = Format.HYBRID | |
| recipe = DelayedScaling( | |
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | |
| ) | |
| no_fp8 = False |
examples/pytorch/fsdp/fsdp.py
Outdated
| dtype_explicitly_set = "--dtype" in sys.argv | ||
| no_fp8_explicitly_set = "--no-fp8" in sys.argv |
There was a problem hiding this comment.
Parsing sys.argv directly is fragile - will break if args are passed via config file, environment variables, or if arg uses = syntax (--dtype=fp16). Use parser.parse_known_args() or track which args were explicitly set via custom action class.
Example with custom action:
class StoreExplicitAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, values)
setattr(namespace, f'{self.dest}_explicitly_set', True)
Additional Comments (1)
|
Replace fragile sys.argv parsing with robust custom argparse action class to track explicitly set arguments, and fix default precision case to explicitly set no_fp8 to False for consistent FP8-enabled behavior. - Add StoreExplicitAction custom action class for tracking explicit arguments - Update --dtype argument to use StoreExplicitAction - Replace sys.argv check with getattr for dtype_explicitly_set attribute - Remove sys import from train() function - Fix default case to set no_fp8 = False instead of opts.no_fp8 - Ensure recipe variable is properly initialized in all code paths - Support all argument passing methods including config files and = syntax Signed-off-by: Andrea Gallo <aagallo@amazon.com>
…rmerEngine into extend-precision
for more information, see https://pre-commit.ci
examples/pytorch/fsdp/fsdp.py
Outdated
| def train(opts): | ||
| # Check which flags were explicitly set | ||
| dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) | ||
| no_fp8_explicitly_set = opts.no_fp8 != False |
There was a problem hiding this comment.
Detection of explicit --no-fp8 is incorrect - opts.no_fp8 != False is True when the flag is absent (defaults to False). User running with defaults will trigger else branch instead of precision preset logic.
| no_fp8_explicitly_set = opts.no_fp8 != False | |
| no_fp8_explicitly_set = "--no-fp8" in sys.argv |
| @@ -173,7 +193,11 @@ def parse_fsdp_args(): | |||
| "--no-fp8", | |||
There was a problem hiding this comment.
Missing action=StoreExplicitAction to track explicit --no-fp8 usage - required for precedence logic at line 250
| "--no-fp8", | |
| action=StoreExplicitAction, |
examples/pytorch/fsdp/fsdp.py
Outdated
| dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") | ||
| if not no_fp8: | ||
| dist_print(f"Using FP8 recipe: {type(recipe).__name__}") |
There was a problem hiding this comment.
Logging FP8 recipe when no_fp8=True and recipe=None (line 322) will cause AttributeError on type(recipe).__name__
examples/pytorch/fsdp/fsdp.py
Outdated
| parser.add_argument( | ||
| "--precision", | ||
| type=precision, | ||
| default="fp8", |
There was a problem hiding this comment.
Default changed from implicit FP8 disabled (via --dtype bf16 + --no-fp8 False) to explicit --precision fp8 enabled - existing scripts relying on default behavior will now use FP8 unexpectedly
Additional Comments (1)
|
Remove params_dtype initialization from get_layer_args() and update FSDP MixedPrecision to use computed dtype variable instead of raw opts.dtype, ensuring precision presets are properly applied throughout the model. - Remove params_dtype from get_layer_args() layer_kwargs initialization - Update FSDP MixedPrecision param_dtype to use computed dtype variable - Ensure precision preset logic is respected in both layer initialization and FSDP - Maintain backward compatibility with original FP8-enabled default behavior Signed-off-by: Andrea Gallo <aagallo@amazon.com>
…rmerEngine into extend-precision
for more information, see https://pre-commit.ci
Additional Comments (1)
One fix is to perform conversion inside the action: (or alternatively, implement a custom action that subclasses |
Add type converter application in StoreExplicitAction custom action to ensure --dtype values are properly converted from strings to torch dtype objects, preventing runtime errors in torch operations. - Store type converter in StoreExplicitAction.__init__ - Apply type conversion in __call__ before setting attribute value - Add error handling for invalid type conversions - Ensure opts.dtype contains torch dtype object, not raw string - Fix runtime errors in torch.rand() and MixedPrecision() calls Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
| if not no_fp8: | ||
| precision_format = Format.HYBRID | ||
| recipe = DelayedScaling( | ||
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| else: | ||
| recipe = None |
There was a problem hiding this comment.
CRITICAL BUG: Incorrect recipe selection when --dtype or --no-fp8 is explicitly set
When users explicitly set --dtype but also specify --precision mxfp8 or --precision nvfp4, the code ignores the precision preset and falls back to DelayedScaling recipe instead of the expected MXFP8BlockScaling or NVFP4BlockScaling recipe.
Reproduction:
# User expects MXFP8BlockScaling recipe
./fsdp.py --precision mxfp8 --dtype bf16
# But gets DelayedScaling recipe instead!Root cause:
Lines 340-344 unconditionally use DelayedScaling when not no_fp8 is true, completely ignoring the opts.precision value.
Expected behavior:
The code should respect the precision preset even when dtype is explicitly set. The logic should check opts.precision and select the appropriate recipe (MXFP8BlockScaling for "mxfp8", NVFP4BlockScaling for "nvfp4", etc.).
Suggested fix:
# Set up recipe based on precision preset
if not no_fp8:
match opts.precision:
case "mxfp8":
recipe = MXFP8BlockScaling(fp8_format=Format.E4M3)
case "nvfp4":
recipe = NVFP4BlockScaling()
case _:
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
else:
recipe = NoneThis is a merge-blocking bug as it breaks the advertised functionality of the --precision flag.
examples/pytorch/fsdp/fsdp.py
Outdated
| else: | ||
| # dtype and/or no_fp8 were explicitly set - they take precedence | ||
| dtype = opts.dtype | ||
| no_fp8 = opts.no_fp8 | ||
|
|
||
| # Set up default recipe for FP8 cases | ||
| if not no_fp8: | ||
| precision_format = Format.HYBRID | ||
| recipe = DelayedScaling( | ||
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| else: | ||
| recipe = None |
There was a problem hiding this comment.
Precision preset ignored
In the else: branch (when --dtype and/or --no-fp8 were explicitly set), the code unconditionally falls back to DelayedScaling whenever FP8 is enabled:
if not no_fp8:
recipe = DelayedScaling(...)This ignores opts.precision entirely, so --precision mxfp8 or --precision nvfp4 will silently use DelayedScaling if the user also sets --dtype/--no-fp8 (e.g. --precision mxfp8 --dtype bf16). That breaks the advertised presets and applies the wrong quantization recipe.
Fix: in the explicit-flags path, either (a) still select recipe based on opts.precision when FP8 is enabled, or (b) explicitly error/warn and force opts.precision back to fp8 if you’re going to always use DelayedScaling.
Additional Comments (3)
Line 106 is dead code left behind when Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
When a user passes an unrecognised precision value (e.g. …which gives no hint about valid choices. Prefer
Unlike |
Resolve three code review issues in examples/pytorch/fsdp/fsdp.py: dead commented-out code, unhelpful TypeError in precision(), and rigid __init__ signature in StoreTrueExplicitAction. - Remove commented-out layer_kwargs["params_dtype"] = dtype at line 106; dead code after params_dtype was moved to train() - Replace bare raise TypeError in precision() with argparse.ArgumentTypeError and explicit list of supported values (fp32, fp16, fp8, mxfp8, nvfp4) for a meaningful error message - Add **kwargs to StoreTrueExplicitAction.__init__ and forward to super().__init__(); aligns with StoreExplicitAction for robustness Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Additional Comments (1)
The warning is emitted any time Guard the message (and the redundant recipe re-creation) behind an actual value change: If you need the recipe regeneration guard as well, keep the |
Guard the --dtype override warning and dtype reassignment behind an actual value change check to avoid a false positive when the user explicitly passes --dtype with the same value the precision preset would have selected. - Add new_dtype != preset_dtype guard inside the dtype_explicitly_set branch so warning and dtype reassignment only trigger on a real override - Suppress redundant recipe re-creation when dtype matches preset default; recipe is already correctly set from preset_recipe above - No behavioral change when --dtype differs from preset default Signed-off-by: Andrea Gallo <aagallo@amazon.com>
…rmerEngine into extend-precision
Additional Comments (1)
Argparse only catches Since Fix: use def __call__(self, parser, namespace, values, option_string=None):
if self.type_converter is not None:
try:
values = self.type_converter(values)
except (ValueError, TypeError, argparse.ArgumentTypeError) as e:
parser.error(f"argument {option_string}: invalid value: {values!r}: {e}")
setattr(namespace, self.dest, values)
setattr(namespace, f"{self.dest}_explicitly_set", True)This will produce the same clean exit-with-usage-error behaviour as argparse's own type validation. |
…_call__ Apply type_converter inside __call__ with proper exception handling to ensure --dtype values are converted and validated at parse time rather than silently passing raw strings through. - Wrap type_converter call in try/except catching ValueError, TypeError, and argparse.ArgumentTypeError to surface conversion failures via parser.error() with a descriptive message - Guard conversion behind if self.type_converter is not None check for cases where no converter is registered - Ensures --dtype argument is correctly converted and validated consistently with standard argparse type= behavior Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Additional Comments (3)
A simpler approach is to delegate to def get_recipe_for_precision(precision_value):
_, _, recipe = get_precision_preset(precision_value)
if recipe is None:
raise NotImplementedError(f"No FP8 recipe defined for precision '{precision_value}'")
return recipeThis eliminates the duplication entirely and guarantees consistency.
When Consider guarding the recipe re-creation: if dtype_explicitly_set:
new_dtype = opts.dtype
if new_dtype != preset_dtype:
dtype = new_dtype
dist_print(
f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting"
)
# dtype only affects parameter storage, not FP8 recipe
if not no_fp8:
recipe = get_recipe_for_precision(opts.precision)This avoids the unnecessary second instantiation and keeps the recipe consistent with the one returned by |
…nstantiation Address three code review issues in examples/pytorch/fsdp/fsdp.py: redundant opts.no_fp8 check, duplicated recipe construction, and unnecessary recipe re-instantiation when dtype matches preset. - Remove redundant 'and opts.no_fp8' from no_fp8_explicitly_set guard at line 373; StoreTrueExplicitAction always sets opts.no_fp8 to True when it fires, making the extra check always True - Refactor get_recipe_for_precision() to delegate to get_precision_preset() and extract the recipe, eliminating duplicated recipe construction logic and silent drift hazard when recipe parameters are tuned in one place but not the other - Guard recipe re-creation inside new_dtype != preset_dtype branch to avoid unnecessary re-instantiation when dtype_explicitly_set but the value matches the preset default Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Additional Comments (2)
The incompatible-flags check at line 365 raises a Consider validating the flag combination before calling def train(opts):
dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False)
no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False)
# Validate flag combinations before touching distributed state
if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set:
raise ValueError(
f"Cannot use --no-fp8 with --precision {opts.precision}. ..."
)
# Initialize torch.distributed global process group
dist.init_process_group(backend="nccl")
...
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
…nt arg Move incompatible-flags check before dist.init_process_group() to avoid leaving the NCCL process group partially initialized, and remove redundant fp8_format=Format.E4M3 from MXFP8BlockScaling(). - Move no_fp8_explicitly_set + precision conflict check to the top of train() before dist.init_process_group() to prevent deadlocks or 'Address already in use' errors on other ranks that are still waiting inside init_process_group when rank 0 raises ValueError - Remove explicit fp8_format=Format.E4M3 from MXFP8BlockScaling() call; Format.E4M3 is already the dataclass default and passing it explicitly adds noise without adding clarity Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Additional Comments (2)
The fix is to forward class StoreExplicitAction(argparse.Action):
"""Custom action that tracks whether an argument was explicitly set."""
def __init__(self, option_strings, dest, **kwargs):
super().__init__(option_strings, dest, **kwargs)
def __call__(self, parser, namespace, values, option_string=None):
# values already converted by argparse via action.type
setattr(namespace, self.dest, values)
setattr(namespace, f"{self.dest}_explicitly_set", True)With this change, remove the manual
Include the recipe type in the log output: This will emit messages like |
Delegate type conversion to argparse in StoreExplicitAction and include active FP8 recipe type in the training configuration log. - Remove self.type_converter field and manual try/except block from StoreExplicitAction.__call__; forward type= kwarg to super().__init__() so argparse handles conversion natively before __call__ is invoked, restoring standard error messages and %(type)s help interpolation - Simplify StoreExplicitAction.__init__ to use **kwargs passthrough, removing the now-unnecessary type= interception logic - Include active recipe type in training configuration log output using type(recipe).__name__ so log emits messages like 'FP8=enabled (MXFP8BlockScaling)' or 'FP8=disabled', making it easier to verify the intended quantization scheme is in use Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Additional Comments (3)
Consider forwarding
The warning message at line 378 formats
The Consider either:
|
…cument rank assumption Address three code review issues in examples/pytorch/fsdp/fsdp.py: silent kwargs drop in StoreTrueExplicitAction, missing confirmation log when dtype matches preset, and undocumented torchrun assumption. - Forward **kwargs to super().__init__() in StoreTrueExplicitAction to prevent silent discard of unexpected keyword arguments (e.g. metavar, choices) if argument registration is ever extended - Add info log when dtype_explicitly_set but new_dtype == preset_dtype so user receives confirmation their --dtype flag was acknowledged even when it matches the preset default and no override is needed - Add comment above no_fp8_explicitly_set validation documenting that raising ValueError before dist.init_process_group is safe because torchrun guarantees all ranks receive identical CLI arguments Signed-off-by: Andrea Gallo <aagallo@amazon.com>
…rmerEngine into extend-precision
for more information, see https://pre-commit.ci
Signed-off-by: Andrea Gallo <aagallo@amazon.com>
|
@vthumbe1503 Thanks for the feedback!
I also went through the different feedback cycles. Let me know if there is anything else pending on my side. Thanks again! |
|
@vthumbe1503 I saw that the build is actually failing. Is there anything I need to change on my end? I haven't touched anything else other than the fsdp.py file |
Additional Comments (2)
Both the warning (line 387) and the info (line 393) messages format a f"Warning: --dtype {dtype} overrides ..." # prints "torch.float32"
f"Info: --dtype {new_dtype} matches ..." # prints "torch.bfloat16"Because Consider formatting the dtype in a friendlier way, for example: dtype_name = str(dtype).replace("torch.", "")
dist_print(
f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype setting"
)and similarly for the "Info:" message.
The original code unconditionally constructed a # old code
fp8_recipe = DelayedScaling(...)
with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, ...):The new code passes # new code (Case 1 path)
else:
recipe = NoneInspecting However, the implicit fallback to a default recipe (even when |
Strip 'torch.' prefix from dtype in user-facing log messages and
add a comment documenting the intentional recipe=None behavior when
FP8 is disabled.
- Replace raw dtype formatting with str(dtype).replace('torch.', '')
in both Warning and Info log messages so users see 'float32' or
'bfloat16' instead of 'torch.float32' or 'torch.bfloat16'
- Add inline comment on recipe=None assignment explaining that
te.autocast safely substitutes get_default_fp8_recipe() internally
when recipe is None, and skips check_recipe_support when
enabled=False, so the assignment is intentional and safe despite
populating global FP8 state with a default recipe
Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Additional Comments (1)
The |
Description
This PR adds comprehensive precision parameter support to the FSDP training script, enabling users to configure training with multiple precision formats (FP32, FP16, FP8, MXFP8, NVFP4) via command-line argument. The implementation includes automatic configuration of appropriate dtypes and format-specific recipes for each precision type.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist:
Please reach out to Santosh Bhavani (sbhavani@nvidia.com) for additional context on the work