Add multi-precision training support to FSDP script#2662
Add multi-precision training support to FSDP script#2662aagallo wants to merge 68 commits intoNVIDIA:mainfrom
Conversation
Enable configurable precision training with support for FP32, FP16, FP8, MXFP8, and NVFP4 formats. Added precision argument parser and match statement to configure appropriate dtype and recipe based on selected precision. - Add precision() type validator function - Implement precision-based configuration in train() - Support FP32, FP16, FP8, MXFP8, and NVFP4 formats - Configure format-specific recipes (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling) - Set appropriate no_fp8 flags based on precision selection Signed-off-by: aagallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds multi-precision training support to the FSDP example script, enabling users to select Key changes:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[CLI args parsed] --> B{--precision set?}
B -- "Yes" --> C{Incompatible flags?}
C -- "fp8/mxfp8/nvfp4 + --no-fp8" --> D[raise ValueError]
C -- "fp32/fp16 + --no-fp8" --> E[warn: redundant flag]
C -- OK --> F[get_precision_preset]
F --> G{precision value}
G -- fp32 --> H["dtype=float32, no_fp8=True, recipe=None"]
G -- fp16 --> I["dtype=float16, no_fp8=True, recipe=None"]
G -- fp8 --> J["dtype=bfloat16, no_fp8=False, recipe=DelayedScaling"]
G -- mxfp8 --> K["dtype=bfloat16, no_fp8=False, recipe=MXFP8BlockScaling"]
G -- nvfp4 --> L["dtype=bfloat16, no_fp8=False, recipe=NVFP4BlockScaling"]
B -- "No" --> M["dtype=opts.dtype, no_fp8=opts.no_fp8"]
M --> N{no_fp8?}
N -- False --> O["recipe=DelayedScaling (HYBRID)"]
N -- True --> P["recipe=None"]
H & I & J & K & L & O & P --> Q{dtype_explicitly_set AND precision set?}
Q -- Yes + dtype differs --> R[Override dtype, warn; keep preset recipe]
Q -- No --> S[Use preset values as-is]
R & S --> T[layer_kwargs params_dtype = dtype]
T --> U{recipe is None?}
U -- Yes --> V["recipe = DelayedScaling() (defensive fallback)"]
U -- No --> W[keep existing recipe]
V & W --> X{"no_fp8=False AND isinstance DelayedScaling?"}
X -- Yes --> Y["amax_group = all_gpus"]
X -- No --> Z["amax_group = None"]
Y & Z --> AA["te.autocast(enabled=not no_fp8, recipe, amax_group)"]
Last reviewed commit: 16e38d2 |
| case "fp16": | ||
| dtype = torch.bfloat16 | ||
| no_fp8 = True | ||
| case "fp8": |
There was a problem hiding this comment.
Incorrect fp16 dtype
In the case "fp16" branch, the code sets dtype = torch.bfloat16. That contradicts the meaning of fp16 and also diverges from the existing --dtype parsing which supports torch.float16. If a user runs with --precision fp16 expecting fp16 parameters/inputs, they’ll silently get bf16 instead.
| parser.add_argument( | ||
| "--precision", | ||
| type=precision, | ||
| default="fp8", | ||
| help="Precision to apply to model training (FP32, FP16, FP8, MXFP8, NVFP4)", | ||
| ) | ||
| return parser.parse_args() |
There was a problem hiding this comment.
Conflicting CLI flags
--precision and --dtype/--no-fp8 now overlap: train() overrides dtype and no_fp8 based on --precision, but --dtype/--no-fp8 are still accepted and used as defaults. As written, --precision fp8 will force no_fp8=False even if the user explicitly passed --no-fp8, and --precision fp32 will ignore an explicit --dtype fp16/bf16. This makes the CLI behavior surprising and hard to reason about; either make --precision derive defaults only when the user didn’t specify --dtype/--no-fp8, or document/enforce precedence (e.g., error on incompatible combinations).
Additional Comments (1)
|
Correct FP16 precision to use torch.float16 instead of torch.bfloat16, and add precedence logic where --dtype and --no-fp8 flags override --precision when explicitly set, with warnings issued for conflicts. - Fix case fp16 to use torch.float16 instead of torch.bfloat16 - Add flag precedence detection by comparing against default values - Implement warning messages when --dtype or --no-fp8 override --precision - Update argument parser help text to document precedence behavior - Ensure --dtype and --no-fp8 take precedence over --precision presets Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Add informative log messages and enhanced help text to clarify precision configuration behavior and flag precedence for better user transparency. - Add log message showing which precision preset is being used - Add warning logs when --dtype or --no-fp8 override --precision - Add final training configuration log (dtype, FP8 status, recipe) - Enhance argument parser help text with precedence examples - Add inline code comments explaining precedence logic Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Add recipe initialization for fp32 and fp16 precision cases to prevent undefined variable errors, even though recipe is not used when no_fp8 is set to True. - Add DelayedScaling recipe setup for fp32 case with no_fp8=True - Add DelayedScaling recipe setup for fp16 case with no_fp8=True - Add inline comments explaining recipe is set up but not used by autocast - Ensure recipe variable is defined in all precision branches for consistency Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Update flag precedence detection to use sys.argv for checking if --dtype was explicitly set, ensuring dtype always overrides precision regardless of whether it matches the default value. - Add sys import for command-line argument detection - Change dtype_explicitly_set check to use '--dtype' in sys.argv - Change no_fp8_explicitly_set check to use '--no-fp8' in sys.argv - Ensure --dtype bf16 correctly overrides --precision even when matching default - Maintain warning messages when explicit flags override precision presets Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
examples/pytorch/fsdp/fsdp.py
Outdated
| case _: | ||
| dtype = torch.float16 | ||
| precision_format = Format.HYBRID | ||
| recipe = DelayedScaling( | ||
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| no_fp8 = opts.no_fp8 |
There was a problem hiding this comment.
recipe variable not initialized in default case. If precision validator allows an unexpected value, this will cause UnboundLocalError at line 403 when passed to te.autocast().
| case _: | |
| dtype = torch.float16 | |
| precision_format = Format.HYBRID | |
| recipe = DelayedScaling( | |
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | |
| ) | |
| no_fp8 = opts.no_fp8 | |
| case _: | |
| dtype = torch.float16 | |
| precision_format = Format.HYBRID | |
| recipe = DelayedScaling( | |
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | |
| ) | |
| no_fp8 = False |
examples/pytorch/fsdp/fsdp.py
Outdated
| dtype_explicitly_set = "--dtype" in sys.argv | ||
| no_fp8_explicitly_set = "--no-fp8" in sys.argv |
There was a problem hiding this comment.
Parsing sys.argv directly is fragile - will break if args are passed via config file, environment variables, or if arg uses = syntax (--dtype=fp16). Use parser.parse_known_args() or track which args were explicitly set via custom action class.
Example with custom action:
class StoreExplicitAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, values)
setattr(namespace, f'{self.dest}_explicitly_set', True)
Additional Comments (1)
|
Replace fragile sys.argv parsing with robust custom argparse action class to track explicitly set arguments, and fix default precision case to explicitly set no_fp8 to False for consistent FP8-enabled behavior. - Add StoreExplicitAction custom action class for tracking explicit arguments - Update --dtype argument to use StoreExplicitAction - Replace sys.argv check with getattr for dtype_explicitly_set attribute - Remove sys import from train() function - Fix default case to set no_fp8 = False instead of opts.no_fp8 - Ensure recipe variable is properly initialized in all code paths - Support all argument passing methods including config files and = syntax Signed-off-by: Andrea Gallo <aagallo@amazon.com>
…rmerEngine into extend-precision
for more information, see https://pre-commit.ci
examples/pytorch/fsdp/fsdp.py
Outdated
| def train(opts): | ||
| # Check which flags were explicitly set | ||
| dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) | ||
| no_fp8_explicitly_set = opts.no_fp8 != False |
There was a problem hiding this comment.
Detection of explicit --no-fp8 is incorrect - opts.no_fp8 != False is True when the flag is absent (defaults to False). User running with defaults will trigger else branch instead of precision preset logic.
| no_fp8_explicitly_set = opts.no_fp8 != False | |
| no_fp8_explicitly_set = "--no-fp8" in sys.argv |
| @@ -173,7 +193,11 @@ def parse_fsdp_args(): | |||
| "--no-fp8", | |||
There was a problem hiding this comment.
Missing action=StoreExplicitAction to track explicit --no-fp8 usage - required for precedence logic at line 250
| "--no-fp8", | |
| action=StoreExplicitAction, |
examples/pytorch/fsdp/fsdp.py
Outdated
| dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") | ||
| if not no_fp8: | ||
| dist_print(f"Using FP8 recipe: {type(recipe).__name__}") |
There was a problem hiding this comment.
Logging FP8 recipe when no_fp8=True and recipe=None (line 322) will cause AttributeError on type(recipe).__name__
examples/pytorch/fsdp/fsdp.py
Outdated
| parser.add_argument( | ||
| "--precision", | ||
| type=precision, | ||
| default="fp8", |
There was a problem hiding this comment.
Default changed from implicit FP8 disabled (via --dtype bf16 + --no-fp8 False) to explicit --precision fp8 enabled - existing scripts relying on default behavior will now use FP8 unexpectedly
Additional Comments (1)
|
Remove params_dtype initialization from get_layer_args() and update FSDP MixedPrecision to use computed dtype variable instead of raw opts.dtype, ensuring precision presets are properly applied throughout the model. - Remove params_dtype from get_layer_args() layer_kwargs initialization - Update FSDP MixedPrecision param_dtype to use computed dtype variable - Ensure precision preset logic is respected in both layer initialization and FSDP - Maintain backward compatibility with original FP8-enabled default behavior Signed-off-by: Andrea Gallo <aagallo@amazon.com>
…rmerEngine into extend-precision
for more information, see https://pre-commit.ci
Additional Comments (1)
One fix is to perform conversion inside the action: (or alternatively, implement a custom action that subclasses |
Add type converter application in StoreExplicitAction custom action to ensure --dtype values are properly converted from strings to torch dtype objects, preventing runtime errors in torch operations. - Store type converter in StoreExplicitAction.__init__ - Apply type conversion in __call__ before setting attribute value - Add error handling for invalid type conversions - Ensure opts.dtype contains torch dtype object, not raw string - Fix runtime errors in torch.rand() and MixedPrecision() calls Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
examples/pytorch/fsdp/fsdp.py
Outdated
| if not no_fp8: | ||
| precision_format = Format.HYBRID | ||
| recipe = DelayedScaling( | ||
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| else: | ||
| recipe = None |
There was a problem hiding this comment.
CRITICAL BUG: Incorrect recipe selection when --dtype or --no-fp8 is explicitly set
When users explicitly set --dtype but also specify --precision mxfp8 or --precision nvfp4, the code ignores the precision preset and falls back to DelayedScaling recipe instead of the expected MXFP8BlockScaling or NVFP4BlockScaling recipe.
Reproduction:
# User expects MXFP8BlockScaling recipe
./fsdp.py --precision mxfp8 --dtype bf16
# But gets DelayedScaling recipe instead!Root cause:
Lines 340-344 unconditionally use DelayedScaling when not no_fp8 is true, completely ignoring the opts.precision value.
Expected behavior:
The code should respect the precision preset even when dtype is explicitly set. The logic should check opts.precision and select the appropriate recipe (MXFP8BlockScaling for "mxfp8", NVFP4BlockScaling for "nvfp4", etc.).
Suggested fix:
# Set up recipe based on precision preset
if not no_fp8:
match opts.precision:
case "mxfp8":
recipe = MXFP8BlockScaling(fp8_format=Format.E4M3)
case "nvfp4":
recipe = NVFP4BlockScaling()
case _:
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
else:
recipe = NoneThis is a merge-blocking bug as it breaks the advertised functionality of the --precision flag.
examples/pytorch/fsdp/fsdp.py
Outdated
| else: | ||
| # dtype and/or no_fp8 were explicitly set - they take precedence | ||
| dtype = opts.dtype | ||
| no_fp8 = opts.no_fp8 | ||
|
|
||
| # Set up default recipe for FP8 cases | ||
| if not no_fp8: | ||
| precision_format = Format.HYBRID | ||
| recipe = DelayedScaling( | ||
| fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| else: | ||
| recipe = None |
There was a problem hiding this comment.
Precision preset ignored
In the else: branch (when --dtype and/or --no-fp8 were explicitly set), the code unconditionally falls back to DelayedScaling whenever FP8 is enabled:
if not no_fp8:
recipe = DelayedScaling(...)This ignores opts.precision entirely, so --precision mxfp8 or --precision nvfp4 will silently use DelayedScaling if the user also sets --dtype/--no-fp8 (e.g. --precision mxfp8 --dtype bf16). That breaks the advertised presets and applies the wrong quantization recipe.
Fix: in the explicit-flags path, either (a) still select recipe based on opts.precision when FP8 is enabled, or (b) explicitly error/warn and force opts.precision back to fp8 if you’re going to always use DelayedScaling.
Fix two bugs in train() precision configuration block. - Define dtype_name unconditionally before the dtype_explicitly_set block to prevent NameError in the config log when dtype_explicitly_set is False (the common case when --dtype is not explicitly passed) - Guard dtype override warning behind 'dtype_explicitly_set and opts.precision is not None' to prevent spurious warning when user passes --dtype without --precision (original behavior path) Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Additional Comments (2)
Example: Running
When
The Fix: in the if opts.precision is None:
no_fp8 = opts.no_fp8
if opts.dtype is not None:
dtype = opts.dtype
# Preserve original default: FP8 enabled → use DelayedScaling as before
if not no_fp8:
recipe = DelayedScaling(
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
else: |
…efault Fix two bugs in train() precision configuration block: stale dtype_name in log messages after dtype override, and behavioral regression where recipe=None was passed to te.autocast when FP8 was enabled in backward-compatible mode. - Recompute dtype_name immediately after dtype = new_dtype in the dtype override branch so warning and config log reflect the effective dtype rather than the stale preset dtype - Restore original default behavior in opts.precision is None path: when no_fp8 is False (FP8 enabled), supply DelayedScaling recipe to preserve the original te.autocast behavior instead of passing recipe=None which changed the implicit fallback behavior Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Additional Comments (2)
|
…assignment Remove two redundant lines in train() precision configuration block. - Remove 'if opts.dtype is not None' guard in opts.precision is None branch; --dtype has default=torch.bfloat16 so opts.dtype is never None and the condition is always True - Remove redundant 'no_fp8 = preset_no_fp8' assignment in the else branch; no_fp8 is already assigned from preset_no_fp8 at the tuple unpack above and reassigning it in the else branch adds noise without changing behavior Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Additional Comments (3)
The original code always passed a concrete Consider defaulting Or, more explicitly, guard assignment in the
The Consider either:
if opts.precision is not None:
preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision)
dtype, no_fp8, recipe = preset_dtype, preset_no_fp8, preset_recipe
The help strings for Consider keeping help text to a concise one-liner with an example, and moving the detailed precedence rules into the module-level docstring or a separate README section. Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
Address three code review issues in examples/pytorch/fsdp/fsdp.py: recipe=None passed to te.autocast, dead case None in get_precision_preset, and excessively verbose help strings for --no-fp8 and --dtype. - Use 'recipe or DelayedScaling()' fallback at te.autocast call site to preserve original defensive pattern of always passing a concrete recipe instance, even when enabled=False - Remove case None from get_precision_preset() and guard call site in train() with 'if opts.precision is not None' to eliminate dead-code path whose return values were immediately overridden by the caller - Replace multi-paragraph help strings for --no-fp8 and --dtype with concise one-liner synopses; move detailed precedence rules to module-level docstring or README Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Additional Comments (1)
Move the fallback outside the loop or replace the And ensure |
…tion Move 'recipe or DelayedScaling()' fallback to a one-time assignment before the training loop to avoid allocating a new DelayedScaling() object on every iteration when FP8 is disabled. - Add 'if recipe is None: recipe = DelayedScaling()' after the configuration block and before the training loop so the fallback object is created once and reused across all iterations - Restore clean 'recipe=recipe' in te.autocast call, matching the original code pattern - Add comment explaining why recipe is always set to a concrete object even when FP8 is disabled Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Additional Comments (1)
The
Additionally, note that when this fallback fires (i.e. Move the guard to just after the # Ensure recipe is always a concrete object before passing to te.autocast.
# When FP8 is disabled, te.autocast ignores the recipe, but some TE versions
# perform attribute access on it regardless of the enabled flag.
if recipe is None:
recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max")
for i in range(opts.num_iters):
x = torch.rand(...)
with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus):
... |
…arameters Move 'if recipe is None: recipe = DelayedScaling(...)' guard to just before the training loop instead of inside it to avoid redundant is-None checks on every iteration and variable mutation inside the loop. - Use consistent DelayedScaling parameters (fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo='max') matching the rest of the file, rather than plain DelayedScaling() with default args - Guard runs once before the loop; recipe is stable for all iterations - Restores clean 'recipe=recipe' in te.autocast call with no inline fallback expression Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Additional Comments (2)
Passing The # autocast needs to be given the FSDP process group for amax reductions
amax_group = all_gpus if isinstance(recipe, DelayedScaling) else None
with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=amax_group):
y = te_model(x)
loss = y.sum()
The help text for Concise, one-sentence help strings for the CLI and a README or docstring for the extended documentation would be significantly more user-friendly. For example: parser.add_argument(
"--no-fp8",
action="store_true",
default=False,
help=(
"Disable te.autocast() FP8 context. Incompatible with --precision fp8/mxfp8/nvfp4."
" Default: False."
),
)The detailed interaction rules can live in a module-level docstring or in the script's README. Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
…lp text Pass amax_reduction_group conditionally based on recipe type and replace verbose multi-paragraph help strings with concise one-liners. - Compute amax_group = all_gpus if isinstance(recipe, DelayedScaling) else None and pass amax_group to te.autocast; amax_reduction_group is a DelayedScaling-specific parameter for per-tensor amax aggregation and is not accepted by MXFP8BlockScaling or NVFP4BlockScaling which use block-level scaling - Replace multi-paragraph help strings for --no-fp8, --dtype, and --precision (with PRECEDENCE/BEHAVIOR/RATIONALE/EXAMPLES sections) with concise one-liner synopses suitable for terminal --help output - Move detailed precedence rules to module-level docstring Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Additional Comments (1)
When This results in The guard condition should also check |
… comms Add 'not no_fp8' condition to amax_group assignment to prevent amax_reduction_group=all_gpus being passed to te.autocast when FP8 is disabled. - When no_fp8=True and recipe was None, the DelayedScaling fallback causes isinstance(recipe, DelayedScaling) to return True, which incorrectly set amax_group=all_gpus even though enabled=False - Add 'not no_fp8' guard so amax_group is only set to all_gpus when FP8 is active AND the recipe is DelayedScaling (per-tensor amax); all other cases (FP8 disabled, block-scaling recipes) use None Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Additional Comments (2)
The validation guard on line 281 only rejects Consider emitting a warning for this case: if opts.precision in ["fp32", "fp16"] and opts.no_fp8:
dist_print(
f"Warning: --no-fp8 is redundant when using --precision {opts.precision} "
"(FP8 is already disabled by this preset). The flag will be ignored."
)
The expression silently passes Please add an inline comment explaining why # MXFP8 and NVFP4 use local block scaling — no distributed amax reduction group needed.
# amax_reduction_group is only required for DelayedScaling (global AMAX allreduce).
amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None |
…p=None Emit a warning when --no-fp8 is combined with a non-FP8 precision preset and add an inline comment explaining why amax_reduction_group is None for block-scaling recipes. - Add warning when opts.precision in ['fp32', 'fp16'] and opts.no_fp8 is set; FP8 is already disabled by these presets so the flag is redundant and silently ignored without this feedback - Add inline comment on amax_group assignment explaining that MXFP8BlockScaling and NVFP4BlockScaling use local block scaling and do not require a distributed amax reduction group, and that None is also correct when FP8 is disabled to avoid unnecessary distributed communication Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Additional Comments (1)
This can silently break if the condition is ever refactored without updating both blocks, and it will trigger CI lint failures in projects that treat unbound-variable warnings as errors. Suggested fix — initialize the trio before the if-else so the later access is always safe: preset_dtype: torch.dtype = opts.dtype # sensible fallback
preset_recipe = None
if opts.precision is not None:
preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision)
... |
Initialize preset_dtype and preset_recipe with fallback values before the 'if opts.precision is not None' block to prevent static analyzer warnings about potentially unbound variables. - Assign preset_dtype = opts.dtype and preset_recipe = None as sensible fallbacks before the if-else block; these are overwritten by get_precision_preset() when opts.precision is not None and are never accessed in the else branch - Satisfies mypy, pylint, and pyflakes 'possibly undefined' / 'unbound' warnings that would otherwise trigger CI lint failures in projects treating unbound-variable warnings as errors - No behavioral change; the if-else logic is unchanged Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Additional Comments (1)
After the A safer pattern is to compute # Compute amax_group before the defensive recipe fallback
amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None
# Ensure recipe is always a concrete object before passing to te.autocast.
if recipe is None:
recipe = DelayedScaling() |
…ce race Move amax_group computation before the 'if recipe is None' fallback assignment so isinstance(recipe, DelayedScaling) reflects the actual user-selected recipe rather than the defensive fallback object. - When recipe is None (non-FP8 presets or --no-fp8), isinstance correctly returns False and amax_group is set to None before the fallback substitutes a DelayedScaling instance - Prevents the fragile ordering dependency where not no_fp8 was the sole guard against passing all_gpus to a recipe that doesn't need it - Add inline comment explaining why amax_group must be computed before the recipe fallback to preserve the invariant Signed-off-by: Andrea Gallo <aagallo@amazon.com>
for more information, see https://pre-commit.ci
Description
This PR adds comprehensive precision parameter support to the FSDP training script, enabling users to configure training with multiple precision formats (FP32, FP16, FP8, MXFP8, NVFP4) via command-line argument. The implementation includes automatic configuration of appropriate dtypes and format-specific recipes for each precision type.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist:
Please reach out to Santosh Bhavani (sbhavani@nvidia.com) for additional context on the work