Skip to content
Open
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
c960e6d
Add precision parameter support for multiple training formats
Feb 9, 2026
5949884
Merge branch 'NVIDIA:main' into extend-precision
aagallo Feb 9, 2026
cd97843
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2026
1341abb
Merge branch 'NVIDIA:main' into main
aagallo Feb 10, 2026
40919d8
Fix FP16 dtype mapping and implement CLI flag precedence
aagallo Feb 10, 2026
5c1db12
Add logging and documentation for precision configuration
aagallo Feb 10, 2026
e4846c8
Initialize recipe variable in all precision cases
aagallo Feb 10, 2026
26aee2f
Fix dtype flag detection to support explicit override behavior
aagallo Feb 10, 2026
295a106
Merge remote-tracking branch 'origin/main' into extend-precision
aagallo Feb 10, 2026
c9524e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
a506d40
Replace sys.argv parsing with custom action and fix default case
aagallo Feb 10, 2026
b100f2c
Merge branch 'extend-precision' of https://github.com/aagallo/Transfo…
aagallo Feb 10, 2026
ec31f2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
07a87a7
Fix params_dtype to use computed dtype from precision logic
aagallo Feb 10, 2026
748bb39
Merge branch 'extend-precision' of https://github.com/aagallo/Transfo…
aagallo Feb 10, 2026
c6fb3a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
da9f82b
Merge branch 'main' into extend-precision
aagallo Feb 10, 2026
2c5473e
Fix type conversion in StoreExplicitAction for --dtype argument
aagallo Feb 10, 2026
a9e664c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
196df3d
Fix precision preset recipe selection and add incompatibility validation
aagallo Feb 11, 2026
368820b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
9e2f34b
Merge branch 'main' into extend-precision
aagallo Feb 11, 2026
76dcb94
Fix unreachable default case and redundant recipe recreation
aagallo Feb 11, 2026
e22c2f2
Add explicit error handling for invalid precision presets
aagallo Feb 11, 2026
9d637d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
c3399a1
Merge branch 'main' into extend-precision
vthumbe1503 Mar 2, 2026
01852f4
Merge branch 'main' into extend-precision
aagallo Mar 3, 2026
cbad806
fix: address argparse robustness and cleanup issues in fsdp.py
aagallo Mar 3, 2026
70b7786
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2026
a766d06
Merge branch 'main' into extend-precision
aagallo Mar 3, 2026
3ee3973
fix: suppress spurious dtype override warning when value matches preset
aagallo Mar 3, 2026
32a54b4
Merge branch 'extend-precision' of https://github.com/aagallo/Transfo…
aagallo Mar 3, 2026
f53b845
fix: add type conversion with error handling in StoreExplicitAction._…
aagallo Mar 3, 2026
3d2caa4
fix: remove redundant condition, deduplicate recipe logic, guard re-i…
aagallo Mar 4, 2026
b473ab2
fix: validate flags before dist.init_process_group and remove redunda…
aagallo Mar 4, 2026
4e251ef
fix: simplify StoreExplicitAction and improve training config log
aagallo Mar 4, 2026
fa959d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2026
5e57ae4
fix: forward kwargs in StoreTrueExplicitAction, improve dtype log, do…
aagallo Mar 4, 2026
dcfd6b7
Merge branch 'extend-precision' of https://github.com/aagallo/Transfo…
aagallo Mar 4, 2026
1595dd9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2026
a601034
Fixing typo in code documentation
aagallo Mar 4, 2026
7755dc0
Merge branch 'main' into extend-precision
vthumbe1503 Mar 4, 2026
ef9879a
fix: format dtype in log messages and document recipe=None intent
aagallo Mar 4, 2026
afa756d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2026
91948e0
refactor: remove redundant __init__ override in StoreExplicitAction
aagallo Mar 4, 2026
6e69a8f
fix: use 'quantization' label in log and remove redundant recipe re-i…
aagallo Mar 4, 2026
4d30ba0
fix: remove redundant recipe re-instantiation in equal-dtype path
aagallo Mar 4, 2026
9509f45
Merge branch 'main' into extend-precision
vthumbe1503 Mar 4, 2026
4e3529a
fix: define dtype_name unconditionally and guard dtype override warning
aagallo Mar 4, 2026
246e948
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2026
aba45ea
fix: recompute dtype_name after override and restore DelayedScaling d…
aagallo Mar 4, 2026
90d44b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2026
f1cf722
refactor: simplify opts.dtype None check and remove redundant no_fp8 …
aagallo Mar 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 170 additions & 12 deletions examples/pytorch/fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
)

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.common.recipe import (
Format,
DelayedScaling,
MXFP8BlockScaling,
NVFP4BlockScaling,
)
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp

LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
Expand Down Expand Up @@ -64,10 +69,21 @@ def torch_dtype(d):
"bfloat16": torch.bfloat16,
}
if lowercase(d) not in typemap.keys():
raise TypeError
raise argparse.ArgumentTypeError(
f"invalid dtype '{d}'. Supported values: fp32/float32, fp16/float16, bf16/bfloat16"
)
return typemap[lowercase(d)]


def precision(d):
typemap = ["fp32", "fp16", "fp8", "mxfp8", "nvfp4"]
if lowercase(d) not in typemap:
raise argparse.ArgumentTypeError(
f"invalid precision '{d}'. Supported values: {', '.join(typemap)}"
)
return lowercase(d)


te_layer_map = {
"linear": te.Linear,
"layernorm": te.LayerNorm,
Expand All @@ -91,7 +107,6 @@ def get_layer_args(opts):
hidden_size = opts.num_heads * opts.head_dim
layer_args = (hidden_size,)
layer_kwargs = {
"params_dtype": opts.dtype,
"device": "cuda" if opts.no_defer_init else "meta",
"get_rng_state_tracker": get_cuda_rng_tracker,
}
Expand All @@ -112,6 +127,15 @@ def get_layer_args(opts):
return layer_args, layer_kwargs


class StoreExplicitAction(argparse.Action):
"""Custom action that tracks whether an argument was explicitly set."""

def __call__(self, parser, namespace, values, option_string=None):
# values already converted by argparse via action.type
setattr(namespace, self.dest, values)
setattr(namespace, f"{self.dest}_explicitly_set", True)


def parse_fsdp_args():
parser = argparse.ArgumentParser(
description="Run Transformer Engine modules with the "
Expand Down Expand Up @@ -173,7 +197,19 @@ def parse_fsdp_args():
"--no-fp8",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing action=StoreExplicitAction to track explicit --no-fp8 usage - required for precedence logic at line 250

Suggested change
"--no-fp8",
action=StoreExplicitAction,

action="store_true",
default=False,
help="Disables the te.autocast() context.",
help=(
"Disables the te.autocast() context. When set, FP8 training is disabled and the model"
" trains in standard precision (as specified by --dtype). PRECEDENCE: This flag is"
" incompatible with FP8-based --precision presets. BEHAVIOR: - Without --precision:"
" Disables FP8 training (original behavior) - With --precision fp32/fp16: Redundant but"
" harmless (already non-FP8) - With --precision fp8/mxfp8/nvfp4: RAISES ERROR"
" (incompatible flags) RATIONALE: FP8 presets explicitly request FP8 training, so"
" disabling FP8 would contradict the user's intent. Use --precision fp32/fp16 instead"
" for non-FP8 training. EXAMPLES: '--no-fp8' disables FP8 (original behavior)."
" '--precision fp8 --no-fp8' raises ValueError (incompatible). '--precision fp16' is"
" the correct way to request non-FP8 training. Default: False (FP8 enabled based on"
" configuration)."
),
)
parser.add_argument(
"--no-defer-init",
Expand All @@ -189,7 +225,35 @@ def parse_fsdp_args():
"--dtype",
type=torch_dtype,
default=torch.bfloat16,
help="Data type for input tensor and Transformer Engine module parameters.",
action=StoreExplicitAction,
help=(
"Data type for input tensor and Transformer Engine module parameters. Supported values:"
" fp32/float32, fp16/float16, bf16/bfloat16. PRECEDENCE: When explicitly set, this flag"
" overrides the dtype from --precision preset. BEHAVIOR: - Without --precision:"
" Controls parameter dtype directly - With --precision: Overrides preset's dtype but"
" preserves FP8 recipe selection EXAMPLES: '--dtype bf16' uses bfloat16 parameters"
" (original behavior). '--precision mxfp8 --dtype fp16' uses fp16 parameters with"
" MXFP8BlockScaling recipe. A warning is issued when overriding --precision dtype."
" Default: bfloat16."
),
)
parser.add_argument(
"--precision",
type=precision,
default=None,
help=(
"Precision preset for model training. Supported values: fp32, fp16, fp8, mxfp8, nvfp4."
" This is a convenience flag that configures both dtype and FP8 settings automatically."
" - fp32/fp16: Non-FP8 training with specified dtype - fp8: FP8 training with"
" DelayedScaling recipe (bf16 parameters) - mxfp8: FP8 training with MXFP8BlockScaling"
" recipe (bf16 parameters) - nvfp4: FP8 training with NVFP4BlockScaling recipe (bf16"
" parameters) PRECEDENCE RULES: - If --dtype is explicitly set, it overrides the dtype"
" from this preset (with warning) - If --no-fp8 is set with fp8/mxfp8/nvfp4 presets, an"
" error is raised (incompatible) - If this flag is not set, original behavior is used"
" (--dtype and --no-fp8 control training) EXAMPLES: '--precision mxfp8' enables MXFP8"
" FP8 training with bf16 parameters. '--precision fp8 --dtype fp16' uses fp16"
" parameters but keeps DelayedScaling recipe. Default: None (backward compatible mode)."
),
)
return parser.parse_args()
Comment on lines +240 to 258
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conflicting CLI flags
--precision and --dtype/--no-fp8 now overlap: train() overrides dtype and no_fp8 based on --precision, but --dtype/--no-fp8 are still accepted and used as defaults. As written, --precision fp8 will force no_fp8=False even if the user explicitly passed --no-fp8, and --precision fp32 will ignore an explicit --dtype fp16/bf16. This makes the CLI behavior surprising and hard to reason about; either make --precision derive defaults only when the user didn’t specify --dtype/--no-fp8, or document/enforce precedence (e.g., error on incompatible combinations).


Expand All @@ -200,15 +264,113 @@ def dist_print(text, all_ranks=False, no_new_line=False):
print(f"[GPU-{LOCAL_RANK}] " + text, end=end)


def get_precision_preset(precision_value):
"""Get dtype, no_fp8, and recipe based on precision preset.

Returns:
tuple: (dtype, no_fp8, recipe)
"""
match precision_value:
case "fp32":
return torch.float32, True, None
case "fp16":
return torch.float16, True, None
case "fp8":
recipe = DelayedScaling(
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
return torch.bfloat16, False, recipe
case "mxfp8":
recipe = MXFP8BlockScaling()
return torch.bfloat16, False, recipe
case "nvfp4":
recipe = NVFP4BlockScaling()
return torch.bfloat16, False, recipe
case None:
# Default case: no precision preset specified, use original behavior.
# dtype and no_fp8 are controlled directly by --dtype and --no-fp8 flags.
return torch.bfloat16, True, None
case _:
raise ValueError(
f"Invalid precision preset: {precision_value}. "
"Supported values: fp32, fp16, fp8, mxfp8, nvfp4"
)


def train(opts):
# Check which flags were explicitly set
dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False)

# Validate flag combinations before touching distributed state.
# Error if user requests FP8-based precision but also sets --no-fp8
# Safe to raise here because torchrun guarantees all ranks receive
# identical CLI arguments; all ranks will raise simultaneously.
if opts.precision in ["fp8", "mxfp8", "nvfp4"] and opts.no_fp8:
raise ValueError(
f"Cannot use --no-fp8 with --precision {opts.precision}. "
"These flags are incompatible. "
f"Either remove --no-fp8 to use {opts.precision} training, "
"or use --precision fp32/fp16 for non-FP8 training."
)

# Initialize torch.distributed global process group
dist.init_process_group(backend="nccl")
torch.cuda.set_device(LOCAL_RANK)
dist_print(f"WORLD_SIZE = {WORLD_SIZE}")
torch.manual_seed(opts.seed)

# Start with precision preset values
preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision)

dtype = preset_dtype
no_fp8 = preset_no_fp8
recipe = preset_recipe

# When no --precision is set, respect --no-fp8 and --dtype directly (original behavior)
if opts.precision is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So backward compatibility in this script is not that big of a deal. Having default value of precision as DelayedScaling should be ok and this if condition can be removed to make the code simpler.

no_fp8 = opts.no_fp8
dtype = opts.dtype
# Preserve original default: FP8 enabled → use DelayedScaling as before
if not no_fp8:
recipe = DelayedScaling(
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
else:
dist_print(f"Using precision preset: {opts.precision}")

dtype_name = str(dtype).replace("torch.", "")

# Apply explicit dtype override with warning
if dtype_explicitly_set and opts.precision is not None:
new_dtype = opts.dtype
if new_dtype != preset_dtype:
dtype = new_dtype
dtype_name = str(dtype).replace("torch.", "")

dist_print(
f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype"
" setting"
)
else:
new_dtype_name = str(new_dtype).replace("torch.", "")
dist_print(
f"Info: --dtype {new_dtype_name} matches --precision {opts.precision} preset"
" default, no override needed"
)

# recipe is already set correctly from preset_recipe above;
# dtype only affects parameter storage, not the quantization recipe

# Always log the final configuration being used
dist_print(
f"Training configuration: dtype={dtype_name}, "
f"quantization={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}"
)

# Construct a simple homogeneous model (only one layer type) with NO PARALLELISM
layer_args, layer_kwargs = get_layer_args(opts)
layer_kwargs["params_dtype"] = dtype

if opts.num_layers > 1:
te_layer_list = []
for i in range(opts.num_layers):
Expand Down Expand Up @@ -239,7 +401,7 @@ def train(opts):
process_group=all_gpus,
use_orig_params=True,
mixed_precision=MixedPrecision(
param_dtype=opts.dtype,
param_dtype=dtype,
reduce_dtype=torch.float32,
),
auto_wrap_policy=fsdp_wrap_policy,
Expand All @@ -258,10 +420,6 @@ def train(opts):
dist_print(f"Post-FSDP memory use = {post_mem_use}MiB")
dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}")

# Fp8 setup for TE
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

# Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded
optim = torch.optim.Adam(te_model.parameters(), lr=0.0001)

Expand All @@ -281,11 +439,11 @@ def train(opts):
opts.seq_length,
opts.batch_size,
opts.num_heads * opts.head_dim,
dtype=opts.dtype,
dtype=dtype,
device="cuda",
)
# autocast needs to be given the FSDP process group for amax reductions
with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus):
with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus):
y = te_model(x)
loss = y.sum()
# calculate gradient and take training step outside the autocast context
Expand Down