Skip to content

Add multi-precision training support to FSDP script#2662

Open
aagallo wants to merge 44 commits intoNVIDIA:mainfrom
aagallo:extend-precision
Open

Add multi-precision training support to FSDP script#2662
aagallo wants to merge 44 commits intoNVIDIA:mainfrom
aagallo:extend-precision

Conversation

@aagallo
Copy link

@aagallo aagallo commented Feb 9, 2026

Description

This PR adds comprehensive precision parameter support to the FSDP training script, enabling users to configure training with multiple precision formats (FP32, FP16, FP8, MXFP8, NVFP4) via command-line argument. The implementation includes automatic configuration of appropriate dtypes and format-specific recipes for each precision type.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added precision() type validator function supporting fp32, fp16, fp8, mxfp8, and nvfp4 formats
  • Added --precision command-line argument to parse_fsdp_args() with default value "fp8"
  • Implemented match statement in train() function to configure precision-based training parameters
  • Configured format-specific recipes for each precision type:
    • FP32/FP16: Uses standard PyTorch dtypes with FP8 disabled
    • FP8: Uses DelayedScaling recipe with HYBRID format
    • MXFP8: Uses MXFP8BlockScaling recipe with E4M3 format
    • NVFP4: Uses NVFP4BlockScaling recipe with bfloat16 dtype
  • Set appropriate no_fp8 flags based on precision selection
  • Updated layer_kwargs["params_dtype"] to use precision-determined dtype
  • Imported required recipe classes: MXFP8BlockScaling and NVFP4BlockScaling

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Please reach out to Santosh Bhavani (sbhavani@nvidia.com) for additional context on the work

aagallo and others added 3 commits February 9, 2026 16:28
Enable configurable precision training with support for FP32, FP16, FP8,
MXFP8, and NVFP4 formats. Added precision argument parser and match statement
to configure appropriate dtype and recipe based on selected precision.

- Add precision() type validator function
- Implement precision-based configuration in train()
- Support FP32, FP16, FP8, MXFP8, and NVFP4 formats
- Configure format-specific recipes (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling)
- Set appropriate no_fp8 flags based on precision selection

Signed-off-by: aagallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Greptile Summary

This PR adds a --precision convenience flag to the FSDP training example, supporting fp32, fp16, fp8, mxfp8, and nvfp4 presets. It introduces StoreExplicitAction to distinguish explicitly-set --dtype values from defaults, a get_precision_preset() helper with proper fallback, and clear precedence rules between --precision, --dtype, and --no-fp8.

The current revision has incorporated the bulk of previously-raised feedback:

  • get_precision_preset now raises ValueError on unknown presets instead of silently returning None
  • The fp16 case correctly returns torch.float16
  • Incompatible --precision fp8/mxfp8/nvfp4 + --no-fp8 combinations now raise a ValueError before dist.init_process_group()
  • The training-config log is safe: type(recipe).__name__ is only reached when no_fp8=False, at which point recipe is always a non-None object

Remaining concern worth addressing:

  • In the dtype_explicitly_set path, get_recipe_for_precision(opts.precision) is called in both the matching-dtype and the non-matching-dtype branches (lines 375 and 386), even though recipe was already correctly assigned from get_precision_preset() at line 360. This creates a redundant second recipe instance; keeping the existing preset_recipe is both simpler and avoids any subtle divergence if the preset's recipe configuration ever differs from what get_recipe_for_precision returns independently.

Confidence Score: 4/5

  • PR is safe to merge for an example script; most previously raised issues have been addressed and the logic is correct, with only minor redundancy remaining.
  • The core logic is sound: preset/flag precedence is well-defined, the incompatible-flag guard fires before distributed init, the log line is safe (recipe is non-None whenever no_fp8=False), and backward compatibility is preserved when --precision is omitted. The one unresolved structural issue is the redundant recipe recreation in the dtype-override branch, which is wasteful but not incorrect. The only new comment raised (redundant init) is cosmetic.
  • examples/pytorch/fsdp/fsdp.py — lines 374-386 (redundant get_recipe_for_precision calls in dtype-override path)

Important Files Changed

Filename Overview
examples/pytorch/fsdp/fsdp.py Adds multi-precision training support (fp32/fp16/fp8/mxfp8/nvfp4) via --precision CLI argument, StoreExplicitAction for explicit-flag detection, get_precision_preset helper, and revised train() logic; several previously-raised concerns (recipe redundancy, flag precedence, log crash) are partially addressed but some patterns still warrant attention.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[parse_fsdp_args] --> B{--precision set?}
    B -- "None (default)" --> C[Backward-compat path\ndtype = opts.dtype\nno_fp8 = opts.no_fp8]
    C --> C1{no_fp8?}
    C1 -- False --> C2[recipe = DelayedScaling]
    C1 -- True --> C3[recipe = None]

    B -- "fp32/fp16/fp8/mxfp8/nvfp4" --> D{--no-fp8 set with\nFP8 preset?}
    D -- Yes --> E[raise ValueError]
    D -- No --> F[get_precision_preset]
    F --> G[dtype, no_fp8, recipe from preset]
    G --> H{--dtype explicitly set?}
    H -- No --> I[Keep preset values]
    H -- Yes --> J{new_dtype == preset_dtype?}
    J -- No --> K[Override dtype\nwarn user\nrecreate recipe if FP8]
    J -- Yes --> L[Log 'no override needed'\nrecreate recipe if FP8]

    C2 & C3 & I & K & L --> M[layer_kwargs params_dtype = dtype]
    M --> N[Build TE model]
    N --> O[Wrap with FSDP MixedPrecision param_dtype=dtype]
    O --> P[Training loop\nte.autocast enabled=not no_fp8 recipe=recipe]
Loading

Last reviewed commit: afa756d

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +243 to +246
case "fp16":
dtype = torch.bfloat16
no_fp8 = True
case "fp8":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect fp16 dtype
In the case "fp16" branch, the code sets dtype = torch.bfloat16. That contradicts the meaning of fp16 and also diverges from the existing --dtype parsing which supports torch.float16. If a user runs with --precision fp16 expecting fp16 parameters/inputs, they’ll silently get bf16 instead.

Comment on lines +206 to 212
parser.add_argument(
"--precision",
type=precision,
default="fp8",
help="Precision to apply to model training (FP32, FP16, FP8, MXFP8, NVFP4)",
)
return parser.parse_args()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conflicting CLI flags
--precision and --dtype/--no-fp8 now overlap: train() overrides dtype and no_fp8 based on --precision, but --dtype/--no-fp8 are still accepted and used as defaults. As written, --precision fp8 will force no_fp8=False even if the user explicitly passed --no-fp8, and --precision fp32 will ignore an explicit --dtype fp16/bf16. This makes the CLI behavior surprising and hard to reason about; either make --precision derive defaults only when the user didn’t specify --dtype/--no-fp8, or document/enforce precedence (e.g., error on incompatible combinations).

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
FSDP mixed_precision mismatch
layer_kwargs["params_dtype"] and the input tensor dtype are overridden by --precision, but FSDP is still configured with mixed_precision=MixedPrecision(param_dtype=opts.dtype, ...) (and opts.dtype no longer matches the model param dtype when --precision is used). This will cause inconsistent param casting/communication behavior under FSDP for e.g. --precision fp32 (params are fp32 but FSDP thinks they’re bf16) and --precision fp16 (currently sets dtype=torch.bfloat16). FSDP param_dtype should be driven by the same dtype selected in the precision switch, or the precision switch should not override param dtype when FSDP mixed precision is enabled.

Correct FP16 precision to use torch.float16 instead of torch.bfloat16,
and add precedence logic where --dtype and --no-fp8 flags override
--precision when explicitly set, with warnings issued for conflicts.

- Fix case fp16 to use torch.float16 instead of torch.bfloat16
- Add flag precedence detection by comparing against default values
- Implement warning messages when --dtype or --no-fp8 override --precision
- Update argument parser help text to document precedence behavior
- Ensure --dtype and --no-fp8 take precedence over --precision presets

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@ptrendx ptrendx requested a review from vthumbe1503 February 10, 2026 20:09
aagallo and others added 5 commits February 10, 2026 15:31
Add informative log messages and enhanced help text to clarify precision
configuration behavior and flag precedence for better user transparency.

- Add log message showing which precision preset is being used
- Add warning logs when --dtype or --no-fp8 override --precision
- Add final training configuration log (dtype, FP8 status, recipe)
- Enhance argument parser help text with precedence examples
- Add inline code comments explaining precedence logic

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Add recipe initialization for fp32 and fp16 precision cases to prevent
undefined variable errors, even though recipe is not used when no_fp8
is set to True.

- Add DelayedScaling recipe setup for fp32 case with no_fp8=True
- Add DelayedScaling recipe setup for fp16 case with no_fp8=True
- Add inline comments explaining recipe is set up but not used by autocast
- Ensure recipe variable is defined in all precision branches for consistency

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Update flag precedence detection to use sys.argv for checking if --dtype
was explicitly set, ensuring dtype always overrides precision regardless
of whether it matches the default value.

- Add sys import for command-line argument detection
- Change dtype_explicitly_set check to use '--dtype' in sys.argv
- Change no_fp8_explicitly_set check to use '--no-fp8' in sys.argv
- Ensure --dtype bf16 correctly overrides --precision even when matching default
- Maintain warning messages when explicit flags override precision presets

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +296 to +302
case _:
dtype = torch.float16
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
no_fp8 = opts.no_fp8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recipe variable not initialized in default case. If precision validator allows an unexpected value, this will cause UnboundLocalError at line 403 when passed to te.autocast().

Suggested change
case _:
dtype = torch.float16
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
no_fp8 = opts.no_fp8
case _:
dtype = torch.float16
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
no_fp8 = False

Comment on lines +242 to +243
dtype_explicitly_set = "--dtype" in sys.argv
no_fp8_explicitly_set = "--no-fp8" in sys.argv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parsing sys.argv directly is fragile - will break if args are passed via config file, environment variables, or if arg uses = syntax (--dtype=fp16). Use parser.parse_known_args() or track which args were explicitly set via custom action class.

Example with custom action:

class StoreExplicitAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, values)
        setattr(namespace, f'{self.dest}_explicitly_set', True)

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 10, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
FSDP MixedPrecision uses opts.dtype instead of local dtype variable. When --precision is used without explicit --dtype, FSDP will use the default bfloat16 instead of the precision-determined dtype (e.g., float16 for fp8/mxfp8 or float32 for fp32).

        mixed_precision=MixedPrecision(
            param_dtype=dtype,
            reduce_dtype=torch.float32,
        ),

aagallo and others added 3 commits February 10, 2026 16:15
Replace fragile sys.argv parsing with robust custom argparse action class
to track explicitly set arguments, and fix default precision case to
explicitly set no_fp8 to False for consistent FP8-enabled behavior.

- Add StoreExplicitAction custom action class for tracking explicit arguments
- Update --dtype argument to use StoreExplicitAction
- Replace sys.argv check with getattr for dtype_explicitly_set attribute
- Remove sys import from train() function
- Fix default case to set no_fp8 = False instead of opts.no_fp8
- Ensure recipe variable is properly initialized in all code paths
- Support all argument passing methods including config files and = syntax

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

def train(opts):
# Check which flags were explicitly set
dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False)
no_fp8_explicitly_set = opts.no_fp8 != False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detection of explicit --no-fp8 is incorrect - opts.no_fp8 != False is True when the flag is absent (defaults to False). User running with defaults will trigger else branch instead of precision preset logic.

Suggested change
no_fp8_explicitly_set = opts.no_fp8 != False
no_fp8_explicitly_set = "--no-fp8" in sys.argv

@@ -173,7 +193,11 @@ def parse_fsdp_args():
"--no-fp8",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing action=StoreExplicitAction to track explicit --no-fp8 usage - required for precedence logic at line 250

Suggested change
"--no-fp8",
action=StoreExplicitAction,

Comment on lines +332 to +334
dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}")
if not no_fp8:
dist_print(f"Using FP8 recipe: {type(recipe).__name__}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logging FP8 recipe when no_fp8=True and recipe=None (line 322) will cause AttributeError on type(recipe).__name__

parser.add_argument(
"--precision",
type=precision,
default="fp8",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default changed from implicit FP8 disabled (via --dtype bf16 + --no-fp8 False) to explicit --precision fp8 enabled - existing scripts relying on default behavior will now use FP8 unexpectedly

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 10, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
param_dtype set to opts.dtype instead of local dtype variable - ignores precision preset

            param_dtype=dtype,

aagallo and others added 3 commits February 10, 2026 16:47
Remove params_dtype initialization from get_layer_args() and update FSDP
MixedPrecision to use computed dtype variable instead of raw opts.dtype,
ensuring precision presets are properly applied throughout the model.

- Remove params_dtype from get_layer_args() layer_kwargs initialization
- Update FSDP MixedPrecision param_dtype to use computed dtype variable
- Ensure precision preset logic is respected in both layer initialization and FSDP
- Maintain backward compatibility with original FP8-enabled default behavior

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 10, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
Broken --dtype parsing

--dtype uses a custom action=StoreExplicitAction, but StoreExplicitAction.__call__() just assigns values directly and never applies the type=torch_dtype conversion. With a custom Action, argparse won’t run the type converter for you, so passing --dtype bf16 will make opts.dtype a raw string (e.g., 'bf16') and later calls like torch.rand(..., dtype=dtype) / MixedPrecision(param_dtype=dtype) will fail.

One fix is to perform conversion inside the action:

    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, torch_dtype(values))
        setattr(namespace, f"{self.dest}_explicitly_set", True)

(or alternatively, implement a custom action that subclasses _StoreAction and preserves normal type handling).

aagallo and others added 3 commits February 10, 2026 17:05
Add type converter application in StoreExplicitAction custom action to
ensure --dtype values are properly converted from strings to torch dtype
objects, preventing runtime errors in torch operations.

- Store type converter in StoreExplicitAction.__init__
- Apply type conversion in __call__ before setting attribute value
- Add error handling for invalid type conversions
- Ensure opts.dtype contains torch dtype object, not raw string
- Fix runtime errors in torch.rand() and MixedPrecision() calls

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +340 to +346
if not no_fp8:
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
else:
recipe = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL BUG: Incorrect recipe selection when --dtype or --no-fp8 is explicitly set

When users explicitly set --dtype but also specify --precision mxfp8 or --precision nvfp4, the code ignores the precision preset and falls back to DelayedScaling recipe instead of the expected MXFP8BlockScaling or NVFP4BlockScaling recipe.

Reproduction:

# User expects MXFP8BlockScaling recipe
./fsdp.py --precision mxfp8 --dtype bf16
# But gets DelayedScaling recipe instead!

Root cause:
Lines 340-344 unconditionally use DelayedScaling when not no_fp8 is true, completely ignoring the opts.precision value.

Expected behavior:
The code should respect the precision preset even when dtype is explicitly set. The logic should check opts.precision and select the appropriate recipe (MXFP8BlockScaling for "mxfp8", NVFP4BlockScaling for "nvfp4", etc.).

Suggested fix:

# Set up recipe based on precision preset
if not no_fp8:
    match opts.precision:
        case "mxfp8":
            recipe = MXFP8BlockScaling(fp8_format=Format.E4M3)
        case "nvfp4":
            recipe = NVFP4BlockScaling()
        case _:
            precision_format = Format.HYBRID
            recipe = DelayedScaling(
                fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
            )
else:
    recipe = None

This is a merge-blocking bug as it breaks the advertised functionality of the --precision flag.

Comment on lines +334 to +346
else:
# dtype and/or no_fp8 were explicitly set - they take precedence
dtype = opts.dtype
no_fp8 = opts.no_fp8

# Set up default recipe for FP8 cases
if not no_fp8:
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
else:
recipe = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Precision preset ignored

In the else: branch (when --dtype and/or --no-fp8 were explicitly set), the code unconditionally falls back to DelayedScaling whenever FP8 is enabled:

if not no_fp8:
    recipe = DelayedScaling(...)

This ignores opts.precision entirely, so --precision mxfp8 or --precision nvfp4 will silently use DelayedScaling if the user also sets --dtype/--no-fp8 (e.g. --precision mxfp8 --dtype bf16). That breaks the advertised presets and applies the wrong quantization recipe.

Fix: in the explicit-flags path, either (a) still select recipe based on opts.precision when FP8 is enabled, or (b) explicitly error/warn and force opts.precision back to fp8 if you’re going to always use DelayedScaling.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (3)

examples/pytorch/fsdp/fsdp.py, line 106
Commented-out code should be deleted

Line 106 is dead code left behind when params_dtype was moved to train(). Commented-out code in a committed diff makes the history harder to follow; just delete it since layer_kwargs["params_dtype"] = dtype is now set explicitly inside train() at line 401.

        "device": "cuda" if opts.no_defer_init else "meta",

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


examples/pytorch/fsdp/fsdp.py, line 80
precision() raises a bare TypeError with no message

When a user passes an unrecognised precision value (e.g. --precision bf16), argparse catches the exception and uses its message to build the error string. A bare raise TypeError produces an empty message, so the user sees something like:

error: argument --precision: invalid precision value: 'bf16'

…which gives no hint about valid choices. Prefer argparse.ArgumentTypeError with an explicit list:

def precision(d):
    typemap = ["fp32", "fp16", "fp8", "mxfp8", "nvfp4"]
    if lowercase(d) not in typemap:
        raise argparse.ArgumentTypeError(
            f"invalid precision '{d}'. Supported values: {', '.join(typemap)}"
        )
    return lowercase(d)

examples/pytorch/fsdp/fsdp.py, line 156
StoreTrueExplicitAction.__init__ has a rigid signature

Unlike StoreExplicitAction (which uses **kwargs), this class lists only a fixed set of parameters (default, required, help). If argparse ever passes an unexpected keyword (e.g. metavar, type, or any future addition), the call will raise TypeError. The safer pattern – consistent with StoreExplicitAction – is to forward unknown kwargs:

class StoreTrueExplicitAction(argparse.Action):
    """Custom action for store_true that tracks whether flag was explicitly set."""

    def __init__(self, option_strings, dest, default=False, required=False, help=None, **kwargs):
        super().__init__(
            option_strings, dest, nargs=0, const=True, default=default, required=required, help=help
        )

aagallo and others added 3 commits March 3, 2026 17:29
Resolve three code review issues in examples/pytorch/fsdp/fsdp.py:
dead commented-out code, unhelpful TypeError in precision(), and
rigid __init__ signature in StoreTrueExplicitAction.

- Remove commented-out layer_kwargs["params_dtype"] = dtype at
  line 106; dead code after params_dtype was moved to train()
- Replace bare raise TypeError in precision() with
  argparse.ArgumentTypeError and explicit list of supported values
  (fp32, fp16, fp8, mxfp8, nvfp4) for a meaningful error message
- Add **kwargs to StoreTrueExplicitAction.__init__ and forward to
  super().__init__(); aligns with StoreExplicitAction for robustness

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 392
Spurious override warning when dtype matches preset default

The warning is emitted any time --dtype is explicitly provided alongside --precision, even when the user supplies the same value the preset would have chosen (e.g. --precision mxfp8 --dtype bf16). In that case the dtype isn't actually being "overridden" — it's identical — but the user still sees a confusing warning.

Guard the message (and the redundant recipe re-creation) behind an actual value change:

        if dtype_explicitly_set:
            new_dtype = opts.dtype
            if new_dtype != preset_dtype:
                dtype = new_dtype
                dist_print(
                    f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting"
                )

If you need the recipe regeneration guard as well, keep the if not no_fp8: recipe = get_recipe_for_precision(opts.precision) block inside the if new_dtype != preset_dtype: branch (or remove it entirely, since recipe is already correctly set from preset_recipe above).

aagallo added 2 commits March 3, 2026 17:40
Guard the --dtype override warning and dtype reassignment behind an
actual value change check to avoid a false positive when the user
explicitly passes --dtype with the same value the precision preset
would have selected.

- Add new_dtype != preset_dtype guard inside the dtype_explicitly_set
  branch so warning and dtype reassignment only trigger on a real
  override
- Suppress redundant recipe re-creation when dtype matches preset
  default; recipe is already correctly set from preset_recipe above
- No behavioral change when --dtype differs from preset default

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 141
ArgumentTypeError raised inside __call__ won't be caught by argparse

Argparse only catches ArgumentTypeError, TypeError, and ValueError when they originate from the registered type function (inside _get_value()). When these exceptions are raised from inside an action's __call__ method, argparse does NOT intercept them — the user gets a raw Python traceback instead of a clean error: argument --dtype: invalid value message.

Since action.type is None here (not passed to super().__init__()), argparse skips its own type-conversion step and passes the raw string directly to __call__. The manual conversion is correct functionally, but the error path produces an unformatted traceback.

Fix: use parser.error() (available as the first argument in __call__) or raise SystemExit with a formatted message:

def __call__(self, parser, namespace, values, option_string=None):
    if self.type_converter is not None:
        try:
            values = self.type_converter(values)
        except (ValueError, TypeError, argparse.ArgumentTypeError) as e:
            parser.error(f"argument {option_string}: invalid value: {values!r}: {e}")

    setattr(namespace, self.dest, values)
    setattr(namespace, f"{self.dest}_explicitly_set", True)

This will produce the same clean exit-with-usage-error behaviour as argparse's own type validation.

…_call__

Apply type_converter inside __call__ with proper exception handling
to ensure --dtype values are converted and validated at parse time
rather than silently passing raw strings through.

- Wrap type_converter call in try/except catching ValueError,
  TypeError, and argparse.ArgumentTypeError to surface conversion
  failures via parser.error() with a descriptive message
- Guard conversion behind if self.type_converter is not None check
  for cases where no converter is registered
- Ensures --dtype argument is correctly converted and validated
  consistently with standard argparse type= behavior

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (3)

examples/pytorch/fsdp/fsdp.py, line 373
Redundant opts.no_fp8 condition

StoreTrueExplicitAction.__call__ always sets opts.no_fp8 = True when it fires, so no_fp8_explicitly_set = True implies opts.no_fp8 = True. The trailing and opts.no_fp8 is therefore always True whenever no_fp8_explicitly_set is True, making it redundant. Consider simplifying:

        if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set:

examples/pytorch/fsdp/fsdp.py, line 338
Duplicated recipe construction logic

get_recipe_for_precision recreates the exact same recipe objects that get_precision_preset already constructs. Because both functions hard-code the same parameters (fp8_format=Format.HYBRID, amax_history_len=32, etc.), any future tuning of a recipe in one function must be mirrored in the other — a silent drift hazard.

A simpler approach is to delegate to get_precision_preset and extract the recipe:

def get_recipe_for_precision(precision_value):
    _, _, recipe = get_precision_preset(precision_value)
    if recipe is None:
        raise NotImplementedError(f"No FP8 recipe defined for precision '{precision_value}'")
    return recipe

This eliminates the duplication entirely and guarantees consistency.


examples/pytorch/fsdp/fsdp.py, line 399
Unnecessary recipe re-instantiation when dtype is unchanged

When dtype_explicitly_set=True but new_dtype == preset_dtype (i.e., the user explicitly passed the same dtype that the preset would have chosen anyway), dtype is correctly left unchanged — but get_recipe_for_precision(opts.precision) still creates a brand-new recipe instance on line 399. Since the recipe is determined by opts.precision (not dtype), this re-creation is unnecessary when the dtype hasn't actually changed.

Consider guarding the recipe re-creation:

        if dtype_explicitly_set:
            new_dtype = opts.dtype
            if new_dtype != preset_dtype:
                dtype = new_dtype
                dist_print(
                    f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting"
                )
                # dtype only affects parameter storage, not FP8 recipe
                if not no_fp8:
                    recipe = get_recipe_for_precision(opts.precision)

This avoids the unnecessary second instantiation and keeps the recipe consistent with the one returned by get_precision_preset.

…nstantiation

Address three code review issues in examples/pytorch/fsdp/fsdp.py:
redundant opts.no_fp8 check, duplicated recipe construction, and
unnecessary recipe re-instantiation when dtype matches preset.

- Remove redundant 'and opts.no_fp8' from no_fp8_explicitly_set
  guard at line 373; StoreTrueExplicitAction always sets opts.no_fp8
  to True when it fires, making the extra check always True
- Refactor get_recipe_for_precision() to delegate to
  get_precision_preset() and extract the recipe, eliminating
  duplicated recipe construction logic and silent drift hazard
  when recipe parameters are tuned in one place but not the other
- Guard recipe re-creation inside new_dtype != preset_dtype branch
  to avoid unnecessary re-instantiation when dtype_explicitly_set
  but the value matches the preset default

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (2)

examples/pytorch/fsdp/fsdp.py, line 371
ValueError raised after dist.init_process_group()

The incompatible-flags check at line 365 raises a ValueError after dist.init_process_group() has already been called at line 339. This leaves the NCCL process group initialized without a corresponding dist.destroy_process_group() call, which can cause "Address already in use" or deadlock errors on the other ranks that are still waiting inside init_process_group.

Consider validating the flag combination before calling dist.init_process_group():

def train(opts):
    dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False)
    no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False)

    # Validate flag combinations before touching distributed state
    if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set:
        raise ValueError(
            f"Cannot use --no-fp8 with --precision {opts.precision}. ..."
        )

    # Initialize torch.distributed global process group
    dist.init_process_group(backend="nccl")
    ...

examples/pytorch/fsdp/fsdp.py, line 305
Redundant default format argument for MXFP8BlockScaling

Format.E4M3 is already the default value for MXFP8BlockScaling.fp8_format (as defined in the MXFP8BlockScaling dataclass). Passing it explicitly adds noise without adding clarity.

            recipe = MXFP8BlockScaling()

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

…nt arg

Move incompatible-flags check before dist.init_process_group() to
avoid leaving the NCCL process group partially initialized, and
remove redundant fp8_format=Format.E4M3 from MXFP8BlockScaling().

- Move no_fp8_explicitly_set + precision conflict check to the top
  of train() before dist.init_process_group() to prevent deadlocks
  or 'Address already in use' errors on other ranks that are still
  waiting inside init_process_group when rank 0 raises ValueError
- Remove explicit fp8_format=Format.E4M3 from MXFP8BlockScaling()
  call; Format.E4M3 is already the dataclass default and passing it
  explicitly adds noise without adding clarity

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (2)

examples/pytorch/fsdp/fsdp.py, line 133
type parameter is captured in self.type_converter but not forwarded to super().__init__(), which causes two issues:

  1. Error messages degrade — When the user supplies an invalid value (e.g., --dtype bad), argparse's standard error message format becomes generic "argument --dtype: invalid value: 'bad'" instead of the more specific "argument --dtype: invalid torch_dtype value: 'bad'".

  2. %(type)s help text interpolation fails — If help text uses %(type)s substitution, it won't expand to the converter name (torch_dtype), causing help display issues.

The fix is to forward type to the parent argparse.Action normally, letting argparse handle the conversion before the action is called:

class StoreExplicitAction(argparse.Action):
    """Custom action that tracks whether an argument was explicitly set."""

    def __init__(self, option_strings, dest, **kwargs):
        super().__init__(option_strings, dest, **kwargs)

    def __call__(self, parser, namespace, values, option_string=None):
        # values already converted by argparse via action.type
        setattr(namespace, self.dest, values)
        setattr(namespace, f"{self.dest}_explicitly_set", True)

With this change, remove the manual try/except block and the self.type_converter field. This makes the action simpler and preserves argparse's native error reporting.


examples/pytorch/fsdp/fsdp.py, line 394
The final configuration log omits the active recipe type, making it hard to verify that the intended quantization scheme is in use—especially when --dtype overrides a precision preset.

Include the recipe type in the log output:

    dist_print(
        f"Training configuration: dtype={dtype}, "
        f"FP8={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}"
    )

This will emit messages like FP8=enabled (MXFP8BlockScaling) or FP8=disabled, providing clarity on which FP8 recipe is active.

aagallo and others added 2 commits March 3, 2026 20:59
Delegate type conversion to argparse in StoreExplicitAction and
include active FP8 recipe type in the training configuration log.

- Remove self.type_converter field and manual try/except block from
  StoreExplicitAction.__call__; forward type= kwarg to super().__init__()
  so argparse handles conversion natively before __call__ is invoked,
  restoring standard error messages and %(type)s help interpolation
- Simplify StoreExplicitAction.__init__ to use **kwargs passthrough,
  removing the now-unnecessary type= interception logic
- Include active recipe type in training configuration log output
  using type(recipe).__name__ so log emits messages like
  'FP8=enabled (MXFP8BlockScaling)' or 'FP8=disabled', making it
  easier to verify the intended quantization scheme is in use

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (3)

examples/pytorch/fsdp/fsdp.py, line 150
StoreTrueExplicitAction silently drops kwargs

StoreTrueExplicitAction.__init__ captures **kwargs from argparse but never passes it to super().__init__(). This means any unexpected kwargs (e.g., metavar, choices, or future argparse internals) are silently discarded. While this is harmless for the current --no-fp8 usage (which has no type, choices, or metavar), it can mask misconfiguration if the argument registration is ever extended.

Consider forwarding **kwargs to super().__init__() the same way StoreExplicitAction does:

    def __init__(self, option_strings, dest, default=False, required=False, help=None, **kwargs):
        super().__init__(
            option_strings, dest, nargs=0, const=True, default=default, required=required, help=help, **kwargs
        )

examples/pytorch/fsdp/fsdp.py, line 384
Dtype override warning uses post-assignment variable

The warning message at line 378 formats dtype (the local variable) rather than the user-supplied string form of the type, which makes the message slightly confusing: at that point dtype has already been reassigned to new_dtype, so the log reads "Warning: --dtype torch.float16 overrides...", which is correct but could be clearer. More importantly, the recipe is only re-fetched from get_recipe_for_precision when new_dtype != preset_dtype (lines 383-384), but when new_dtype == preset_dtype the code silently skips the log warning entirely, meaning the user gets no confirmation that the dtype override took effect. Consider logging the override even when the dtypes happen to match, or factoring out the warning to always reflect the actual effective dtype.


examples/pytorch/fsdp/fsdp.py, line 339
Incompatibility validation runs before dist.init_process_group

The ValueError at lines 334–339 is raised before dist.init_process_group is called. In a torchrun launch, every process receives the same CLI arguments, so all processes will raise the same error simultaneously — making this safe in practice. However, if this function is ever called from a launcher that does NOT guarantee identical arguments across ranks (e.g., per-rank config injection), one process would die while others wait indefinitely at init_process_group, causing a hang.

Consider either:

  • Adding a comment documenting the assumption that all ranks have identical arguments, or
  • Moving the validation to after dist.init_process_group and using a broadcast to propagate the error to all ranks.

aagallo and others added 4 commits March 3, 2026 21:15
…cument rank assumption

Address three code review issues in examples/pytorch/fsdp/fsdp.py:
silent kwargs drop in StoreTrueExplicitAction, missing confirmation
log when dtype matches preset, and undocumented torchrun assumption.

- Forward **kwargs to super().__init__() in StoreTrueExplicitAction
  to prevent silent discard of unexpected keyword arguments (e.g.
  metavar, choices) if argument registration is ever extended
- Add info log when dtype_explicitly_set but new_dtype == preset_dtype
  so user receives confirmation their --dtype flag was acknowledged
  even when it matches the preset default and no override is needed
- Add comment above no_fp8_explicitly_set validation documenting that
  raising ValueError before dist.init_process_group is safe because
  torchrun guarantees all ranks receive identical CLI arguments

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@aagallo
Copy link
Author

aagallo commented Mar 4, 2026

@vthumbe1503 Thanks for the feedback!

  • for "StoreTrueExplicitAction", I replied in the comment with an explanation. Let me know if you need more information
  • for the "precision_value is fp8 recipes", I have updated the code accordingly

I also went through the different feedback cycles. Let me know if there is anything else pending on my side. Thanks again!

@aagallo
Copy link
Author

aagallo commented Mar 4, 2026

@vthumbe1503 I saw that the build is actually failing. Is there anything I need to change on my end? I haven't touched anything else other than the fsdp.py file

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (2)

examples/pytorch/fsdp/fsdp.py, line 395
torch.dtype object printed in user-facing log messages

Both the warning (line 387) and the info (line 393) messages format a torch.dtype object directly:

f"Warning: --dtype {dtype} overrides ..."          # prints "torch.float32"
f"Info: --dtype {new_dtype} matches ..."           # prints "torch.bfloat16"

Because torch_dtype() converts the CLI string (e.g., fp32, float32) into a torch.dtype object before it is stored in opts.dtype, the original user-supplied string is lost. The log then shows torch.float32 instead of the fp32 or float32 that the user typed, which can be confusing.

Consider formatting the dtype in a friendlier way, for example:

dtype_name = str(dtype).replace("torch.", "")
dist_print(
    f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype setting"
)

and similarly for the "Info:" message.


examples/pytorch/fsdp/fsdp.py, line 369
recipe=None is a behavioral change from the original code

The original code unconditionally constructed a DelayedScaling recipe object and always passed it to te.autocast, even when --no-fp8 was set (FP8 disabled):

# old code
fp8_recipe = DelayedScaling(...)
with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, ...):

The new code passes recipe=None to te.autocast whenever no_fp8=True:

# new code (Case 1 path)
else:
    recipe = None

Inspecting te.autocast and FP8GlobalStateManager.autocast_enter, passing recipe=None is handled safely — autocast_enter substitutes get_default_fp8_recipe() when fp8_recipe is None, and check_recipe_support is skipped entirely when enabled=False. So there is no runtime error.

However, the implicit fallback to a default recipe (even when enabled=False) means the global FP8 state is populated with a recipe object that the user never specified. While harmless today, this is worth documenting with a comment so future maintainers understand the intent, especially since TE evolves quickly.

aagallo and others added 2 commits March 4, 2026 13:43
Strip 'torch.' prefix from dtype in user-facing log messages and
add a comment documenting the intentional recipe=None behavior when
FP8 is disabled.

- Replace raw dtype formatting with str(dtype).replace('torch.', '')
  in both Warning and Info log messages so users see 'float32' or
  'bfloat16' instead of 'torch.float32' or 'torch.bfloat16'
- Add inline comment on recipe=None assignment explaining that
  te.autocast safely substitutes get_default_fp8_recipe() internally
  when recipe is None, and skips check_recipe_support when
  enabled=False, so the assignment is intentional and safe despite
  populating global FP8 state with a default recipe

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 137
Redundant __init__ override in StoreExplicitAction

The __init__ override does nothing beyond calling super().__init__() with the same arguments, which Python would do automatically. This dead code adds noise without benefit and can be safely removed.

class StoreExplicitAction(argparse.Action):
    """Custom action that tracks whether an argument was explicitly set."""

    def __call__(self, parser, namespace, values, option_string=None):
        # values already converted by argparse via action.type
        setattr(namespace, self.dest, values)
        setattr(namespace, f"{self.dest}_explicitly_set", True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants