Skip to content

Add multi-precision training support to FSDP script#2662

Open
aagallo wants to merge 68 commits intoNVIDIA:mainfrom
aagallo:extend-precision
Open

Add multi-precision training support to FSDP script#2662
aagallo wants to merge 68 commits intoNVIDIA:mainfrom
aagallo:extend-precision

Conversation

@aagallo
Copy link

@aagallo aagallo commented Feb 9, 2026

Description

This PR adds comprehensive precision parameter support to the FSDP training script, enabling users to configure training with multiple precision formats (FP32, FP16, FP8, MXFP8, NVFP4) via command-line argument. The implementation includes automatic configuration of appropriate dtypes and format-specific recipes for each precision type.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added precision() type validator function supporting fp32, fp16, fp8, mxfp8, and nvfp4 formats
  • Added --precision command-line argument to parse_fsdp_args() with default value "fp8"
  • Implemented match statement in train() function to configure precision-based training parameters
  • Configured format-specific recipes for each precision type:
    • FP32/FP16: Uses standard PyTorch dtypes with FP8 disabled
    • FP8: Uses DelayedScaling recipe with HYBRID format
    • MXFP8: Uses MXFP8BlockScaling recipe with E4M3 format
    • NVFP4: Uses NVFP4BlockScaling recipe with bfloat16 dtype
  • Set appropriate no_fp8 flags based on precision selection
  • Updated layer_kwargs["params_dtype"] to use precision-determined dtype
  • Imported required recipe classes: MXFP8BlockScaling and NVFP4BlockScaling

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Please reach out to Santosh Bhavani (sbhavani@nvidia.com) for additional context on the work

aagallo and others added 3 commits February 9, 2026 16:28
Enable configurable precision training with support for FP32, FP16, FP8,
MXFP8, and NVFP4 formats. Added precision argument parser and match statement
to configure appropriate dtype and recipe based on selected precision.

- Add precision() type validator function
- Implement precision-based configuration in train()
- Support FP32, FP16, FP8, MXFP8, and NVFP4 formats
- Configure format-specific recipes (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling)
- Set appropriate no_fp8 flags based on precision selection

Signed-off-by: aagallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Greptile Summary

This PR adds multi-precision training support to the FSDP example script, enabling users to select fp32, fp16, fp8, mxfp8, or nvfp4 via a new --precision flag. It introduces get_precision_preset() to map precision strings to (dtype, no_fp8, recipe) tuples, a StoreExplicitAction to track whether --dtype was explicitly set (allowing user-supplied dtype to override the preset), and validation logic that raises early errors for incompatible flag combinations (e.g. --precision fp8 --no-fp8). The FSDP MixedPrecision and te.autocast paths are updated to use the resolved dtype/recipe/amax_group instead of hard-coded values.

Key changes:

  • get_precision_preset() correctly maps all five precision values to their respective recipes (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling) and returns a ValueError for unexpected values.
  • --precision defaults to None, preserving existing default behavior (bfloat16 + FP8 via --dtype/--no-fp8). Note: the PR description states the default is "fp8", but the code correctly uses default=None.
  • amax_reduction_group is now None for block-scaling recipes (MXFP8BlockScaling, NVFP4BlockScaling) since they do not require global AMAX all-reduces, and is set to all_gpus only for DelayedScaling.
  • Ordering concern: amax_group is computed after the defensive recipe = DelayedScaling() fallback (line 415-416). This means isinstance(recipe, DelayedScaling) is True even when no_fp8=True. The result is still correct today because not no_fp8 short-circuits it to None, but the ordering creates a fragile invariant. Swapping the two blocks (compute amax_group first, then apply the fallback) would make the intent clearer and safer.

Confidence Score: 3/5

  • PR is functional but has a subtle ordering issue between the amax_group computation and the defensive recipe fallback that could silently break if the code is refactored.
  • The PR addresses the majority of concerns raised in previous review threads (correct fp16 dtype, StoreExplicitAction for dtype tracking, explicit error on incompatible flags, proper recipe selection per precision preset, defensive DelayedScaling fallback, and amax_group scoped to DelayedScaling only). One new ordering concern was identified: amax_group is computed after the recipe=DelayedScaling() fallback, making the isinstance guard technically redundant for the no_fp8=True path. This doesn't cause a bug today but is fragile. The PR description also incorrectly states the default for --precision is "fp8" when it is actually None. No critical logic errors remain.
  • examples/pytorch/fsdp/fsdp.py — specifically the amax_group / recipe-fallback ordering around lines 412–420.

Important Files Changed

Filename Overview
examples/pytorch/fsdp/fsdp.py Adds multi-precision training support (fp32, fp16, fp8, mxfp8, nvfp4) via --precision CLI argument; introduces StoreExplicitAction for dtype tracking and get_precision_preset() helper; several previous thread issues addressed but a few new concerns remain around the dist_print-before-init ordering and dead initialization code.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[CLI args parsed] --> B{--precision set?}

    B -- "Yes" --> C{Incompatible flags?}
    C -- "fp8/mxfp8/nvfp4 + --no-fp8" --> D[raise ValueError]
    C -- "fp32/fp16 + --no-fp8" --> E[warn: redundant flag]
    C -- OK --> F[get_precision_preset]

    F --> G{precision value}
    G -- fp32 --> H["dtype=float32, no_fp8=True, recipe=None"]
    G -- fp16 --> I["dtype=float16, no_fp8=True, recipe=None"]
    G -- fp8 --> J["dtype=bfloat16, no_fp8=False, recipe=DelayedScaling"]
    G -- mxfp8 --> K["dtype=bfloat16, no_fp8=False, recipe=MXFP8BlockScaling"]
    G -- nvfp4 --> L["dtype=bfloat16, no_fp8=False, recipe=NVFP4BlockScaling"]

    B -- "No" --> M["dtype=opts.dtype, no_fp8=opts.no_fp8"]
    M --> N{no_fp8?}
    N -- False --> O["recipe=DelayedScaling (HYBRID)"]
    N -- True --> P["recipe=None"]

    H & I & J & K & L & O & P --> Q{dtype_explicitly_set AND precision set?}
    Q -- Yes + dtype differs --> R[Override dtype, warn; keep preset recipe]
    Q -- No --> S[Use preset values as-is]
    R & S --> T[layer_kwargs params_dtype = dtype]

    T --> U{recipe is None?}
    U -- Yes --> V["recipe = DelayedScaling() (defensive fallback)"]
    U -- No --> W[keep existing recipe]

    V & W --> X{"no_fp8=False AND isinstance DelayedScaling?"}
    X -- Yes --> Y["amax_group = all_gpus"]
    X -- No --> Z["amax_group = None"]

    Y & Z --> AA["te.autocast(enabled=not no_fp8, recipe, amax_group)"]
Loading

Last reviewed commit: 16e38d2

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +243 to +246
case "fp16":
dtype = torch.bfloat16
no_fp8 = True
case "fp8":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect fp16 dtype
In the case "fp16" branch, the code sets dtype = torch.bfloat16. That contradicts the meaning of fp16 and also diverges from the existing --dtype parsing which supports torch.float16. If a user runs with --precision fp16 expecting fp16 parameters/inputs, they’ll silently get bf16 instead.

Comment on lines +206 to 212
parser.add_argument(
"--precision",
type=precision,
default="fp8",
help="Precision to apply to model training (FP32, FP16, FP8, MXFP8, NVFP4)",
)
return parser.parse_args()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conflicting CLI flags
--precision and --dtype/--no-fp8 now overlap: train() overrides dtype and no_fp8 based on --precision, but --dtype/--no-fp8 are still accepted and used as defaults. As written, --precision fp8 will force no_fp8=False even if the user explicitly passed --no-fp8, and --precision fp32 will ignore an explicit --dtype fp16/bf16. This makes the CLI behavior surprising and hard to reason about; either make --precision derive defaults only when the user didn’t specify --dtype/--no-fp8, or document/enforce precedence (e.g., error on incompatible combinations).

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
FSDP mixed_precision mismatch
layer_kwargs["params_dtype"] and the input tensor dtype are overridden by --precision, but FSDP is still configured with mixed_precision=MixedPrecision(param_dtype=opts.dtype, ...) (and opts.dtype no longer matches the model param dtype when --precision is used). This will cause inconsistent param casting/communication behavior under FSDP for e.g. --precision fp32 (params are fp32 but FSDP thinks they’re bf16) and --precision fp16 (currently sets dtype=torch.bfloat16). FSDP param_dtype should be driven by the same dtype selected in the precision switch, or the precision switch should not override param dtype when FSDP mixed precision is enabled.

Correct FP16 precision to use torch.float16 instead of torch.bfloat16,
and add precedence logic where --dtype and --no-fp8 flags override
--precision when explicitly set, with warnings issued for conflicts.

- Fix case fp16 to use torch.float16 instead of torch.bfloat16
- Add flag precedence detection by comparing against default values
- Implement warning messages when --dtype or --no-fp8 override --precision
- Update argument parser help text to document precedence behavior
- Ensure --dtype and --no-fp8 take precedence over --precision presets

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@ptrendx ptrendx requested a review from vthumbe1503 February 10, 2026 20:09
aagallo and others added 5 commits February 10, 2026 15:31
Add informative log messages and enhanced help text to clarify precision
configuration behavior and flag precedence for better user transparency.

- Add log message showing which precision preset is being used
- Add warning logs when --dtype or --no-fp8 override --precision
- Add final training configuration log (dtype, FP8 status, recipe)
- Enhance argument parser help text with precedence examples
- Add inline code comments explaining precedence logic

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Add recipe initialization for fp32 and fp16 precision cases to prevent
undefined variable errors, even though recipe is not used when no_fp8
is set to True.

- Add DelayedScaling recipe setup for fp32 case with no_fp8=True
- Add DelayedScaling recipe setup for fp16 case with no_fp8=True
- Add inline comments explaining recipe is set up but not used by autocast
- Ensure recipe variable is defined in all precision branches for consistency

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Update flag precedence detection to use sys.argv for checking if --dtype
was explicitly set, ensuring dtype always overrides precision regardless
of whether it matches the default value.

- Add sys import for command-line argument detection
- Change dtype_explicitly_set check to use '--dtype' in sys.argv
- Change no_fp8_explicitly_set check to use '--no-fp8' in sys.argv
- Ensure --dtype bf16 correctly overrides --precision even when matching default
- Maintain warning messages when explicit flags override precision presets

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +296 to +302
case _:
dtype = torch.float16
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
no_fp8 = opts.no_fp8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recipe variable not initialized in default case. If precision validator allows an unexpected value, this will cause UnboundLocalError at line 403 when passed to te.autocast().

Suggested change
case _:
dtype = torch.float16
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
no_fp8 = opts.no_fp8
case _:
dtype = torch.float16
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
no_fp8 = False

Comment on lines +242 to +243
dtype_explicitly_set = "--dtype" in sys.argv
no_fp8_explicitly_set = "--no-fp8" in sys.argv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parsing sys.argv directly is fragile - will break if args are passed via config file, environment variables, or if arg uses = syntax (--dtype=fp16). Use parser.parse_known_args() or track which args were explicitly set via custom action class.

Example with custom action:

class StoreExplicitAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, values)
        setattr(namespace, f'{self.dest}_explicitly_set', True)

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 10, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
FSDP MixedPrecision uses opts.dtype instead of local dtype variable. When --precision is used without explicit --dtype, FSDP will use the default bfloat16 instead of the precision-determined dtype (e.g., float16 for fp8/mxfp8 or float32 for fp32).

        mixed_precision=MixedPrecision(
            param_dtype=dtype,
            reduce_dtype=torch.float32,
        ),

aagallo and others added 3 commits February 10, 2026 16:15
Replace fragile sys.argv parsing with robust custom argparse action class
to track explicitly set arguments, and fix default precision case to
explicitly set no_fp8 to False for consistent FP8-enabled behavior.

- Add StoreExplicitAction custom action class for tracking explicit arguments
- Update --dtype argument to use StoreExplicitAction
- Replace sys.argv check with getattr for dtype_explicitly_set attribute
- Remove sys import from train() function
- Fix default case to set no_fp8 = False instead of opts.no_fp8
- Ensure recipe variable is properly initialized in all code paths
- Support all argument passing methods including config files and = syntax

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

def train(opts):
# Check which flags were explicitly set
dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False)
no_fp8_explicitly_set = opts.no_fp8 != False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detection of explicit --no-fp8 is incorrect - opts.no_fp8 != False is True when the flag is absent (defaults to False). User running with defaults will trigger else branch instead of precision preset logic.

Suggested change
no_fp8_explicitly_set = opts.no_fp8 != False
no_fp8_explicitly_set = "--no-fp8" in sys.argv

@@ -173,7 +193,11 @@ def parse_fsdp_args():
"--no-fp8",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing action=StoreExplicitAction to track explicit --no-fp8 usage - required for precedence logic at line 250

Suggested change
"--no-fp8",
action=StoreExplicitAction,

Comment on lines +332 to +334
dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}")
if not no_fp8:
dist_print(f"Using FP8 recipe: {type(recipe).__name__}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logging FP8 recipe when no_fp8=True and recipe=None (line 322) will cause AttributeError on type(recipe).__name__

parser.add_argument(
"--precision",
type=precision,
default="fp8",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default changed from implicit FP8 disabled (via --dtype bf16 + --no-fp8 False) to explicit --precision fp8 enabled - existing scripts relying on default behavior will now use FP8 unexpectedly

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 10, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
param_dtype set to opts.dtype instead of local dtype variable - ignores precision preset

            param_dtype=dtype,

aagallo and others added 3 commits February 10, 2026 16:47
Remove params_dtype initialization from get_layer_args() and update FSDP
MixedPrecision to use computed dtype variable instead of raw opts.dtype,
ensuring precision presets are properly applied throughout the model.

- Remove params_dtype from get_layer_args() layer_kwargs initialization
- Update FSDP MixedPrecision param_dtype to use computed dtype variable
- Ensure precision preset logic is respected in both layer initialization and FSDP
- Maintain backward compatibility with original FP8-enabled default behavior

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 10, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
Broken --dtype parsing

--dtype uses a custom action=StoreExplicitAction, but StoreExplicitAction.__call__() just assigns values directly and never applies the type=torch_dtype conversion. With a custom Action, argparse won’t run the type converter for you, so passing --dtype bf16 will make opts.dtype a raw string (e.g., 'bf16') and later calls like torch.rand(..., dtype=dtype) / MixedPrecision(param_dtype=dtype) will fail.

One fix is to perform conversion inside the action:

    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, torch_dtype(values))
        setattr(namespace, f"{self.dest}_explicitly_set", True)

(or alternatively, implement a custom action that subclasses _StoreAction and preserves normal type handling).

aagallo and others added 3 commits February 10, 2026 17:05
Add type converter application in StoreExplicitAction custom action to
ensure --dtype values are properly converted from strings to torch dtype
objects, preventing runtime errors in torch operations.

- Store type converter in StoreExplicitAction.__init__
- Apply type conversion in __call__ before setting attribute value
- Add error handling for invalid type conversions
- Ensure opts.dtype contains torch dtype object, not raw string
- Fix runtime errors in torch.rand() and MixedPrecision() calls

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +340 to +346
if not no_fp8:
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
else:
recipe = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL BUG: Incorrect recipe selection when --dtype or --no-fp8 is explicitly set

When users explicitly set --dtype but also specify --precision mxfp8 or --precision nvfp4, the code ignores the precision preset and falls back to DelayedScaling recipe instead of the expected MXFP8BlockScaling or NVFP4BlockScaling recipe.

Reproduction:

# User expects MXFP8BlockScaling recipe
./fsdp.py --precision mxfp8 --dtype bf16
# But gets DelayedScaling recipe instead!

Root cause:
Lines 340-344 unconditionally use DelayedScaling when not no_fp8 is true, completely ignoring the opts.precision value.

Expected behavior:
The code should respect the precision preset even when dtype is explicitly set. The logic should check opts.precision and select the appropriate recipe (MXFP8BlockScaling for "mxfp8", NVFP4BlockScaling for "nvfp4", etc.).

Suggested fix:

# Set up recipe based on precision preset
if not no_fp8:
    match opts.precision:
        case "mxfp8":
            recipe = MXFP8BlockScaling(fp8_format=Format.E4M3)
        case "nvfp4":
            recipe = NVFP4BlockScaling()
        case _:
            precision_format = Format.HYBRID
            recipe = DelayedScaling(
                fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
            )
else:
    recipe = None

This is a merge-blocking bug as it breaks the advertised functionality of the --precision flag.

Comment on lines +334 to +346
else:
# dtype and/or no_fp8 were explicitly set - they take precedence
dtype = opts.dtype
no_fp8 = opts.no_fp8

# Set up default recipe for FP8 cases
if not no_fp8:
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
else:
recipe = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Precision preset ignored

In the else: branch (when --dtype and/or --no-fp8 were explicitly set), the code unconditionally falls back to DelayedScaling whenever FP8 is enabled:

if not no_fp8:
    recipe = DelayedScaling(...)

This ignores opts.precision entirely, so --precision mxfp8 or --precision nvfp4 will silently use DelayedScaling if the user also sets --dtype/--no-fp8 (e.g. --precision mxfp8 --dtype bf16). That breaks the advertised presets and applies the wrong quantization recipe.

Fix: in the explicit-flags path, either (a) still select recipe based on opts.precision when FP8 is enabled, or (b) explicitly error/warn and force opts.precision back to fp8 if you’re going to always use DelayedScaling.

aagallo and others added 2 commits March 4, 2026 15:06
Fix two bugs in train() precision configuration block.

- Define dtype_name unconditionally before the dtype_explicitly_set
  block to prevent NameError in the config log when dtype_explicitly_set
  is False (the common case when --dtype is not explicitly passed)
- Guard dtype override warning behind 'dtype_explicitly_set and
  opts.precision is not None' to prevent spurious warning when user
  passes --dtype without --precision (original behavior path)

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (2)

examples/pytorch/fsdp/fsdp.py, line 363
Stale dtype_name causes incorrect warning and config log

dtype_name is computed at line 338 from preset_dtype (before any override), but dtype is then updated to new_dtype on line 344 when the user explicitly sets --dtype. Neither the warning message nor the final configuration log recomputes dtype_name after the override, so both print the preset dtype instead of the effective dtype.

Example: Running --precision mxfp8 --dtype fp16:

  • Line 338: dtype_name = "bfloat16" (from MXFP8 preset)
  • Line 344: dtype = torch.float16 (overridden)
  • Line 346 warning: "Warning: --dtype bfloat16 overrides..." ← should say float16
  • Line 360 config log: "dtype=bfloat16" ← should say float16
    dtype_name = str(dtype).replace("torch.", "")

    # Apply explicit dtype override with warning
    if dtype_explicitly_set and opts.precision is not None:
        new_dtype = opts.dtype
        if new_dtype != preset_dtype:
            dtype = new_dtype
            dtype_name = str(dtype).replace("torch.", "")
            dist_print(
                f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype"
                " setting"
            )
        else:
            new_dtype_name = str(new_dtype).replace("torch.", "")
            dist_print(
                f"Info: --dtype {new_dtype_name} matches --precision {opts.precision} preset"
                " default, no override needed"
            )

        # recipe is already set correctly from preset_recipe above;
        # dtype only affects parameter storage, not the quantization recipe

    # Always log the final configuration being used
    dist_print(
        f"Training configuration: dtype={dtype_name}, "
        f"quantization={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}"
    )

examples/pytorch/fsdp/fsdp.py, line 337
Behavioral regression: recipe=None when FP8 is enabled in backward-compatible mode

When --precision is not set (the default) and --no-fp8 is also not set, the original behavior was FP8 enabled with a DelayedScaling recipe. In the new code this path results in:

  • recipe = None (returned by get_precision_preset(None))
  • no_fp8 = opts.no_fp8 = False (FP8 is enabled)
  • te.autocast(enabled=True, recipe=None, ...) ← different from original DelayedScaling

The case None: arm in get_precision_preset() was intended for "no preset, use original behavior", but original behavior was FP8 enabled with DelayedScaling, not FP8 enabled with recipe=None. Any user who relied on the default without passing --precision or --no-fp8 will now get a None recipe instead of DelayedScaling.

Fix: in the opts.precision is None branch, when no_fp8 resolves to False, supply a fallback DelayedScaling recipe to preserve the original default:

    if opts.precision is None:
        no_fp8 = opts.no_fp8
        if opts.dtype is not None:
            dtype = opts.dtype
        # Preserve original default: FP8 enabled → use DelayedScaling as before
        if not no_fp8:
            recipe = DelayedScaling(
                fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
            )
    else:

aagallo and others added 2 commits March 4, 2026 15:15
…efault

Fix two bugs in train() precision configuration block: stale
dtype_name in log messages after dtype override, and behavioral
regression where recipe=None was passed to te.autocast when FP8
was enabled in backward-compatible mode.

- Recompute dtype_name immediately after dtype = new_dtype in the
  dtype override branch so warning and config log reflect the
  effective dtype rather than the stale preset dtype
- Restore original default behavior in opts.precision is None path:
  when no_fp8 is False (FP8 enabled), supply DelayedScaling recipe
  to preserve the original te.autocast behavior instead of passing
  recipe=None which changed the implicit fallback behavior

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (2)

examples/pytorch/fsdp/fsdp.py, line 333
The condition opts.dtype is not None is always True since --dtype has default=torch.bfloat16 and will never be None. You can simplify this to just dtype = opts.dtype directly.

        dtype = opts.dtype

examples/pytorch/fsdp/fsdp.py, line 341
The line no_fp8 = preset_no_fp8 on line 340 is redundant—the same assignment already occurred on line 326. When opts.precision is not None, this variable retains its preset value and doesn't need to be reset. Remove this line to clarify the control flow.

    else:
        dist_print(f"Using precision preset: {opts.precision}")

…assignment

Remove two redundant lines in train() precision configuration block.

- Remove 'if opts.dtype is not None' guard in opts.precision is None
  branch; --dtype has default=torch.bfloat16 so opts.dtype is never
  None and the condition is always True
- Remove redundant 'no_fp8 = preset_no_fp8' assignment in the else
  branch; no_fp8 is already assigned from preset_no_fp8 at the tuple
  unpack above and reassigning it in the else branch adds noise
  without changing behavior

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (3)

examples/pytorch/fsdp/fsdp.py, line 446
recipe=None passed to te.autocast when FP8 is disabled

The original code always passed a concrete DelayedScaling instance to te.autocast even when enabled=False. This PR now passes recipe=None whenever FP8 is disabled (e.g. --precision fp32, --precision fp16, or --no-fp8 without a precision preset). If te.autocast performs any attribute access or type checking on recipe before consulting the enabled flag, this will raise a TypeError/AttributeError for all non-FP8 users.

Consider defaulting recipe to a no-op/default recipe when it would otherwise be None, matching the original defensive pattern:

        with te.autocast(enabled=not no_fp8, recipe=recipe or DelayedScaling(), amax_reduction_group=all_gpus):

Or, more explicitly, guard assignment in the opts.precision is None branch and the preset paths so recipe is never None when passed to te.autocast, by constructing a DelayedScaling() default for the non-FP8 presets.


examples/pytorch/fsdp/fsdp.py, line 298
case None in get_precision_preset returns stale defaults that are immediately overridden

The case None: branch (line 289-292) returns (torch.bfloat16, True, None), but every caller that reaches this branch (opts.precision is None) immediately overwrites all three values in the if opts.precision is None: block in train() (lines 330-337). The return value of this case is never actually used. This creates a dead-code path that can confuse future maintainers into believing the None case has meaningful semantics.

Consider either:

  • Raising a ValueError/AssertionError in case None and handling the None path entirely before calling get_precision_preset, or
  • Removing case None from this function entirely and adding a guard at the call site in train():
if opts.precision is not None:
    preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision)
    dtype, no_fp8, recipe = preset_dtype, preset_no_fp8, preset_recipe

examples/pytorch/fsdp/fsdp.py, line 213
Excessively verbose help text in --no-fp8 and --dtype

The help strings for --no-fp8 (and --dtype, lines 229-238) embed multiple paragraphs of prose, uppercase section headers (PRECEDENCE:, BEHAVIOR:, RATIONALE:, EXAMPLES:), and implementation details. --help output is intended to be a brief synopsis; detailed documentation belongs in a README or docstring. Users running --help will find the output difficult to scan.

Consider keeping help text to a concise one-liner with an example, and moving the detailed precedence rules into the module-level docstring or a separate README section.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

aagallo and others added 2 commits March 4, 2026 15:33
Address three code review issues in examples/pytorch/fsdp/fsdp.py:
recipe=None passed to te.autocast, dead case None in get_precision_preset,
and excessively verbose help strings for --no-fp8 and --dtype.

- Use 'recipe or DelayedScaling()' fallback at te.autocast call site
  to preserve original defensive pattern of always passing a concrete
  recipe instance, even when enabled=False
- Remove case None from get_precision_preset() and guard call site in
  train() with 'if opts.precision is not None' to eliminate dead-code
  path whose return values were immediately overridden by the caller
- Replace multi-paragraph help strings for --no-fp8 and --dtype with
  concise one-liner synopses; move detailed precedence rules to
  module-level docstring or README

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 439
New DelayedScaling() allocated every iteration

recipe or DelayedScaling() is evaluated on each training iteration. When no_fp8=True, recipe is None, so a fresh DelayedScaling() object is constructed on every call even though enabled=False makes it unused. The original code created the recipe object once outside the loop; the current pattern regresses that behaviour.

Move the fallback outside the loop or replace the or with an upfront assignment so each iteration reuses the same (no-op) object:

        with te.autocast(
            enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus
        ):

And ensure recipe is always non-None before the loop (e.g. recipe = recipe or DelayedScaling() after the configuration block).

aagallo and others added 2 commits March 4, 2026 15:49
…tion

Move 'recipe or DelayedScaling()' fallback to a one-time assignment
before the training loop to avoid allocating a new DelayedScaling()
object on every iteration when FP8 is disabled.

- Add 'if recipe is None: recipe = DelayedScaling()' after the
  configuration block and before the training loop so the fallback
  object is created once and reused across all iterations
- Restore clean 'recipe=recipe' in te.autocast call, matching the
  original code pattern
- Add comment explaining why recipe is always set to a concrete
  object even when FP8 is disabled

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 441
Recipe fallback initialization belongs outside the training loop

The if recipe is None: recipe = DelayedScaling() guard is placed inside the for i in range(opts.num_iters): loop, but it only has any effect on the very first iteration — after that recipe is no longer None. Keeping it here has two downsides:

  1. The is None check runs on every iteration, even though it only matters once.
  2. It mutates a variable that was established before the loop, making the control-flow confusing for future readers.

Additionally, note that when this fallback fires (i.e. recipe was None), it creates a plain DelayedScaling() with default constructor arguments rather than the custom parameters (fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max") used throughout the rest of the file. While this is harmless at runtime because te.autocast(enabled=False, ...) ignores the recipe, it is inconsistent and could cause silent behavioral differences if no_fp8 logic ever changes.

Move the guard to just after the if opts.precision is not None / else block, before the loop begins:

    # Ensure recipe is always a concrete object before passing to te.autocast.
    # When FP8 is disabled, te.autocast ignores the recipe, but some TE versions
    # perform attribute access on it regardless of the enabled flag.
    if recipe is None:
        recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max")

    for i in range(opts.num_iters):
        x = torch.rand(...)
        with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus):
            ...

aagallo and others added 2 commits March 4, 2026 15:55
…arameters

Move 'if recipe is None: recipe = DelayedScaling(...)' guard to
just before the training loop instead of inside it to avoid
redundant is-None checks on every iteration and variable mutation
inside the loop.

- Use consistent DelayedScaling parameters (fp8_format=Format.HYBRID,
  amax_history_len=32, amax_compute_algo='max') matching the rest of
  the file, rather than plain DelayedScaling() with default args
- Guard runs once before the loop; recipe is stable for all iterations
- Restores clean 'recipe=recipe' in te.autocast call with no inline
  fallback expression

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (2)

examples/pytorch/fsdp/fsdp.py, line 444
amax_reduction_group incompatible with block-scaling recipes

amax_reduction_group is an FP8-specific parameter used by DelayedScaling to aggregate per-tensor amax statistics across FSDP ranks. Both MXFP8BlockScaling and NVFP4BlockScaling use block-level scaling that does not require (and may not accept) this parameter.

Passing amax_reduction_group=all_gpus unconditionally when using --precision mxfp8 or --precision nvfp4 will likely trigger a TypeError or silently create unnecessary inter-rank communication, depending on which TE version is running.

The amax_reduction_group should only be passed when the active recipe uses per-tensor amax (i.e., DelayedScaling):

# autocast needs to be given the FSDP process group for amax reductions
amax_group = all_gpus if isinstance(recipe, DelayedScaling) else None
with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=amax_group):
    y = te_model(x)
    loss = y.sum()

examples/pytorch/fsdp/fsdp.py, line 258
Overly verbose help strings

The help text for --no-fp8, --dtype, and --precision embeds structured prose with PRECEDENCE:, BEHAVIOR:, RATIONALE:, EXAMPLES: sections. This content will be displayed verbatim by argparse --help and will be very hard to read in a terminal (no line-break control, wrapping at 80 columns, etc.).

Concise, one-sentence help strings for the CLI and a README or docstring for the extended documentation would be significantly more user-friendly. For example:

parser.add_argument(
    "--no-fp8",
    action="store_true",
    default=False,
    help=(
        "Disable te.autocast() FP8 context. Incompatible with --precision fp8/mxfp8/nvfp4."
        " Default: False."
    ),
)

The detailed interaction rules can live in a module-level docstring or in the script's README.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

aagallo and others added 2 commits March 4, 2026 16:06
…lp text

Pass amax_reduction_group conditionally based on recipe type and
replace verbose multi-paragraph help strings with concise one-liners.

- Compute amax_group = all_gpus if isinstance(recipe, DelayedScaling)
  else None and pass amax_group to te.autocast; amax_reduction_group
  is a DelayedScaling-specific parameter for per-tensor amax
  aggregation and is not accepted by MXFP8BlockScaling or
  NVFP4BlockScaling which use block-level scaling
- Replace multi-paragraph help strings for --no-fp8, --dtype, and
  --precision (with PRECEDENCE/BEHAVIOR/RATIONALE/EXAMPLES sections)
  with concise one-liner synopses suitable for terminal --help output
- Move detailed precedence rules to module-level docstring

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 409
amax_group is incorrectly set to all_gpus when FP8 is disabled.

When no_fp8=True, recipe starts as None (from fp32/fp16 presets or --no-fp8 flag), gets replaced with DelayedScaling() on line 408, and then isinstance(recipe, DelayedScaling) on line 409 evaluates to True — causing amax_group to be set to all_gpus even though FP8 is disabled.

This results in te.autocast(enabled=False, ..., amax_reduction_group=all_gpus) at line 422, which could trigger unnecessary distributed all-reduce operations if TE processes the group regardless of the enabled flag.

The guard condition should also check not no_fp8:

    if recipe is None:
        recipe = DelayedScaling()
    amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None

aagallo and others added 2 commits March 4, 2026 16:18
… comms

Add 'not no_fp8' condition to amax_group assignment to prevent
amax_reduction_group=all_gpus being passed to te.autocast when
FP8 is disabled.

- When no_fp8=True and recipe was None, the DelayedScaling fallback
  causes isinstance(recipe, DelayedScaling) to return True, which
  incorrectly set amax_group=all_gpus even though enabled=False
- Add 'not no_fp8' guard so amax_group is only set to all_gpus when
  FP8 is active AND the recipe is DelayedScaling (per-tensor amax);
  all other cases (FP8 disabled, block-scaling recipes) use None

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (2)

examples/pytorch/fsdp/fsdp.py, line 287
--no-fp8 silently ignored with fp32/fp16 precision presets

The validation guard on line 281 only rejects --no-fp8 when the precision preset is FP8-based (fp8, mxfp8, nvfp4). When --precision fp32 or --precision fp16 is combined with --no-fp8, no warning or error is raised and opts.no_fp8 is never actually read anywhere in the if opts.precision is not None: branch — no_fp8 comes entirely from the preset (which already returns True). A user who explicitly passes --no-fp8 --precision fp32 receives no feedback that the flag is redundant/ignored, which can be confusing.

Consider emitting a warning for this case:

if opts.precision in ["fp32", "fp16"] and opts.no_fp8:
    dist_print(
        f"Warning: --no-fp8 is redundant when using --precision {opts.precision} "
        "(FP8 is already disabled by this preset). The flag will be ignored."
    )

examples/pytorch/fsdp/fsdp.py, line 410
amax_reduction_group implicitly assumed None for MXFP8/NVFP4 — add a comment

The expression silently passes amax_reduction_group=None to te.autocast for MXFP8BlockScaling and NVFP4BlockScaling. While block-scaling recipes typically compute scales locally (no global amax reduction), this is a non-obvious correctness assumption. If a future TE version or variant of these recipes does require a process group, the call will silently fall back to local-only scaling without any diagnostic.

Please add an inline comment explaining why None is intentional, e.g.:

# MXFP8 and NVFP4 use local block scaling — no distributed amax reduction group needed.
# amax_reduction_group is only required for DelayedScaling (global AMAX allreduce).
amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None

…p=None

Emit a warning when --no-fp8 is combined with a non-FP8 precision
preset and add an inline comment explaining why amax_reduction_group
is None for block-scaling recipes.

- Add warning when opts.precision in ['fp32', 'fp16'] and opts.no_fp8
  is set; FP8 is already disabled by these presets so the flag is
  redundant and silently ignored without this feedback
- Add inline comment on amax_group assignment explaining that
  MXFP8BlockScaling and NVFP4BlockScaling use local block scaling
  and do not require a distributed amax reduction group, and that
  None is also correct when FP8 is disabled to avoid unnecessary
  distributed communication

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 332
preset_dtype / preset_recipe potentially unbound from static analysis perspective

preset_dtype and preset_recipe (used at lines 319 and 328) are only assigned inside the if opts.precision is not None: block (line 300–303). The outer guard at line 317 includes the same opts.precision is not None condition, so they are always logically initialized when accessed — but Python's static analyzers (mypy, pylint, pyflakes) will still raise an "possibly undefined" / "unbound" warning because the else branch (line 304) never assigns these names.

This can silently break if the condition is ever refactored without updating both blocks, and it will trigger CI lint failures in projects that treat unbound-variable warnings as errors.

Suggested fix — initialize the trio before the if-else so the later access is always safe:

preset_dtype: torch.dtype = opts.dtype   # sensible fallback
preset_recipe = None

if opts.precision is not None:
    preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision)
    ...

aagallo and others added 2 commits March 4, 2026 16:36
Initialize preset_dtype and preset_recipe with fallback values before
the 'if opts.precision is not None' block to prevent static analyzer
warnings about potentially unbound variables.

- Assign preset_dtype = opts.dtype and preset_recipe = None as
  sensible fallbacks before the if-else block; these are overwritten
  by get_precision_preset() when opts.precision is not None and are
  never accessed in the else branch
- Satisfies mypy, pylint, and pyflakes 'possibly undefined' / 'unbound'
  warnings that would otherwise trigger CI lint failures in projects
  treating unbound-variable warnings as errors
- No behavioral change; the if-else logic is unchanged

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 416
amax_group check races with recipe mutation

After the recipe = DelayedScaling() fallback on line 416, recipe is now a DelayedScaling instance even when no_fp8=True. The amax_group computation on line 420 then evaluates isinstance(recipe, DelayedScaling) as True, relying solely on not no_fp8 to produce None. This is technically correct today, but the invariant is fragile: if the guard condition is ever reordered or the fallback recipe type changes, the FSDP process group could silently be passed to a recipe that doesn't expect it.

A safer pattern is to compute amax_group before the fallback assignment, or guard it only on no_fp8:

# Compute amax_group before the defensive recipe fallback
amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None

# Ensure recipe is always a concrete object before passing to te.autocast.
if recipe is None:
    recipe = DelayedScaling()

aagallo and others added 2 commits March 4, 2026 16:55
…ce race

Move amax_group computation before the 'if recipe is None' fallback
assignment so isinstance(recipe, DelayedScaling) reflects the actual
user-selected recipe rather than the defensive fallback object.

- When recipe is None (non-FP8 presets or --no-fp8), isinstance
  correctly returns False and amax_group is set to None before the
  fallback substitutes a DelayedScaling instance
- Prevents the fragile ordering dependency where not no_fp8 was the
  sole guard against passing all_gpus to a recipe that doesn't need it
- Add inline comment explaining why amax_group must be computed before
  the recipe fallback to preserve the invariant

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants