From c960e6dc515b0d663ea904c7b6478b86f9f3d5d8 Mon Sep 17 00:00:00 2001 From: aagallo Date: Mon, 9 Feb 2026 10:03:06 -0500 Subject: [PATCH 01/56] Add precision parameter support for multiple training formats Enable configurable precision training with support for FP32, FP16, FP8, MXFP8, and NVFP4 formats. Added precision argument parser and match statement to configure appropriate dtype and recipe based on selected precision. - Add precision() type validator function - Implement precision-based configuration in train() - Support FP32, FP16, FP8, MXFP8, and NVFP4 formats - Configure format-specific recipes (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling) - Set appropriate no_fp8 flags based on precision selection Signed-off-by: aagallo --- examples/pytorch/fsdp/fsdp.py | 65 +++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 7 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index b469ef56b7..2b8dc3a436 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -18,7 +18,7 @@ ) import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) @@ -68,6 +68,19 @@ def torch_dtype(d): return typemap[lowercase(d)] +def precision(d): + typemap = [ + "fp32", + "fp16", + "fp8", + "mxfp8", + "nvfp4" + ] + if lowercase(d) not in typemap: + raise TypeError + return lowercase(d) + + te_layer_map = { "linear": te.Linear, "layernorm": te.LayerNorm, @@ -191,6 +204,12 @@ def parse_fsdp_args(): default=torch.bfloat16, help="Data type for input tensor and Transformer Engine module parameters.", ) + parser.add_argument( + "--precision", + type=precision, + default="fp8", + help="Precision to apply to model training (FP32, FP16, FP8, MXFP8, NVFP4)", + ) return parser.parse_args() @@ -209,6 +228,42 @@ def train(opts): # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM layer_args, layer_kwargs = get_layer_args(opts) + + # Determining the format and recipe for the training + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = opts.no_fp8 + dtype=opts.dtype + + match opts.precision: + case "fp32": + dtype=torch.float32 + no_fp8 = True + case "fp16": + dtype=torch.bfloat16 + no_fp8 = True + case "fp8": + dtype=torch.bfloat16 + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = False + case "mxfp8": + dtype=torch.bfloat16 + precision_format = Format.E4M3 + recipe = MXFP8BlockScaling(fp8_format=precision_format) + no_fp8 = False + case "nvfp4": + dtype=torch.bfloat16 # RHT only supports bfloat16 + recipe = NVFP4BlockScaling() + no_fp8 = False + case _: + dtype=torch.bfloat16 + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = opts.no_fp8 + + layer_kwargs["params_dtype"]=dtype + if opts.num_layers > 1: te_layer_list = [] for i in range(opts.num_layers): @@ -258,10 +313,6 @@ def train(opts): dist_print(f"Post-FSDP memory use = {post_mem_use}MiB") dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}") - # Fp8 setup for TE - fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") - # Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded optim = torch.optim.Adam(te_model.parameters(), lr=0.0001) @@ -281,11 +332,11 @@ def train(opts): opts.seq_length, opts.batch_size, opts.num_heads * opts.head_dim, - dtype=opts.dtype, + dtype=dtype, device="cuda", ) # autocast needs to be given the FSDP process group for amax reductions - with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus): + with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus): y = te_model(x) loss = y.sum() # calculate gradient and take training step outside the autocast context From cd9784307abf0b2b1a1dafd93c9e09905ad0017e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Feb 2026 21:40:44 +0000 Subject: [PATCH 02/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 43 +++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 2b8dc3a436..882e76ce4b 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -18,7 +18,12 @@ ) import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling +from transformer_engine.common.recipe import ( + Format, + DelayedScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) @@ -69,13 +74,7 @@ def torch_dtype(d): def precision(d): - typemap = [ - "fp32", - "fp16", - "fp8", - "mxfp8", - "nvfp4" - ] + typemap = ["fp32", "fp16", "fp8", "mxfp8", "nvfp4"] if lowercase(d) not in typemap: raise TypeError return lowercase(d) @@ -231,38 +230,44 @@ def train(opts): # Determining the format and recipe for the training precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = opts.no_fp8 - dtype=opts.dtype + dtype = opts.dtype match opts.precision: case "fp32": - dtype=torch.float32 + dtype = torch.float32 no_fp8 = True case "fp16": - dtype=torch.bfloat16 + dtype = torch.bfloat16 no_fp8 = True case "fp8": - dtype=torch.bfloat16 + dtype = torch.bfloat16 precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = False case "mxfp8": - dtype=torch.bfloat16 + dtype = torch.bfloat16 precision_format = Format.E4M3 recipe = MXFP8BlockScaling(fp8_format=precision_format) no_fp8 = False case "nvfp4": - dtype=torch.bfloat16 # RHT only supports bfloat16 + dtype = torch.bfloat16 # RHT only supports bfloat16 recipe = NVFP4BlockScaling() no_fp8 = False case _: - dtype=torch.bfloat16 + dtype = torch.bfloat16 precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = opts.no_fp8 - layer_kwargs["params_dtype"]=dtype + layer_kwargs["params_dtype"] = dtype if opts.num_layers > 1: te_layer_list = [] From 40919d8949df4cc2b9ae440bcc8433dd4577e93e Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 13:19:40 -0500 Subject: [PATCH 03/56] Fix FP16 dtype mapping and implement CLI flag precedence Correct FP16 precision to use torch.float16 instead of torch.bfloat16, and add precedence logic where --dtype and --no-fp8 flags override --precision when explicitly set, with warnings issued for conflicts. - Fix case fp16 to use torch.float16 instead of torch.bfloat16 - Add flag precedence detection by comparing against default values - Implement warning messages when --dtype or --no-fp8 override --precision - Update argument parser help text to document precedence behavior - Ensure --dtype and --no-fp8 take precedence over --precision presets Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 2b8dc3a436..f38e04207d 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -220,6 +220,12 @@ def dist_print(text, all_ranks=False, no_new_line=False): def train(opts): + + # Check which flags were explicitly set + dtype_explicitly_set = opts.dtype != torch.bfloat16 + no_fp8_explicitly_set = opts.no_fp8 != False + precision_is_non_default = opts.precision != "fp8" + # Initialize torch.distributed global process group dist.init_process_group(backend="nccl") torch.cuda.set_device(LOCAL_RANK) From 5c1db12c3b34e8d836be0285ed33a4f5b449c308 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 15:31:35 -0500 Subject: [PATCH 04/56] Add logging and documentation for precision configuration Add informative log messages and enhanced help text to clarify precision configuration behavior and flag precedence for better user transparency. - Add log message showing which precision preset is being used - Add warning logs when --dtype or --no-fp8 override --precision - Add final training configuration log (dtype, FP8 status, recipe) - Enhance argument parser help text with precedence examples - Add inline code comments explaining precedence logic Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 103 +++++++++++++++++++++++----------- 1 file changed, 69 insertions(+), 34 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index f38e04207d..78a4fa2115 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -186,7 +186,11 @@ def parse_fsdp_args(): "--no-fp8", action="store_true", default=False, - help="Disables the te.autocast() context.", + help="Disables the te.autocast() context. When set, FP8 training is disabled " + + "and the model trains in standard precision (as specified by --dtype). " + + "Takes precedence over --precision if both are specified. " + + "Example: '--no-fp8 --precision fp8' will disable FP8 despite fp8 preset. " + + "Default: False (FP8 enabled based on precision).", ) parser.add_argument( "--no-defer-init", @@ -202,13 +206,23 @@ def parse_fsdp_args(): "--dtype", type=torch_dtype, default=torch.bfloat16, - help="Data type for input tensor and Transformer Engine module parameters.", + help="Data type for input tensor and Transformer Engine module parameters. " + + "Supported values: fp32/float32, fp16/float16, bf16/bfloat16. " + + "Takes precedence over --precision if both are specified. " + + "Example: '--dtype fp16 --precision fp8' will use fp16 dtype and ignore fp8 preset. " + + "Default: bfloat16.", ) parser.add_argument( "--precision", type=precision, default="fp8", - help="Precision to apply to model training (FP32, FP16, FP8, MXFP8, NVFP4)", + help="Precision preset for model training. Supported values: FP32, FP16, FP8, MXFP8, NVFP4. " + + "This is a convenience flag that configures both dtype and FP8 settings automatically. " + + "If --dtype or --no-fp8 are explicitly specified, they take precedence over this flag " + + "and a warning will be issued. " + + "Precedence: --dtype and --no-fp8 override --precision. " + + "Example: Use '--precision fp8' for quick setup, or '--dtype bf16 --no-fp8' for explicit control. " + + "Default: fp8.", ) return parser.parse_args() @@ -235,39 +249,60 @@ def train(opts): # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM layer_args, layer_kwargs = get_layer_args(opts) - # Determining the format and recipe for the training - precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") - no_fp8 = opts.no_fp8 - dtype=opts.dtype - - match opts.precision: - case "fp32": - dtype=torch.float32 - no_fp8 = True - case "fp16": - dtype=torch.bfloat16 - no_fp8 = True - case "fp8": - dtype=torch.bfloat16 - precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") - no_fp8 = False - case "mxfp8": - dtype=torch.bfloat16 - precision_format = Format.E4M3 - recipe = MXFP8BlockScaling(fp8_format=precision_format) - no_fp8 = False - case "nvfp4": - dtype=torch.bfloat16 # RHT only supports bfloat16 - recipe = NVFP4BlockScaling() - no_fp8 = False - case _: - dtype=torch.bfloat16 + if not dtype_explicitly_set and not no_fp8_explicitly_set: + + dist_print(f"Using precision preset: {opts.precision}") + + match opts.precision: + case "fp32": + dtype=torch.float32 + no_fp8 = True + case "fp16": + dtype=torch.float16 + no_fp8 = True + case "fp8": + dtype=torch.float16 + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = False + case "mxfp8": + dtype=torch.float16 + precision_format = Format.E4M3 + recipe = MXFP8BlockScaling(fp8_format=precision_format) + no_fp8 = False + case "nvfp4": + dtype=torch.bfloat16 # RHT only supports bfloat16 + recipe = NVFP4BlockScaling() + no_fp8 = False + case _: + dtype=torch.float16 + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = opts.no_fp8 + else: + # dtype and/or no_fp8 were explicitly set - they take precedence + dtype = opts.dtype + no_fp8 = opts.no_fp8 + + # Set up default recipe for FP8 cases + if not no_fp8: precision_format = Format.HYBRID recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") - no_fp8 = opts.no_fp8 - + else: + recipe = None + + # Warn if precision was also set to non-default (being overridden) + if precision_is_non_default: + if dtype_explicitly_set: + dist_print(f"Warning: --dtype {dtype} overrides --precision {opts.precision}") + if no_fp8_explicitly_set: + dist_print(f"Warning: --no-fp8 overrides --precision {opts.precision}") + + # Always log the final configuration being used + dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") + if not no_fp8: + dist_print(f"Using FP8 recipe: {type(recipe).__name__}") + layer_kwargs["params_dtype"]=dtype if opts.num_layers > 1: From e4846c8a48aaea2a1fcf8c41cb8e34aa12597a27 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 15:40:23 -0500 Subject: [PATCH 05/56] Initialize recipe variable in all precision cases Add recipe initialization for fp32 and fp16 precision cases to prevent undefined variable errors, even though recipe is not used when no_fp8 is set to True. - Add DelayedScaling recipe setup for fp32 case with no_fp8=True - Add DelayedScaling recipe setup for fp16 case with no_fp8=True - Add inline comments explaining recipe is set up but not used by autocast - Ensure recipe variable is defined in all precision branches for consistency Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 78a4fa2115..b28b3d83f4 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -256,9 +256,19 @@ def train(opts): match opts.precision: case "fp32": dtype=torch.float32 + + #set up, but not used by autocast with no-fp8 set to true + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = True case "fp16": dtype=torch.float16 + + #set up, but not used by autocast with no-fp8 set to true + precision_format = Format.HYBRID + recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + no_fp8 = True case "fp8": dtype=torch.float16 From 26aee2f25a85e0534dd778a8021a4e39a9d27c0f Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 15:48:42 -0500 Subject: [PATCH 06/56] Fix dtype flag detection to support explicit override behavior Update flag precedence detection to use sys.argv for checking if --dtype was explicitly set, ensuring dtype always overrides precision regardless of whether it matches the default value. - Add sys import for command-line argument detection - Change dtype_explicitly_set check to use '--dtype' in sys.argv - Change no_fp8_explicitly_set check to use '--no-fp8' in sys.argv - Ensure --dtype bf16 correctly overrides --precision even when matching default - Maintain warning messages when explicit flags override precision presets Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index b28b3d83f4..5b5e4b80ea 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -234,10 +234,11 @@ def dist_print(text, all_ranks=False, no_new_line=False): def train(opts): + import sys # Check which flags were explicitly set - dtype_explicitly_set = opts.dtype != torch.bfloat16 - no_fp8_explicitly_set = opts.no_fp8 != False + dtype_explicitly_set = '--dtype' in sys.argv + no_fp8_explicitly_set = '--no-fp8' in sys.argv precision_is_non_default = opts.precision != "fp8" # Initialize torch.distributed global process group @@ -264,7 +265,7 @@ def train(opts): no_fp8 = True case "fp16": dtype=torch.float16 - + #set up, but not used by autocast with no-fp8 set to true precision_format = Format.HYBRID recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") From c9524e38dbf9ce85a6c49a41eda66f753f601569 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 21:03:02 +0000 Subject: [PATCH 07/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 53 ++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index b8962e48f9..ae4f00c6cd 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -215,12 +215,15 @@ def parse_fsdp_args(): "--precision", type=precision, default="fp8", - help="Precision preset for model training. Supported values: FP32, FP16, FP8, MXFP8, NVFP4. " + help=( + "Precision preset for model training. Supported values: FP32, FP16, FP8, MXFP8, NVFP4. " + ) + "This is a convenience flag that configures both dtype and FP8 settings automatically. " + "If --dtype or --no-fp8 are explicitly specified, they take precedence over this flag " + "and a warning will be issued. " + "Precedence: --dtype and --no-fp8 override --precision. " - + "Example: Use '--precision fp8' for quick setup, or '--dtype bf16 --no-fp8' for explicit control. " + + "Example: Use '--precision fp8' for quick setup, or '--dtype bf16 --no-fp8' for explicit" + " control. " + "Default: fp8.", ) return parser.parse_args() @@ -236,8 +239,8 @@ def train(opts): import sys # Check which flags were explicitly set - dtype_explicitly_set = '--dtype' in sys.argv - no_fp8_explicitly_set = '--no-fp8' in sys.argv + dtype_explicitly_set = "--dtype" in sys.argv + no_fp8_explicitly_set = "--no-fp8" in sys.argv precision_is_non_default = opts.precision != "fp8" # Initialize torch.distributed global process group @@ -255,39 +258,47 @@ def train(opts): match opts.precision: case "fp32": - dtype=torch.float32 + dtype = torch.float32 - #set up, but not used by autocast with no-fp8 set to true + # set up, but not used by autocast with no-fp8 set to true precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = True case "fp16": - dtype=torch.float16 + dtype = torch.float16 - #set up, but not used by autocast with no-fp8 set to true + # set up, but not used by autocast with no-fp8 set to true precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = True case "fp8": - dtype=torch.float16 + dtype = torch.float16 precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = False case "mxfp8": - dtype=torch.float16 + dtype = torch.float16 precision_format = Format.E4M3 recipe = MXFP8BlockScaling(fp8_format=precision_format) no_fp8 = False case "nvfp4": - dtype=torch.bfloat16 # RHT only supports bfloat16 + dtype = torch.bfloat16 # RHT only supports bfloat16 recipe = NVFP4BlockScaling() no_fp8 = False case _: - dtype=torch.float16 + dtype = torch.float16 precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = opts.no_fp8 else: # dtype and/or no_fp8 were explicitly set - they take precedence @@ -297,11 +308,13 @@ def train(opts): # Set up default recipe for FP8 cases if not no_fp8: precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) else: recipe = None - # Warn if precision was also set to non-default (being overridden) + # Warn if precision was also set to non-default (being overridden) if precision_is_non_default: if dtype_explicitly_set: dist_print(f"Warning: --dtype {dtype} overrides --precision {opts.precision}") @@ -312,8 +325,8 @@ def train(opts): dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") if not no_fp8: dist_print(f"Using FP8 recipe: {type(recipe).__name__}") - - layer_kwargs["params_dtype"]=dtype + + layer_kwargs["params_dtype"] = dtype if opts.num_layers > 1: te_layer_list = [] From a506d402ed6e78547d59b096f57f90539840a872 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 16:15:36 -0500 Subject: [PATCH 08/56] Replace sys.argv parsing with custom action and fix default case Replace fragile sys.argv parsing with robust custom argparse action class to track explicitly set arguments, and fix default precision case to explicitly set no_fp8 to False for consistent FP8-enabled behavior. - Add StoreExplicitAction custom action class for tracking explicit arguments - Update --dtype argument to use StoreExplicitAction - Replace sys.argv check with getattr for dtype_explicitly_set attribute - Remove sys import from train() function - Fix default case to set no_fp8 = False instead of opts.no_fp8 - Ensure recipe variable is properly initialized in all code paths - Support all argument passing methods including config files and = syntax Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index b8962e48f9..13185a367e 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -124,6 +124,13 @@ def get_layer_args(opts): return layer_args, layer_kwargs +class StoreExplicitAction(argparse.Action): + """Custom action that tracks whether an argument was explicitly set.""" + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + setattr(namespace, f'{self.dest}_explicitly_set', True) + + def parse_fsdp_args(): parser = argparse.ArgumentParser( description="Run Transformer Engine modules with the " @@ -205,6 +212,7 @@ def parse_fsdp_args(): "--dtype", type=torch_dtype, default=torch.bfloat16, + action=StoreExplicitAction, # Add custom action help="Data type for input tensor and Transformer Engine module parameters. " + "Supported values: fp32/float32, fp16/float16, bf16/bfloat16. " + "Takes precedence over --precision if both are specified. " @@ -233,11 +241,9 @@ def dist_print(text, all_ranks=False, no_new_line=False): def train(opts): - import sys - # Check which flags were explicitly set - dtype_explicitly_set = '--dtype' in sys.argv - no_fp8_explicitly_set = '--no-fp8' in sys.argv + dtype_explicitly_set = getattr(opts, 'dtype_explicitly_set', False) + no_fp8_explicitly_set = opts.no_fp8 != False precision_is_non_default = opts.precision != "fp8" # Initialize torch.distributed global process group @@ -288,7 +294,7 @@ def train(opts): dtype=torch.float16 precision_format = Format.HYBRID recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") - no_fp8 = opts.no_fp8 + no_fp8 = False else: # dtype and/or no_fp8 were explicitly set - they take precedence dtype = opts.dtype From ec31f2a736f6e4f78a92080aa0a491c1abba5acb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 21:18:47 +0000 Subject: [PATCH 09/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 465d918663..45accb47ee 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -126,9 +126,10 @@ def get_layer_args(opts): class StoreExplicitAction(argparse.Action): """Custom action that tracks whether an argument was explicitly set.""" + def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, values) - setattr(namespace, f'{self.dest}_explicitly_set', True) + setattr(namespace, f"{self.dest}_explicitly_set", True) def parse_fsdp_args(): @@ -245,7 +246,7 @@ def dist_print(text, all_ranks=False, no_new_line=False): def train(opts): # Check which flags were explicitly set - dtype_explicitly_set = getattr(opts, 'dtype_explicitly_set', False) + dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) no_fp8_explicitly_set = opts.no_fp8 != False precision_is_non_default = opts.precision != "fp8" @@ -302,7 +303,9 @@ def train(opts): case _: dtype = torch.float16 precision_format = Format.HYBRID - recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") + recipe = DelayedScaling( + fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + ) no_fp8 = False else: # dtype and/or no_fp8 were explicitly set - they take precedence From 07a87a7891dcb42be62109ae608ad3a08b6c3ac2 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 16:47:43 -0500 Subject: [PATCH 10/56] Fix params_dtype to use computed dtype from precision logic Remove params_dtype initialization from get_layer_args() and update FSDP MixedPrecision to use computed dtype variable instead of raw opts.dtype, ensuring precision presets are properly applied throughout the model. - Remove params_dtype from get_layer_args() layer_kwargs initialization - Update FSDP MixedPrecision param_dtype to use computed dtype variable - Ensure precision preset logic is respected in both layer initialization and FSDP - Maintain backward compatibility with original FP8-enabled default behavior Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 465d918663..65ac782f54 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -103,7 +103,7 @@ def get_layer_args(opts): hidden_size = opts.num_heads * opts.head_dim layer_args = (hidden_size,) layer_kwargs = { - "params_dtype": opts.dtype, + #"params_dtype": opts.dtype, "device": "cuda" if opts.no_defer_init else "meta", "get_rng_state_tracker": get_cuda_rng_tracker, } @@ -130,6 +130,15 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, values) setattr(namespace, f'{self.dest}_explicitly_set', True) +class StoreTrueExplicitAction(argparse.Action): + """Custom action for store_true that tracks whether flag was explicitly set.""" + def __init__(self, option_strings, dest, default=False, required=False, help=None): + super().__init__(option_strings, dest, nargs=0, const=True, + default=default, required=required, help=help) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, True) + setattr(namespace, f'{self.dest}_explicitly_set', True) def parse_fsdp_args(): parser = argparse.ArgumentParser( @@ -190,7 +199,7 @@ def parse_fsdp_args(): ) parser.add_argument( "--no-fp8", - action="store_true", + action=StoreTrueExplicitAction, # Use custom action default=False, help="Disables the te.autocast() context. When set, FP8 training is disabled " + "and the model trains in standard precision (as specified by --dtype). " @@ -246,7 +255,7 @@ def dist_print(text, all_ranks=False, no_new_line=False): def train(opts): # Check which flags were explicitly set dtype_explicitly_set = getattr(opts, 'dtype_explicitly_set', False) - no_fp8_explicitly_set = opts.no_fp8 != False + no_fp8_explicitly_set = getattr(opts, 'no_fp8_explicitly_set', False) # Fixed precision_is_non_default = opts.precision != "fp8" # Initialize torch.distributed global process group @@ -284,14 +293,14 @@ def train(opts): no_fp8 = True case "fp8": - dtype = torch.float16 + dtype = torch.bfloat16 precision_format = Format.HYBRID recipe = DelayedScaling( fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" ) no_fp8 = False case "mxfp8": - dtype = torch.float16 + dtype = torch.bfloat16 precision_format = Format.E4M3 recipe = MXFP8BlockScaling(fp8_format=precision_format) no_fp8 = False @@ -300,7 +309,7 @@ def train(opts): recipe = NVFP4BlockScaling() no_fp8 = False case _: - dtype = torch.float16 + dtype = torch.bfloat16 precision_format = Format.HYBRID recipe = DelayedScaling(fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max") no_fp8 = False @@ -327,7 +336,7 @@ def train(opts): # Always log the final configuration being used dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") - if not no_fp8: + if not no_fp8 and recipe is not None: dist_print(f"Using FP8 recipe: {type(recipe).__name__}") layer_kwargs["params_dtype"] = dtype @@ -362,7 +371,7 @@ def train(opts): process_group=all_gpus, use_orig_params=True, mixed_precision=MixedPrecision( - param_dtype=opts.dtype, + param_dtype=dtype, reduce_dtype=torch.float32, ), auto_wrap_policy=fsdp_wrap_policy, From c6fb3a51f7ea23370f67a657224e53aad6cd0c71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 21:56:21 +0000 Subject: [PATCH 11/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 6431712930..9287688616 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -103,7 +103,7 @@ def get_layer_args(opts): hidden_size = opts.num_heads * opts.head_dim layer_args = (hidden_size,) layer_kwargs = { - #"params_dtype": opts.dtype, + # "params_dtype": opts.dtype, "device": "cuda" if opts.no_defer_init else "meta", "get_rng_state_tracker": get_cuda_rng_tracker, } @@ -131,15 +131,18 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, values) setattr(namespace, f"{self.dest}_explicitly_set", True) + class StoreTrueExplicitAction(argparse.Action): """Custom action for store_true that tracks whether flag was explicitly set.""" + def __init__(self, option_strings, dest, default=False, required=False, help=None): - super().__init__(option_strings, dest, nargs=0, const=True, - default=default, required=required, help=help) + super().__init__( + option_strings, dest, nargs=0, const=True, default=default, required=required, help=help + ) def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, True) - setattr(namespace, f'{self.dest}_explicitly_set', True) + setattr(namespace, f"{self.dest}_explicitly_set", True) def parse_fsdp_args(): @@ -256,8 +259,8 @@ def dist_print(text, all_ranks=False, no_new_line=False): def train(opts): # Check which flags were explicitly set - dtype_explicitly_set = getattr(opts, 'dtype_explicitly_set', False) - no_fp8_explicitly_set = getattr(opts, 'no_fp8_explicitly_set', False) # Fixed + dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) + no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False) # Fixed precision_is_non_default = opts.precision != "fp8" # Initialize torch.distributed global process group From 2c5473e92e6c0df93bf349426a43d7ea747c39c5 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 17:09:52 -0500 Subject: [PATCH 12/56] Fix type conversion in StoreExplicitAction for --dtype argument Add type converter application in StoreExplicitAction custom action to ensure --dtype values are properly converted from strings to torch dtype objects, preventing runtime errors in torch operations. - Store type converter in StoreExplicitAction.__init__ - Apply type conversion in __call__ before setting attribute value - Add error handling for invalid type conversions - Ensure opts.dtype contains torch dtype object, not raw string - Fix runtime errors in torch.rand() and MixedPrecision() calls Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 9287688616..0fe2e198cb 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -126,10 +126,22 @@ def get_layer_args(opts): class StoreExplicitAction(argparse.Action): """Custom action that tracks whether an argument was explicitly set.""" + def __init__(self, option_strings, dest, type=None, **kwargs): + super().__init__(option_strings, dest, **kwargs) + self.type_converter = type # Store the type converter def __call__(self, parser, namespace, values, option_string=None): + # Apply the type converter if one was provided + if self.type_converter is not None: + try: + values = self.type_converter(values) + except (ValueError, TypeError) as e: + raise argparse.ArgumentTypeError( + f"invalid {self.dest} value: {values}" + ) from e + setattr(namespace, self.dest, values) - setattr(namespace, f"{self.dest}_explicitly_set", True) + setattr(namespace, f'{self.dest}_explicitly_set', True) class StoreTrueExplicitAction(argparse.Action): From a9e664c4d4d7d1cc57ff5e59a4c0838726e4f6cd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 22:11:04 +0000 Subject: [PATCH 13/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 0fe2e198cb..c1a500815a 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -126,6 +126,7 @@ def get_layer_args(opts): class StoreExplicitAction(argparse.Action): """Custom action that tracks whether an argument was explicitly set.""" + def __init__(self, option_strings, dest, type=None, **kwargs): super().__init__(option_strings, dest, **kwargs) self.type_converter = type # Store the type converter @@ -136,12 +137,10 @@ def __call__(self, parser, namespace, values, option_string=None): try: values = self.type_converter(values) except (ValueError, TypeError) as e: - raise argparse.ArgumentTypeError( - f"invalid {self.dest} value: {values}" - ) from e + raise argparse.ArgumentTypeError(f"invalid {self.dest} value: {values}") from e setattr(namespace, self.dest, values) - setattr(namespace, f'{self.dest}_explicitly_set', True) + setattr(namespace, f"{self.dest}_explicitly_set", True) class StoreTrueExplicitAction(argparse.Action): From 196df3df94c07e230e152a0d1f4db37fa148e987 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 22:45:46 -0500 Subject: [PATCH 14/56] Fix precision preset recipe selection and add incompatibility validation Address critical bugs where FP8 recipes were incorrectly selected when explicit flags were set, and add validation to prevent incompatible flag combinations that would silently disable FP8 training. - Remove default value from --precision parameter (set to None for backward compatibility) - Add get_precision_preset() and get_recipe_for_precision() helper functions - Implement two-path configuration logic: backward compatibility mode vs. precision preset mode - Add incompatibility validation: raise ValueError when --no-fp8 used with fp8/mxfp8/nvfp4 presets - Preserve FP8 recipe selection when --dtype explicitly overrides precision preset dtype - Fix fp16 case to correctly map to torch.float16 instead of torch.bfloat16 - Update parameter help text with precedence rules and usage examples - Ensure backward compatibility: scripts without --precision work identically to original version Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 228 ++++++++++++++++++++-------------- 1 file changed, 138 insertions(+), 90 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index c1a500815a..fa68b72840 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -215,13 +215,24 @@ def parse_fsdp_args(): ) parser.add_argument( "--no-fp8", - action=StoreTrueExplicitAction, # Use custom action + action=StoreTrueExplicitAction, default=False, - help="Disables the te.autocast() context. When set, FP8 training is disabled " - + "and the model trains in standard precision (as specified by --dtype). " - + "Takes precedence over --precision if both are specified. " - + "Example: '--no-fp8 --precision fp8' will disable FP8 despite fp8 preset. " - + "Default: False (FP8 enabled based on precision).", + help=( + "Disables the te.autocast() context. When set, FP8 training is disabled " + "and the model trains in standard precision (as specified by --dtype). " + "PRECEDENCE: This flag is incompatible with FP8-based --precision presets. " + "BEHAVIOR: " + "- Without --precision: Disables FP8 training (original behavior) " + "- With --precision fp32/fp16: Redundant but harmless (already non-FP8) " + "- With --precision fp8/mxfp8/nvfp4: RAISES ERROR (incompatible flags) " + "RATIONALE: FP8 presets explicitly request FP8 training, so disabling FP8 " + "would contradict the user's intent. Use --precision fp32/fp16 instead for non-FP8 training. " + "EXAMPLES: " + "'--no-fp8' disables FP8 (original behavior). " + "'--precision fp8 --no-fp8' raises ValueError (incompatible). " + "'--precision fp16' is the correct way to request non-FP8 training. " + "Default: False (FP8 enabled based on configuration)." + ), ) parser.add_argument( "--no-defer-init", @@ -237,27 +248,41 @@ def parse_fsdp_args(): "--dtype", type=torch_dtype, default=torch.bfloat16, - action=StoreExplicitAction, # Add custom action - help="Data type for input tensor and Transformer Engine module parameters. " - + "Supported values: fp32/float32, fp16/float16, bf16/bfloat16. " - + "Takes precedence over --precision if both are specified. " - + "Example: '--dtype fp16 --precision fp8' will use fp16 dtype and ignore fp8 preset. " - + "Default: bfloat16.", + action=StoreExplicitAction, + help=( + "Data type for input tensor and Transformer Engine module parameters. " + "Supported values: fp32/float32, fp16/float16, bf16/bfloat16. " + "PRECEDENCE: When explicitly set, this flag overrides the dtype from --precision preset. " + "BEHAVIOR: " + "- Without --precision: Controls parameter dtype directly " + "- With --precision: Overrides preset's dtype but preserves FP8 recipe selection " + "EXAMPLES: " + "'--dtype bf16' uses bfloat16 parameters (original behavior). " + "'--precision mxfp8 --dtype fp16' uses fp16 parameters with MXFP8BlockScaling recipe. " + "A warning is issued when overriding --precision dtype. " + "Default: bfloat16." + ), ) parser.add_argument( "--precision", type=precision, - default="fp8", + default=None, help=( - "Precision preset for model training. Supported values: FP32, FP16, FP8, MXFP8, NVFP4. " - ) - + "This is a convenience flag that configures both dtype and FP8 settings automatically. " - + "If --dtype or --no-fp8 are explicitly specified, they take precedence over this flag " - + "and a warning will be issued. " - + "Precedence: --dtype and --no-fp8 override --precision. " - + "Example: Use '--precision fp8' for quick setup, or '--dtype bf16 --no-fp8' for explicit" - " control. " - + "Default: fp8.", + "Precision preset for model training. Supported values: fp32, fp16, fp8, mxfp8, nvfp4. " + "This is a convenience flag that configures both dtype and FP8 settings automatically. " + "- fp32/fp16: Non-FP8 training with specified dtype " + "- fp8: FP8 training with DelayedScaling recipe (bf16 parameters) " + "- mxfp8: FP8 training with MXFP8BlockScaling recipe (bf16 parameters) " + "- nvfp4: FP8 training with NVFP4BlockScaling recipe (bf16 parameters) " + "PRECEDENCE RULES: " + "- If --dtype is explicitly set, it overrides the dtype from this preset (with warning) " + "- If --no-fp8 is set with fp8/mxfp8/nvfp4 presets, an error is raised (incompatible) " + "- If this flag is not set, original behavior is used (--dtype and --no-fp8 control training) " + "EXAMPLES: " + "'--precision mxfp8' enables MXFP8 FP8 training with bf16 parameters. " + "'--precision fp8 --dtype fp16' uses fp16 parameters but keeps DelayedScaling recipe. " + "Default: None (backward compatible mode)." + ), ) return parser.parse_args() @@ -267,12 +292,60 @@ def dist_print(text, all_ranks=False, no_new_line=False): end = "" if no_new_line else "\n" print(f"[GPU-{LOCAL_RANK}] " + text, end=end) +def get_precision_preset(precision_value): + """Get dtype, no_fp8, and recipe based on precision preset. + + Returns: + tuple: (dtype, no_fp8, recipe) + """ + match precision_value: + case "fp32": + return torch.float32, True, None + case "fp16": + return torch.float16, True, None + case "fp8": + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + return torch.bfloat16, False, recipe + case "mxfp8": + recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + return torch.bfloat16, False, recipe + case "nvfp4": + recipe = NVFP4BlockScaling() + return torch.bfloat16, False, recipe + case _: + # Default to fp8 behavior + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + return torch.bfloat16, False, recipe + + +def get_recipe_for_precision(precision_value): + """Get FP8 recipe based on precision preset (when FP8 is enabled). + + Args: + precision_value: The precision preset string + + Returns: + Recipe object for FP8 training + """ + match precision_value: + case "mxfp8": + return MXFP8BlockScaling(fp8_format=Format.E4M3) + case "nvfp4": + return NVFP4BlockScaling() + case _: + # Default to DelayedScaling for fp8 or any other value + return DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) def train(opts): # Check which flags were explicitly set dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) - no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False) # Fixed - precision_is_non_default = opts.precision != "fp8" + no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False) # Initialize torch.distributed global process group dist.init_process_group(backend="nccl") @@ -280,83 +353,58 @@ def train(opts): dist_print(f"WORLD_SIZE = {WORLD_SIZE}") torch.manual_seed(opts.seed) - # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM - layer_args, layer_kwargs = get_layer_args(opts) - - if not dtype_explicitly_set and not no_fp8_explicitly_set: - - dist_print(f"Using precision preset: {opts.precision}") - - match opts.precision: - case "fp32": - dtype = torch.float32 - - # set up, but not used by autocast with no-fp8 set to true - precision_format = Format.HYBRID - recipe = DelayedScaling( - fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" - ) - - no_fp8 = True - case "fp16": - dtype = torch.float16 - - # set up, but not used by autocast with no-fp8 set to true - precision_format = Format.HYBRID - recipe = DelayedScaling( - fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" - ) - - no_fp8 = True - case "fp8": - dtype = torch.bfloat16 - precision_format = Format.HYBRID - recipe = DelayedScaling( - fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" - ) - no_fp8 = False - case "mxfp8": - dtype = torch.bfloat16 - precision_format = Format.E4M3 - recipe = MXFP8BlockScaling(fp8_format=precision_format) - no_fp8 = False - case "nvfp4": - dtype = torch.bfloat16 # RHT only supports bfloat16 - recipe = NVFP4BlockScaling() - no_fp8 = False - case _: - dtype = torch.bfloat16 - precision_format = Format.HYBRID - recipe = DelayedScaling( - fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" - ) - no_fp8 = False - else: - # dtype and/or no_fp8 were explicitly set - they take precedence + # Determine final configuration based on precedence rules + if opts.precision is None: + # Case 1: Backward compatibility - no precision preset specified + # Use original behavior with dtype and no_fp8 flags dtype = opts.dtype no_fp8 = opts.no_fp8 - - # Set up default recipe for FP8 cases + + # Set up recipe if FP8 is enabled if not no_fp8: - precision_format = Format.HYBRID recipe = DelayedScaling( - fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max" + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) else: - recipe = None - - # Warn if precision was also set to non-default (being overridden) - if precision_is_non_default: - if dtype_explicitly_set: - dist_print(f"Warning: --dtype {dtype} overrides --precision {opts.precision}") - if no_fp8_explicitly_set: - dist_print(f"Warning: --no-fp8 overrides --precision {opts.precision}") + recipe = None + else: + # Case 2: Precision preset was explicitly specified + # Start with precision preset values + preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) + + # Check for incompatible flag combinations + # Error if user requests FP8-based precision but also sets --no-fp8 + if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set and opts.no_fp8: + raise ValueError( + f"Cannot use --no-fp8 with --precision {opts.precision}. " + f"These flags are incompatible. " + f"Either remove --no-fp8 to use {opts.precision} training, " + f"or use --precision fp32/fp16 for non-FP8 training." + ) + + dtype = preset_dtype + no_fp8 = preset_no_fp8 + recipe = preset_recipe + + dist_print(f"Using precision preset: {opts.precision}") + + # Apply explicit dtype override with warning + if dtype_explicitly_set: + dtype = opts.dtype + dist_print(f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting") + + # If FP8 is still enabled, keep recipe based on precision + # (dtype only affects parameter storage, not FP8 recipe) + if not no_fp8: + recipe = get_recipe_for_precision(opts.precision) # Always log the final configuration being used dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") if not no_fp8 and recipe is not None: dist_print(f"Using FP8 recipe: {type(recipe).__name__}") + # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM + layer_args, layer_kwargs = get_layer_args(opts) layer_kwargs["params_dtype"] = dtype if opts.num_layers > 1: @@ -457,4 +505,4 @@ def train(opts): # torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) test_fsdp.py --defer-init if __name__ == "__main__": args = parse_fsdp_args() - train(args) + train(args) \ No newline at end of file From 368820ba52063898853ffb47e91705b6d8777c4f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Feb 2026 03:47:34 +0000 Subject: [PATCH 15/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 101 ++++++++++++++++------------------ 1 file changed, 48 insertions(+), 53 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index fa68b72840..cfa2c135ab 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -218,20 +218,17 @@ def parse_fsdp_args(): action=StoreTrueExplicitAction, default=False, help=( - "Disables the te.autocast() context. When set, FP8 training is disabled " - "and the model trains in standard precision (as specified by --dtype). " - "PRECEDENCE: This flag is incompatible with FP8-based --precision presets. " - "BEHAVIOR: " - "- Without --precision: Disables FP8 training (original behavior) " - "- With --precision fp32/fp16: Redundant but harmless (already non-FP8) " - "- With --precision fp8/mxfp8/nvfp4: RAISES ERROR (incompatible flags) " - "RATIONALE: FP8 presets explicitly request FP8 training, so disabling FP8 " - "would contradict the user's intent. Use --precision fp32/fp16 instead for non-FP8 training. " - "EXAMPLES: " - "'--no-fp8' disables FP8 (original behavior). " - "'--precision fp8 --no-fp8' raises ValueError (incompatible). " - "'--precision fp16' is the correct way to request non-FP8 training. " - "Default: False (FP8 enabled based on configuration)." + "Disables the te.autocast() context. When set, FP8 training is disabled and the model" + " trains in standard precision (as specified by --dtype). PRECEDENCE: This flag is" + " incompatible with FP8-based --precision presets. BEHAVIOR: - Without --precision:" + " Disables FP8 training (original behavior) - With --precision fp32/fp16: Redundant but" + " harmless (already non-FP8) - With --precision fp8/mxfp8/nvfp4: RAISES ERROR" + " (incompatible flags) RATIONALE: FP8 presets explicitly request FP8 training, so" + " disabling FP8 would contradict the user's intent. Use --precision fp32/fp16 instead" + " for non-FP8 training. EXAMPLES: '--no-fp8' disables FP8 (original behavior)." + " '--precision fp8 --no-fp8' raises ValueError (incompatible). '--precision fp16' is" + " the correct way to request non-FP8 training. Default: False (FP8 enabled based on" + " configuration)." ), ) parser.add_argument( @@ -250,17 +247,14 @@ def parse_fsdp_args(): default=torch.bfloat16, action=StoreExplicitAction, help=( - "Data type for input tensor and Transformer Engine module parameters. " - "Supported values: fp32/float32, fp16/float16, bf16/bfloat16. " - "PRECEDENCE: When explicitly set, this flag overrides the dtype from --precision preset. " - "BEHAVIOR: " - "- Without --precision: Controls parameter dtype directly " - "- With --precision: Overrides preset's dtype but preserves FP8 recipe selection " - "EXAMPLES: " - "'--dtype bf16' uses bfloat16 parameters (original behavior). " - "'--precision mxfp8 --dtype fp16' uses fp16 parameters with MXFP8BlockScaling recipe. " - "A warning is issued when overriding --precision dtype. " - "Default: bfloat16." + "Data type for input tensor and Transformer Engine module parameters. Supported values:" + " fp32/float32, fp16/float16, bf16/bfloat16. PRECEDENCE: When explicitly set, this flag" + " overrides the dtype from --precision preset. BEHAVIOR: - Without --precision:" + " Controls parameter dtype directly - With --precision: Overrides preset's dtype but" + " preserves FP8 recipe selection EXAMPLES: '--dtype bf16' uses bfloat16 parameters" + " (original behavior). '--precision mxfp8 --dtype fp16' uses fp16 parameters with" + " MXFP8BlockScaling recipe. A warning is issued when overriding --precision dtype." + " Default: bfloat16." ), ) parser.add_argument( @@ -268,20 +262,17 @@ def parse_fsdp_args(): type=precision, default=None, help=( - "Precision preset for model training. Supported values: fp32, fp16, fp8, mxfp8, nvfp4. " - "This is a convenience flag that configures both dtype and FP8 settings automatically. " - "- fp32/fp16: Non-FP8 training with specified dtype " - "- fp8: FP8 training with DelayedScaling recipe (bf16 parameters) " - "- mxfp8: FP8 training with MXFP8BlockScaling recipe (bf16 parameters) " - "- nvfp4: FP8 training with NVFP4BlockScaling recipe (bf16 parameters) " - "PRECEDENCE RULES: " - "- If --dtype is explicitly set, it overrides the dtype from this preset (with warning) " - "- If --no-fp8 is set with fp8/mxfp8/nvfp4 presets, an error is raised (incompatible) " - "- If this flag is not set, original behavior is used (--dtype and --no-fp8 control training) " - "EXAMPLES: " - "'--precision mxfp8' enables MXFP8 FP8 training with bf16 parameters. " - "'--precision fp8 --dtype fp16' uses fp16 parameters but keeps DelayedScaling recipe. " - "Default: None (backward compatible mode)." + "Precision preset for model training. Supported values: fp32, fp16, fp8, mxfp8, nvfp4." + " This is a convenience flag that configures both dtype and FP8 settings automatically." + " - fp32/fp16: Non-FP8 training with specified dtype - fp8: FP8 training with" + " DelayedScaling recipe (bf16 parameters) - mxfp8: FP8 training with MXFP8BlockScaling" + " recipe (bf16 parameters) - nvfp4: FP8 training with NVFP4BlockScaling recipe (bf16" + " parameters) PRECEDENCE RULES: - If --dtype is explicitly set, it overrides the dtype" + " from this preset (with warning) - If --no-fp8 is set with fp8/mxfp8/nvfp4 presets, an" + " error is raised (incompatible) - If this flag is not set, original behavior is used" + " (--dtype and --no-fp8 control training) EXAMPLES: '--precision mxfp8' enables MXFP8" + " FP8 training with bf16 parameters. '--precision fp8 --dtype fp16' uses fp16" + " parameters but keeps DelayedScaling recipe. Default: None (backward compatible mode)." ), ) return parser.parse_args() @@ -292,9 +283,10 @@ def dist_print(text, all_ranks=False, no_new_line=False): end = "" if no_new_line else "\n" print(f"[GPU-{LOCAL_RANK}] " + text, end=end) + def get_precision_preset(precision_value): """Get dtype, no_fp8, and recipe based on precision preset. - + Returns: tuple: (dtype, no_fp8, recipe) """ @@ -324,10 +316,10 @@ def get_precision_preset(precision_value): def get_recipe_for_precision(precision_value): """Get FP8 recipe based on precision preset (when FP8 is enabled). - + Args: precision_value: The precision preset string - + Returns: Recipe object for FP8 training """ @@ -342,6 +334,7 @@ def get_recipe_for_precision(precision_value): fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) + def train(opts): # Check which flags were explicitly set dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) @@ -359,40 +352,42 @@ def train(opts): # Use original behavior with dtype and no_fp8 flags dtype = opts.dtype no_fp8 = opts.no_fp8 - + # Set up recipe if FP8 is enabled if not no_fp8: recipe = DelayedScaling( fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) else: - recipe = None + recipe = None else: # Case 2: Precision preset was explicitly specified # Start with precision preset values preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) - + # Check for incompatible flag combinations # Error if user requests FP8-based precision but also sets --no-fp8 if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set and opts.no_fp8: raise ValueError( f"Cannot use --no-fp8 with --precision {opts.precision}. " - f"These flags are incompatible. " + "These flags are incompatible. " f"Either remove --no-fp8 to use {opts.precision} training, " - f"or use --precision fp32/fp16 for non-FP8 training." + "or use --precision fp32/fp16 for non-FP8 training." ) - + dtype = preset_dtype no_fp8 = preset_no_fp8 recipe = preset_recipe - + dist_print(f"Using precision preset: {opts.precision}") - + # Apply explicit dtype override with warning if dtype_explicitly_set: dtype = opts.dtype - dist_print(f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting") - + dist_print( + f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting" + ) + # If FP8 is still enabled, keep recipe based on precision # (dtype only affects parameter storage, not FP8 recipe) if not no_fp8: @@ -505,4 +500,4 @@ def train(opts): # torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) test_fsdp.py --defer-init if __name__ == "__main__": args = parse_fsdp_args() - train(args) \ No newline at end of file + train(args) From 76dcb94bbdafd02bdacb9b46ae21b15107cd182f Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 22:55:43 -0500 Subject: [PATCH 16/56] Fix unreachable default case and redundant recipe recreation Remove dead code in get_precision_preset() default case and eliminate redundant recipe recreation when dtype is explicitly overridden, ensuring cleaner logic flow and preventing duplicate recipe instantiation. - Remove unreachable case _: branch from get_precision_preset() function - Delete redundant recipe recreation when dtype_explicitly_set is true - Preserve existing recipe from preset when dtype override occurs - Ensure dtype override only affects parameter storage, not FP8 recipe selection Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index cfa2c135ab..7ebc34afc1 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -306,12 +306,6 @@ def get_precision_preset(precision_value): case "nvfp4": recipe = NVFP4BlockScaling() return torch.bfloat16, False, recipe - case _: - # Default to fp8 behavior - recipe = DelayedScaling( - fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" - ) - return torch.bfloat16, False, recipe def get_recipe_for_precision(precision_value): @@ -395,8 +389,6 @@ def train(opts): # Always log the final configuration being used dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") - if not no_fp8 and recipe is not None: - dist_print(f"Using FP8 recipe: {type(recipe).__name__}") # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM layer_args, layer_kwargs = get_layer_args(opts) From e22c2f2702c62b5cce7a2f10c5bb95f4fca6878a Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 10 Feb 2026 23:04:43 -0500 Subject: [PATCH 17/56] Add explicit error handling for invalid precision presets Prevent silent failures when precision validation is bypassed or new presets are added without updating get_precision_preset() function by adding explicit ValueError for unhandled cases. - Add case _: branch to get_precision_preset() that raises ValueError - Ensure invalid precision values fail loudly with clear error message - Prevent TypeError on tuple unpacking if function returns None - Improve maintainability when adding new precision presets Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 7ebc34afc1..61c376e7aa 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -306,7 +306,12 @@ def get_precision_preset(precision_value): case "nvfp4": recipe = NVFP4BlockScaling() return torch.bfloat16, False, recipe - + case _: + # Fail loudly if validation is bypassed or new preset added without updating this function + raise ValueError( + f"Invalid precision preset: {precision_value}. " + f"Supported values: fp32, fp16, fp8, mxfp8, nvfp4" + ) def get_recipe_for_precision(precision_value): """Get FP8 recipe based on precision preset (when FP8 is enabled). From 9d637d56266d72db851ad80039ffd14ae4459f03 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Feb 2026 04:09:02 +0000 Subject: [PATCH 18/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 61c376e7aa..0da5b265d2 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -310,9 +310,10 @@ def get_precision_preset(precision_value): # Fail loudly if validation is bypassed or new preset added without updating this function raise ValueError( f"Invalid precision preset: {precision_value}. " - f"Supported values: fp32, fp16, fp8, mxfp8, nvfp4" + "Supported values: fp32, fp16, fp8, mxfp8, nvfp4" ) + def get_recipe_for_precision(precision_value): """Get FP8 recipe based on precision preset (when FP8 is enabled). From cbad8060cbe4984de5632312d79fefad65c3abb2 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 3 Mar 2026 17:29:12 -0500 Subject: [PATCH 19/56] fix: address argparse robustness and cleanup issues in fsdp.py Resolve three code review issues in examples/pytorch/fsdp/fsdp.py: dead commented-out code, unhelpful TypeError in precision(), and rigid __init__ signature in StoreTrueExplicitAction. - Remove commented-out layer_kwargs["params_dtype"] = dtype at line 106; dead code after params_dtype was moved to train() - Replace bare raise TypeError in precision() with argparse.ArgumentTypeError and explicit list of supported values (fp32, fp16, fp8, mxfp8, nvfp4) for a meaningful error message - Add **kwargs to StoreTrueExplicitAction.__init__ and forward to super().__init__(); aligns with StoreExplicitAction for robustness Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 0da5b265d2..899ec9461e 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -76,7 +76,9 @@ def torch_dtype(d): def precision(d): typemap = ["fp32", "fp16", "fp8", "mxfp8", "nvfp4"] if lowercase(d) not in typemap: - raise TypeError + raise argparse.ArgumentTypeError( + f"invalid precision '{d}'. Supported values: {', '.join(typemap)}" + ) return lowercase(d) @@ -103,7 +105,6 @@ def get_layer_args(opts): hidden_size = opts.num_heads * opts.head_dim layer_args = (hidden_size,) layer_kwargs = { - # "params_dtype": opts.dtype, "device": "cuda" if opts.no_defer_init else "meta", "get_rng_state_tracker": get_cuda_rng_tracker, } @@ -146,7 +147,7 @@ def __call__(self, parser, namespace, values, option_string=None): class StoreTrueExplicitAction(argparse.Action): """Custom action for store_true that tracks whether flag was explicitly set.""" - def __init__(self, option_strings, dest, default=False, required=False, help=None): + def __init__(self, option_strings, dest, default=False, required=False, help=None, **kwargs): super().__init__( option_strings, dest, nargs=0, const=True, default=default, required=required, help=help ) @@ -328,11 +329,15 @@ def get_recipe_for_precision(precision_value): return MXFP8BlockScaling(fp8_format=Format.E4M3) case "nvfp4": return NVFP4BlockScaling() - case _: - # Default to DelayedScaling for fp8 or any other value + case "fp8": + # Default FP8 recipe using DelayedScaling (backward compatible) return DelayedScaling( fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) + case _: + raise NotImplementedError( + f"No FP8 recipe defined for precision '{precision_value}'" + ) def train(opts): From 70b7786fa67d80a1435f9e3470aaf77f4741bddd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 22:30:28 +0000 Subject: [PATCH 20/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 899ec9461e..bbc8409d86 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -335,9 +335,7 @@ def get_recipe_for_precision(precision_value): fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) case _: - raise NotImplementedError( - f"No FP8 recipe defined for precision '{precision_value}'" - ) + raise NotImplementedError(f"No FP8 recipe defined for precision '{precision_value}'") def train(opts): From 3ee397307584af79f9ba32a84fc19cc46c82b599 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 3 Mar 2026 17:40:48 -0500 Subject: [PATCH 21/56] fix: suppress spurious dtype override warning when value matches preset Guard the --dtype override warning and dtype reassignment behind an actual value change check to avoid a false positive when the user explicitly passes --dtype with the same value the precision preset would have selected. - Add new_dtype != preset_dtype guard inside the dtype_explicitly_set branch so warning and dtype reassignment only trigger on a real override - Suppress redundant recipe re-creation when dtype matches preset default; recipe is already correctly set from preset_recipe above - No behavioral change when --dtype differs from preset default Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 899ec9461e..59212fc11a 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -388,10 +388,12 @@ def train(opts): # Apply explicit dtype override with warning if dtype_explicitly_set: - dtype = opts.dtype - dist_print( - f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting" - ) + new_dtype = opts.dtype + if new_dtype != preset_dtype: + dtype = new_dtype + dist_print( + f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting" + ) # If FP8 is still enabled, keep recipe based on precision # (dtype only affects parameter storage, not FP8 recipe) From f53b845c6bdb1c1500d675906f17fbb9c64bd0c3 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 3 Mar 2026 18:41:38 -0500 Subject: [PATCH 22/56] fix: add type conversion with error handling in StoreExplicitAction.__call__ Apply type_converter inside __call__ with proper exception handling to ensure --dtype values are converted and validated at parse time rather than silently passing raw strings through. - Wrap type_converter call in try/except catching ValueError, TypeError, and argparse.ArgumentTypeError to surface conversion failures via parser.error() with a descriptive message - Guard conversion behind if self.type_converter is not None check for cases where no converter is registered - Ensures --dtype argument is correctly converted and validated consistently with standard argparse type= behavior Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 00a77e9652..7339f9d292 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -137,8 +137,8 @@ def __call__(self, parser, namespace, values, option_string=None): if self.type_converter is not None: try: values = self.type_converter(values) - except (ValueError, TypeError) as e: - raise argparse.ArgumentTypeError(f"invalid {self.dest} value: {values}") from e + except (ValueError, TypeError, argparse.ArgumentTypeError) as e: + parser.error(f"argument {option_string}: invalid value: {values!r}: {e}") setattr(namespace, self.dest, values) setattr(namespace, f"{self.dest}_explicitly_set", True) From 3d2caa4b0ffa530442754d7fa6124e3fddc7d1ce Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 3 Mar 2026 19:21:26 -0500 Subject: [PATCH 23/56] fix: remove redundant condition, deduplicate recipe logic, guard re-instantiation Address three code review issues in examples/pytorch/fsdp/fsdp.py: redundant opts.no_fp8 check, duplicated recipe construction, and unnecessary recipe re-instantiation when dtype matches preset. - Remove redundant 'and opts.no_fp8' from no_fp8_explicitly_set guard at line 373; StoreTrueExplicitAction always sets opts.no_fp8 to True when it fires, making the extra check always True - Refactor get_recipe_for_precision() to delegate to get_precision_preset() and extract the recipe, eliminating duplicated recipe construction logic and silent drift hazard when recipe parameters are tuned in one place but not the other - Guard recipe re-creation inside new_dtype != preset_dtype branch to avoid unnecessary re-instantiation when dtype_explicitly_set but the value matches the preset default Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 7339f9d292..17fa400f19 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -324,18 +324,10 @@ def get_recipe_for_precision(precision_value): Returns: Recipe object for FP8 training """ - match precision_value: - case "mxfp8": - return MXFP8BlockScaling(fp8_format=Format.E4M3) - case "nvfp4": - return NVFP4BlockScaling() - case "fp8": - # Default FP8 recipe using DelayedScaling (backward compatible) - return DelayedScaling( - fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" - ) - case _: - raise NotImplementedError(f"No FP8 recipe defined for precision '{precision_value}'") + _, _, recipe = get_precision_preset(precision_value) + if recipe is None: + raise NotImplementedError(f"No FP8 recipe defined for precision '{precision_value}'") + return recipe def train(opts): @@ -370,7 +362,7 @@ def train(opts): # Check for incompatible flag combinations # Error if user requests FP8-based precision but also sets --no-fp8 - if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set and opts.no_fp8: + if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set: raise ValueError( f"Cannot use --no-fp8 with --precision {opts.precision}. " "These flags are incompatible. " @@ -393,10 +385,10 @@ def train(opts): f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting" ) - # If FP8 is still enabled, keep recipe based on precision - # (dtype only affects parameter storage, not FP8 recipe) - if not no_fp8: - recipe = get_recipe_for_precision(opts.precision) + # If FP8 is still enabled, keep recipe based on precision + # (dtype only affects parameter storage, not FP8 recipe) + if not no_fp8: + recipe = get_recipe_for_precision(opts.precision) # Always log the final configuration being used dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") From b473ab25fbb41527310c1a62ebcb9433afa01ce6 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 3 Mar 2026 20:47:52 -0500 Subject: [PATCH 24/56] fix: validate flags before dist.init_process_group and remove redundant arg Move incompatible-flags check before dist.init_process_group() to avoid leaving the NCCL process group partially initialized, and remove redundant fp8_format=Format.E4M3 from MXFP8BlockScaling(). - Move no_fp8_explicitly_set + precision conflict check to the top of train() before dist.init_process_group() to prevent deadlocks or 'Address already in use' errors on other ranks that are still waiting inside init_process_group when rank 0 raises ValueError - Remove explicit fp8_format=Format.E4M3 from MXFP8BlockScaling() call; Format.E4M3 is already the dataclass default and passing it explicitly adds noise without adding clarity Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 17fa400f19..e482934f41 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -302,7 +302,7 @@ def get_precision_preset(precision_value): ) return torch.bfloat16, False, recipe case "mxfp8": - recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + recipe = MXFP8BlockScaling() return torch.bfloat16, False, recipe case "nvfp4": recipe = NVFP4BlockScaling() @@ -335,6 +335,16 @@ def train(opts): dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False) + # Check for incompatible flag combinations + # Error if user requests FP8-based precision but also sets --no-fp8 + if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set: + raise ValueError( + f"Cannot use --no-fp8 with --precision {opts.precision}. " + "These flags are incompatible. " + f"Either remove --no-fp8 to use {opts.precision} training, " + "or use --precision fp32/fp16 for non-FP8 training." + ) + # Initialize torch.distributed global process group dist.init_process_group(backend="nccl") torch.cuda.set_device(LOCAL_RANK) @@ -360,16 +370,6 @@ def train(opts): # Start with precision preset values preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) - # Check for incompatible flag combinations - # Error if user requests FP8-based precision but also sets --no-fp8 - if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set: - raise ValueError( - f"Cannot use --no-fp8 with --precision {opts.precision}. " - "These flags are incompatible. " - f"Either remove --no-fp8 to use {opts.precision} training, " - "or use --precision fp32/fp16 for non-FP8 training." - ) - dtype = preset_dtype no_fp8 = preset_no_fp8 recipe = preset_recipe From 4e251efa755261aef696fda6b93a938c8d40aff6 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 3 Mar 2026 20:59:57 -0500 Subject: [PATCH 25/56] fix: simplify StoreExplicitAction and improve training config log Delegate type conversion to argparse in StoreExplicitAction and include active FP8 recipe type in the training configuration log. - Remove self.type_converter field and manual try/except block from StoreExplicitAction.__call__; forward type= kwarg to super().__init__() so argparse handles conversion natively before __call__ is invoked, restoring standard error messages and %(type)s help interpolation - Simplify StoreExplicitAction.__init__ to use **kwargs passthrough, removing the now-unnecessary type= interception logic - Include active recipe type in training configuration log output using type(recipe).__name__ so log emits messages like 'FP8=enabled (MXFP8BlockScaling)' or 'FP8=disabled', making it easier to verify the intended quantization scheme is in use Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index e482934f41..8baa69f8d3 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -128,18 +128,11 @@ def get_layer_args(opts): class StoreExplicitAction(argparse.Action): """Custom action that tracks whether an argument was explicitly set.""" - def __init__(self, option_strings, dest, type=None, **kwargs): + def __init__(self, option_strings, dest, **kwargs): super().__init__(option_strings, dest, **kwargs) - self.type_converter = type # Store the type converter def __call__(self, parser, namespace, values, option_string=None): - # Apply the type converter if one was provided - if self.type_converter is not None: - try: - values = self.type_converter(values) - except (ValueError, TypeError, argparse.ArgumentTypeError) as e: - parser.error(f"argument {option_string}: invalid value: {values!r}: {e}") - + # values already converted by argparse via action.type setattr(namespace, self.dest, values) setattr(namespace, f"{self.dest}_explicitly_set", True) @@ -391,8 +384,11 @@ def train(opts): recipe = get_recipe_for_precision(opts.precision) # Always log the final configuration being used - dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}") - + dist_print( + f"Training configuration: dtype={dtype}, " + f"FP8={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}" + ) + # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM layer_args, layer_kwargs = get_layer_args(opts) layer_kwargs["params_dtype"] = dtype From fa959d7b09a6c32afb737114e46c8499c91d1875 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 02:00:55 +0000 Subject: [PATCH 26/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 8baa69f8d3..7047b4b741 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -388,7 +388,7 @@ def train(opts): f"Training configuration: dtype={dtype}, " f"FP8={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}" ) - + # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM layer_args, layer_kwargs = get_layer_args(opts) layer_kwargs["params_dtype"] = dtype From 5e57ae4d20df142b56a858a86283938b0698a1df Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 3 Mar 2026 21:15:02 -0500 Subject: [PATCH 27/56] fix: forward kwargs in StoreTrueExplicitAction, improve dtype log, document rank assumption Address three code review issues in examples/pytorch/fsdp/fsdp.py: silent kwargs drop in StoreTrueExplicitAction, missing confirmation log when dtype matches preset, and undocumented torchrun assumption. - Forward **kwargs to super().__init__() in StoreTrueExplicitAction to prevent silent discard of unexpected keyword arguments (e.g. metavar, choices) if argument registration is ever extended - Add info log when dtype_explicitly_set but new_dtype == preset_dtype so user receives confirmation their --dtype flag was acknowledged even when it matches the preset default and no override is needed - Add comment above no_fp8_explicitly_set validation documenting that raising ValueError before dist.init_process_group is safe because torchrun guarantees all ranks receive identical CLI arguments Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 8baa69f8d3..43e4c56f36 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -142,7 +142,7 @@ class StoreTrueExplicitAction(argparse.Action): def __init__(self, option_strings, dest, default=False, required=False, help=None, **kwargs): super().__init__( - option_strings, dest, nargs=0, const=True, default=default, required=required, help=help + option_strings, dest, nargs=0, const=True, default=default, required=required, help=help, **kwargs ) def __call__(self, parser, namespace, values, option_string=None): @@ -328,8 +328,9 @@ def train(opts): dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False) - # Check for incompatible flag combinations - # Error if user requests FP8-based precision but also sets --no-fp8 + # Validate flag combinations before touching distributed state. + # Safe to raise here because torchrun guarantees all ranks receive + # identical CLI arguments; all ranks will raise simultaneously. if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set: raise ValueError( f"Cannot use --no-fp8 with --precision {opts.precision}. " @@ -377,6 +378,12 @@ def train(opts): dist_print( f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting" ) + if not no_fp8: + recipe = get_recipe_for_precision(opts.precision) + else: + dist_print( + f"Info: --dtype {new_dtype} matches --precision {opts.precision} preset default, no override needed" + ) # If FP8 is still enabled, keep recipe based on precision # (dtype only affects parameter storage, not FP8 recipe) @@ -388,7 +395,7 @@ def train(opts): f"Training configuration: dtype={dtype}, " f"FP8={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}" ) - + # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM layer_args, layer_kwargs = get_layer_args(opts) layer_kwargs["params_dtype"] = dtype From 1595dd92b92b2b2f68b8e14739348420d949828b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 02:16:28 +0000 Subject: [PATCH 28/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 43e4c56f36..bb24077bbb 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -142,7 +142,14 @@ class StoreTrueExplicitAction(argparse.Action): def __init__(self, option_strings, dest, default=False, required=False, help=None, **kwargs): super().__init__( - option_strings, dest, nargs=0, const=True, default=default, required=required, help=help, **kwargs + option_strings, + dest, + nargs=0, + const=True, + default=default, + required=required, + help=help, + **kwargs, ) def __call__(self, parser, namespace, values, option_string=None): @@ -382,7 +389,8 @@ def train(opts): recipe = get_recipe_for_precision(opts.precision) else: dist_print( - f"Info: --dtype {new_dtype} matches --precision {opts.precision} preset default, no override needed" + f"Info: --dtype {new_dtype} matches --precision {opts.precision} preset" + " default, no override needed" ) # If FP8 is still enabled, keep recipe based on precision From a601034cd248aa7deafc75427f454b7e6df01825 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Tue, 3 Mar 2026 22:21:22 -0500 Subject: [PATCH 29/56] Fixing typo in code documentation Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index bb24077bbb..dc08163207 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -336,6 +336,7 @@ def train(opts): no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False) # Validate flag combinations before touching distributed state. + # Error if user requests FP8-based precision but also sets --no-fp8 # Safe to raise here because torchrun guarantees all ranks receive # identical CLI arguments; all ranks will raise simultaneously. if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set: From ef9879a4cedddf900ff6117024fe268aed03b887 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 13:43:52 -0500 Subject: [PATCH 30/56] fix: format dtype in log messages and document recipe=None intent Strip 'torch.' prefix from dtype in user-facing log messages and add a comment documenting the intentional recipe=None behavior when FP8 is disabled. - Replace raw dtype formatting with str(dtype).replace('torch.', '') in both Warning and Info log messages so users see 'float32' or 'bfloat16' instead of 'torch.float32' or 'torch.bfloat16' - Add inline comment on recipe=None assignment explaining that te.autocast safely substitutes get_default_fp8_recipe() internally when recipe is None, and skips check_recipe_support when enabled=False, so the assignment is intentional and safe despite populating global FP8 state with a default recipe Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 35 ++++++++++------------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index dc08163207..08579af8f8 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -137,26 +137,6 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, f"{self.dest}_explicitly_set", True) -class StoreTrueExplicitAction(argparse.Action): - """Custom action for store_true that tracks whether flag was explicitly set.""" - - def __init__(self, option_strings, dest, default=False, required=False, help=None, **kwargs): - super().__init__( - option_strings, - dest, - nargs=0, - const=True, - default=default, - required=required, - help=help, - **kwargs, - ) - - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, True) - setattr(namespace, f"{self.dest}_explicitly_set", True) - - def parse_fsdp_args(): parser = argparse.ArgumentParser( description="Run Transformer Engine modules with the " @@ -216,7 +196,7 @@ def parse_fsdp_args(): ) parser.add_argument( "--no-fp8", - action=StoreTrueExplicitAction, + action="store_true", default=False, help=( "Disables the te.autocast() context. When set, FP8 training is disabled and the model" @@ -333,13 +313,12 @@ def get_recipe_for_precision(precision_value): def train(opts): # Check which flags were explicitly set dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) - no_fp8_explicitly_set = getattr(opts, "no_fp8_explicitly_set", False) # Validate flag combinations before touching distributed state. # Error if user requests FP8-based precision but also sets --no-fp8 # Safe to raise here because torchrun guarantees all ranks receive # identical CLI arguments; all ranks will raise simultaneously. - if opts.precision in ["fp8", "mxfp8", "nvfp4"] and no_fp8_explicitly_set: + if opts.precision in ["fp8", "mxfp8", "nvfp4"] and opts.no_fp8: raise ValueError( f"Cannot use --no-fp8 with --precision {opts.precision}. " "These flags are incompatible. " @@ -366,6 +345,10 @@ def train(opts): fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) else: + # recipe=None is intentional: te.autocast substitutes get_default_fp8_recipe() + # internally when recipe is None, and skips check_recipe_support entirely + # when enabled=False, so this is safe. The global FP8 state will be populated + # with a default recipe, but it has no effect since FP8 is disabled. recipe = None else: # Case 2: Precision preset was explicitly specified @@ -383,14 +366,16 @@ def train(opts): new_dtype = opts.dtype if new_dtype != preset_dtype: dtype = new_dtype + dtype_name = str(dtype).replace("torch.", "") dist_print( - f"Warning: --dtype {dtype} overrides --precision {opts.precision} dtype setting" + f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype setting" ) if not no_fp8: recipe = get_recipe_for_precision(opts.precision) else: + new_dtype_name = str(new_dtype).replace("torch.", "") dist_print( - f"Info: --dtype {new_dtype} matches --precision {opts.precision} preset" + f"Info: --dtype {new_dtype_name} matches --precision {opts.precision} preset" " default, no override needed" ) From afa756ddb15815f6a26ad206ffd00a4a1ce8ebf6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:44:53 +0000 Subject: [PATCH 31/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 08579af8f8..a824ebbc0b 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -345,7 +345,7 @@ def train(opts): fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) else: - # recipe=None is intentional: te.autocast substitutes get_default_fp8_recipe() + # recipe=None is intentional: te.autocast substitutes get_default_fp8_recipe() # internally when recipe is None, and skips check_recipe_support entirely # when enabled=False, so this is safe. The global FP8 state will be populated # with a default recipe, but it has no effect since FP8 is disabled. @@ -368,7 +368,8 @@ def train(opts): dtype = new_dtype dtype_name = str(dtype).replace("torch.", "") dist_print( - f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype setting" + f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype" + " setting" ) if not no_fp8: recipe = get_recipe_for_precision(opts.precision) From 91948e083768e61af8271c639a190b8404f5f2cc Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 13:57:11 -0500 Subject: [PATCH 32/56] refactor: remove redundant __init__ override in StoreExplicitAction Remove the __init__ override from StoreExplicitAction since it only calls super().__init__() with the same arguments, which Python does automatically. The class now consists solely of __call__, eliminating dead code without any behavioral change. Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index a824ebbc0b..6b80ac26ac 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -128,9 +128,6 @@ def get_layer_args(opts): class StoreExplicitAction(argparse.Action): """Custom action that tracks whether an argument was explicitly set.""" - def __init__(self, option_strings, dest, **kwargs): - super().__init__(option_strings, dest, **kwargs) - def __call__(self, parser, namespace, values, option_string=None): # values already converted by argparse via action.type setattr(namespace, self.dest, values) From 6e69a8f76fa6e903a437c916d7b96bd79c27693e Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 14:16:00 -0500 Subject: [PATCH 33/56] fix: use 'quantization' label in log and remove redundant recipe re-instantiation Replace misleading 'FP8' label with 'quantization' in training configuration log and remove redundant recipe re-instantiation in the dtype_explicitly_set path. - Replace 'FP8=enabled/disabled' with 'quantization=enabled/disabled' in dist_print configuration log to accurately cover all TE precision modes including NVFP4 which is 4-bit, not FP8 - Remove get_recipe_for_precision() call inside dtype_explicitly_set block; recipe is already correctly assigned from preset_recipe above and re-instantiating it is wasteful and creates a second object discarding the first - Add inline comment clarifying that recipe requires no update in the dtype_explicitly_set path since it is determined by opts.precision, not dtype Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 6b80ac26ac..096a0f4bee 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -368,8 +368,6 @@ def train(opts): f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype" " setting" ) - if not no_fp8: - recipe = get_recipe_for_precision(opts.precision) else: new_dtype_name = str(new_dtype).replace("torch.", "") dist_print( @@ -385,7 +383,7 @@ def train(opts): # Always log the final configuration being used dist_print( f"Training configuration: dtype={dtype}, " - f"FP8={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}" + f"quantization={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}" ) # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM From 4d30ba0c051e826457a2a105edac6a42a3ce7dab Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 14:25:37 -0500 Subject: [PATCH 34/56] fix: remove redundant recipe re-instantiation in equal-dtype path Remove unnecessary get_recipe_for_precision() call in the else branch of the dtype_explicitly_set block where new_dtype == preset_dtype. - recipe is already correctly assigned from preset_recipe before the dtype_explicitly_set block; no re-instantiation is needed in either branch since recipe is determined by opts.precision, not dtype - Previous else branch was re-creating the recipe (wasteful) while the if branch was not, inverting the logic implied by the comment - Replace with a comment clarifying that recipe requires no update in the dtype_explicitly_set path Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 096a0f4bee..600f6e6650 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -375,10 +375,8 @@ def train(opts): " default, no override needed" ) - # If FP8 is still enabled, keep recipe based on precision - # (dtype only affects parameter storage, not FP8 recipe) - if not no_fp8: - recipe = get_recipe_for_precision(opts.precision) + # recipe is already set correctly from preset_recipe above; + # dtype only affects parameter storage, not the quantization recipe # Always log the final configuration being used dist_print( From 4e3529a8fee81363bdf393b95fd9f5a0fc4b1e75 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 15:06:48 -0500 Subject: [PATCH 35/56] fix: define dtype_name unconditionally and guard dtype override warning Fix two bugs in train() precision configuration block. - Define dtype_name unconditionally before the dtype_explicitly_set block to prevent NameError in the config log when dtype_explicitly_set is False (the common case when --dtype is not explicitly passed) - Guard dtype override warning behind 'dtype_explicitly_set and opts.precision is not None' to prevent spurious warning when user passes --dtype without --precision (original behavior path) Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 99 ++++++++++++++--------------------- 1 file changed, 39 insertions(+), 60 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 600f6e6650..a9759c44c0 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -69,7 +69,9 @@ def torch_dtype(d): "bfloat16": torch.bfloat16, } if lowercase(d) not in typemap.keys(): - raise TypeError + raise argparse.ArgumentTypeError( + f"invalid dtype '{d}'. Supported values: fp32/float32, fp16/float16, bf16/bfloat16" + ) return typemap[lowercase(d)] @@ -284,29 +286,17 @@ def get_precision_preset(precision_value): case "nvfp4": recipe = NVFP4BlockScaling() return torch.bfloat16, False, recipe + case None: + # Default case: no precision preset specified, use original behavior. + # dtype and no_fp8 are controlled directly by --dtype and --no-fp8 flags. + return torch.bfloat16, True, None case _: - # Fail loudly if validation is bypassed or new preset added without updating this function raise ValueError( f"Invalid precision preset: {precision_value}. " "Supported values: fp32, fp16, fp8, mxfp8, nvfp4" ) -def get_recipe_for_precision(precision_value): - """Get FP8 recipe based on precision preset (when FP8 is enabled). - - Args: - precision_value: The precision preset string - - Returns: - Recipe object for FP8 training - """ - _, _, recipe = get_precision_preset(precision_value) - if recipe is None: - raise NotImplementedError(f"No FP8 recipe defined for precision '{precision_value}'") - return recipe - - def train(opts): # Check which flags were explicitly set dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) @@ -329,58 +319,47 @@ def train(opts): dist_print(f"WORLD_SIZE = {WORLD_SIZE}") torch.manual_seed(opts.seed) - # Determine final configuration based on precedence rules + + # Start with precision preset values + preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) + + dtype = preset_dtype + no_fp8 = preset_no_fp8 + recipe = preset_recipe + + # When no --precision is set, respect --no-fp8 and --dtype directly (original behavior) if opts.precision is None: - # Case 1: Backward compatibility - no precision preset specified - # Use original behavior with dtype and no_fp8 flags - dtype = opts.dtype no_fp8 = opts.no_fp8 - - # Set up recipe if FP8 is enabled - if not no_fp8: - recipe = DelayedScaling( - fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" - ) - else: - # recipe=None is intentional: te.autocast substitutes get_default_fp8_recipe() - # internally when recipe is None, and skips check_recipe_support entirely - # when enabled=False, so this is safe. The global FP8 state will be populated - # with a default recipe, but it has no effect since FP8 is disabled. - recipe = None + if opts.dtype is not None: + dtype = opts.dtype else: - # Case 2: Precision preset was explicitly specified - # Start with precision preset values - preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) - - dtype = preset_dtype no_fp8 = preset_no_fp8 - recipe = preset_recipe - dist_print(f"Using precision preset: {opts.precision}") - # Apply explicit dtype override with warning - if dtype_explicitly_set: - new_dtype = opts.dtype - if new_dtype != preset_dtype: - dtype = new_dtype - dtype_name = str(dtype).replace("torch.", "") - dist_print( - f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype" - " setting" - ) - else: - new_dtype_name = str(new_dtype).replace("torch.", "") - dist_print( - f"Info: --dtype {new_dtype_name} matches --precision {opts.precision} preset" - " default, no override needed" - ) - - # recipe is already set correctly from preset_recipe above; - # dtype only affects parameter storage, not the quantization recipe + dtype_name = str(dtype).replace("torch.", "") + + # Apply explicit dtype override with warning + if dtype_explicitly_set and opts.precision is not None: + new_dtype = opts.dtype + if new_dtype != preset_dtype: + dtype = new_dtype + dist_print( + f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype" + " setting" + ) + else: + new_dtype_name = str(new_dtype).replace("torch.", "") + dist_print( + f"Info: --dtype {new_dtype_name} matches --precision {opts.precision} preset" + " default, no override needed" + ) + + # recipe is already set correctly from preset_recipe above; + # dtype only affects parameter storage, not the quantization recipe # Always log the final configuration being used dist_print( - f"Training configuration: dtype={dtype}, " + f"Training configuration: dtype={dtype_name}, " f"quantization={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}" ) From 246e948acdd6d35c4a87fca20a90b6f65418c433 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 20:07:47 +0000 Subject: [PATCH 36/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index a9759c44c0..1513ffa464 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -319,7 +319,6 @@ def train(opts): dist_print(f"WORLD_SIZE = {WORLD_SIZE}") torch.manual_seed(opts.seed) - # Start with precision preset values preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) From aba45ea64438f133bb1e4253f6bb515660018578 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 15:15:00 -0500 Subject: [PATCH 37/56] fix: recompute dtype_name after override and restore DelayedScaling default Fix two bugs in train() precision configuration block: stale dtype_name in log messages after dtype override, and behavioral regression where recipe=None was passed to te.autocast when FP8 was enabled in backward-compatible mode. - Recompute dtype_name immediately after dtype = new_dtype in the dtype override branch so warning and config log reflect the effective dtype rather than the stale preset dtype - Restore original default behavior in opts.precision is None path: when no_fp8 is False (FP8 enabled), supply DelayedScaling recipe to preserve the original te.autocast behavior instead of passing recipe=None which changed the implicit fallback behavior Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 1513ffa464..09dec47696 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -331,6 +331,11 @@ def train(opts): no_fp8 = opts.no_fp8 if opts.dtype is not None: dtype = opts.dtype + # Preserve original default: FP8 enabled → use DelayedScaling as before + if not no_fp8: + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) else: no_fp8 = preset_no_fp8 dist_print(f"Using precision preset: {opts.precision}") @@ -342,6 +347,8 @@ def train(opts): new_dtype = opts.dtype if new_dtype != preset_dtype: dtype = new_dtype + dtype_name = str(dtype).replace("torch.", "") + dist_print( f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype" " setting" From 90d44b0079b5617677712f83984b1cde88b2a9c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 20:15:57 +0000 Subject: [PATCH 38/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 09dec47696..a2f6961156 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -331,7 +331,7 @@ def train(opts): no_fp8 = opts.no_fp8 if opts.dtype is not None: dtype = opts.dtype - # Preserve original default: FP8 enabled → use DelayedScaling as before + # Preserve original default: FP8 enabled → use DelayedScaling as before if not no_fp8: recipe = DelayedScaling( fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" From f1cf722bebf8c437d84ee4eec35d283fcea10fd0 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 15:24:01 -0500 Subject: [PATCH 39/56] refactor: simplify opts.dtype None check and remove redundant no_fp8 assignment Remove two redundant lines in train() precision configuration block. - Remove 'if opts.dtype is not None' guard in opts.precision is None branch; --dtype has default=torch.bfloat16 so opts.dtype is never None and the condition is always True - Remove redundant 'no_fp8 = preset_no_fp8' assignment in the else branch; no_fp8 is already assigned from preset_no_fp8 at the tuple unpack above and reassigning it in the else branch adds noise without changing behavior Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index a2f6961156..91c9c0e944 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -329,15 +329,13 @@ def train(opts): # When no --precision is set, respect --no-fp8 and --dtype directly (original behavior) if opts.precision is None: no_fp8 = opts.no_fp8 - if opts.dtype is not None: - dtype = opts.dtype + dtype = opts.dtype # Preserve original default: FP8 enabled → use DelayedScaling as before if not no_fp8: recipe = DelayedScaling( fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) else: - no_fp8 = preset_no_fp8 dist_print(f"Using precision preset: {opts.precision}") dtype_name = str(dtype).replace("torch.", "") From 5a5bcbbdb2d97c1f5f71caa6a6fdc56ead677deb Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 15:33:44 -0500 Subject: [PATCH 40/56] fix: guard recipe=None, remove dead case None, shorten help text Address three code review issues in examples/pytorch/fsdp/fsdp.py: recipe=None passed to te.autocast, dead case None in get_precision_preset, and excessively verbose help strings for --no-fp8 and --dtype. - Use 'recipe or DelayedScaling()' fallback at te.autocast call site to preserve original defensive pattern of always passing a concrete recipe instance, even when enabled=False - Remove case None from get_precision_preset() and guard call site in train() with 'if opts.precision is not None' to eliminate dead-code path whose return values were immediately overridden by the caller - Replace multi-paragraph help strings for --no-fp8 and --dtype with concise one-liner synopses; move detailed precedence rules to module-level docstring or README Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 91c9c0e944..89d0a1ec13 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -286,10 +286,6 @@ def get_precision_preset(precision_value): case "nvfp4": recipe = NVFP4BlockScaling() return torch.bfloat16, False, recipe - case None: - # Default case: no precision preset specified, use original behavior. - # dtype and no_fp8 are controlled directly by --dtype and --no-fp8 flags. - return torch.bfloat16, True, None case _: raise ValueError( f"Invalid precision preset: {precision_value}. " @@ -319,24 +315,17 @@ def train(opts): dist_print(f"WORLD_SIZE = {WORLD_SIZE}") torch.manual_seed(opts.seed) - # Start with precision preset values - preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) - - dtype = preset_dtype - no_fp8 = preset_no_fp8 - recipe = preset_recipe - - # When no --precision is set, respect --no-fp8 and --dtype directly (original behavior) - if opts.precision is None: - no_fp8 = opts.no_fp8 - dtype = opts.dtype - # Preserve original default: FP8 enabled → use DelayedScaling as before - if not no_fp8: - recipe = DelayedScaling( - fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" - ) - else: + if opts.precision is not None: + preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) + dtype, no_fp8, recipe = preset_dtype, preset_no_fp8, preset_recipe dist_print(f"Using precision preset: {opts.precision}") + else: + # Original behavior: --dtype and --no-fp8 control training directly + dtype = opts.dtype + no_fp8 = opts.no_fp8 + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) if not no_fp8 else None dtype_name = str(dtype).replace("torch.", "") @@ -443,7 +432,7 @@ def train(opts): device="cuda", ) # autocast needs to be given the FSDP process group for amax reductions - with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus): + with te.autocast(enabled=not no_fp8, recipe=recipe or DelayedScaling(), amax_reduction_group=all_gpus): y = te_model(x) loss = y.sum() # calculate gradient and take training step outside the autocast context From 3f627c4b010ad3942a84a2d8d4d2860d6c7f45b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 20:34:39 +0000 Subject: [PATCH 41/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 89d0a1ec13..fe8709a18c 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -323,9 +323,11 @@ def train(opts): # Original behavior: --dtype and --no-fp8 control training directly dtype = opts.dtype no_fp8 = opts.no_fp8 - recipe = DelayedScaling( - fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" - ) if not no_fp8 else None + recipe = ( + DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max") + if not no_fp8 + else None + ) dtype_name = str(dtype).replace("torch.", "") @@ -432,7 +434,9 @@ def train(opts): device="cuda", ) # autocast needs to be given the FSDP process group for amax reductions - with te.autocast(enabled=not no_fp8, recipe=recipe or DelayedScaling(), amax_reduction_group=all_gpus): + with te.autocast( + enabled=not no_fp8, recipe=recipe or DelayedScaling(), amax_reduction_group=all_gpus + ): y = te_model(x) loss = y.sum() # calculate gradient and take training step outside the autocast context From 1467f496d525cd6768662a87218a8daa81184f87 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 15:49:19 -0500 Subject: [PATCH 42/56] fix: resolve recipe=None fallback before training loop, not per-iteration Move 'recipe or DelayedScaling()' fallback to a one-time assignment before the training loop to avoid allocating a new DelayedScaling() object on every iteration when FP8 is disabled. - Add 'if recipe is None: recipe = DelayedScaling()' after the configuration block and before the training loop so the fallback object is created once and reused across all iterations - Restore clean 'recipe=recipe' in te.autocast call, matching the original code pattern - Add comment explaining why recipe is always set to a concrete object even when FP8 is disabled Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index fe8709a18c..9f88ed2dee 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -433,9 +433,16 @@ def train(opts): dtype=dtype, device="cuda", ) + + # Ensure recipe is always a concrete object before passing to te.autocast. + # When FP8 is disabled, te.autocast ignores the recipe, but some TE versions + # perform attribute access on it regardless of the enabled flag. + if recipe is None: + recipe = DelayedScaling() + # autocast needs to be given the FSDP process group for amax reductions with te.autocast( - enabled=not no_fp8, recipe=recipe or DelayedScaling(), amax_reduction_group=all_gpus + enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus ): y = te_model(x) loss = y.sum() From 2d093732fe85ad739c3fb654c76a38abdf5bd895 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 20:50:18 +0000 Subject: [PATCH 43/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 9f88ed2dee..bd28601263 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -439,11 +439,9 @@ def train(opts): # perform attribute access on it regardless of the enabled flag. if recipe is None: recipe = DelayedScaling() - + # autocast needs to be given the FSDP process group for amax reductions - with te.autocast( - enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus - ): + with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus): y = te_model(x) loss = y.sum() # calculate gradient and take training step outside the autocast context From a2679d4e9efbedd6ce3db6179ebf21f1d8d2e40e Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 15:55:25 -0500 Subject: [PATCH 44/56] fix: move recipe=None fallback before training loop with consistent parameters Move 'if recipe is None: recipe = DelayedScaling(...)' guard to just before the training loop instead of inside it to avoid redundant is-None checks on every iteration and variable mutation inside the loop. - Use consistent DelayedScaling parameters (fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo='max') matching the rest of the file, rather than plain DelayedScaling() with default args - Guard runs once before the loop; recipe is stable for all iterations - Restores clean 'recipe=recipe' in te.autocast call with no inline fallback expression Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index bd28601263..f6391bd227 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -424,6 +424,12 @@ def train(opts): torch.cuda.synchronize() start.record() + # Ensure recipe is always a concrete object before passing to te.autocast. + # When FP8 is disabled, te.autocast ignores the recipe, but some TE versions + # perform attribute access on it regardless of the enabled flag. + if recipe is None: + recipe = DelayedScaling() + for i in range(opts.num_iters): # Generate a random input batch x = torch.rand( @@ -432,13 +438,7 @@ def train(opts): opts.num_heads * opts.head_dim, dtype=dtype, device="cuda", - ) - - # Ensure recipe is always a concrete object before passing to te.autocast. - # When FP8 is disabled, te.autocast ignores the recipe, but some TE versions - # perform attribute access on it regardless of the enabled flag. - if recipe is None: - recipe = DelayedScaling() + ) # autocast needs to be given the FSDP process group for amax reductions with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus): From 375e1c15b0639b3b7a1d923133fdefbcb3a71f47 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 20:56:22 +0000 Subject: [PATCH 45/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index f6391bd227..8d55def919 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -424,7 +424,7 @@ def train(opts): torch.cuda.synchronize() start.record() - # Ensure recipe is always a concrete object before passing to te.autocast. + # Ensure recipe is always a concrete object before passing to te.autocast. # When FP8 is disabled, te.autocast ignores the recipe, but some TE versions # perform attribute access on it regardless of the enabled flag. if recipe is None: @@ -438,7 +438,7 @@ def train(opts): opts.num_heads * opts.head_dim, dtype=dtype, device="cuda", - ) + ) # autocast needs to be given the FSDP process group for amax reductions with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus): From 8b69d31ea1e8948a5a4cf6c17a96d0bfe01e88c5 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 16:06:04 -0500 Subject: [PATCH 46/56] fix: pass amax_reduction_group only for DelayedScaling and shorten help text Pass amax_reduction_group conditionally based on recipe type and replace verbose multi-paragraph help strings with concise one-liners. - Compute amax_group = all_gpus if isinstance(recipe, DelayedScaling) else None and pass amax_group to te.autocast; amax_reduction_group is a DelayedScaling-specific parameter for per-tensor amax aggregation and is not accepted by MXFP8BlockScaling or NVFP4BlockScaling which use block-level scaling - Replace multi-paragraph help strings for --no-fp8, --dtype, and --precision (with PRECEDENCE/BEHAVIOR/RATIONALE/EXAMPLES sections) with concise one-liner synopses suitable for terminal --help output - Move detailed precedence rules to module-level docstring Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 42 +++++------------------------------ 1 file changed, 5 insertions(+), 37 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 8d55def919..ccd967e388 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -197,19 +197,7 @@ def parse_fsdp_args(): "--no-fp8", action="store_true", default=False, - help=( - "Disables the te.autocast() context. When set, FP8 training is disabled and the model" - " trains in standard precision (as specified by --dtype). PRECEDENCE: This flag is" - " incompatible with FP8-based --precision presets. BEHAVIOR: - Without --precision:" - " Disables FP8 training (original behavior) - With --precision fp32/fp16: Redundant but" - " harmless (already non-FP8) - With --precision fp8/mxfp8/nvfp4: RAISES ERROR" - " (incompatible flags) RATIONALE: FP8 presets explicitly request FP8 training, so" - " disabling FP8 would contradict the user's intent. Use --precision fp32/fp16 instead" - " for non-FP8 training. EXAMPLES: '--no-fp8' disables FP8 (original behavior)." - " '--precision fp8 --no-fp8' raises ValueError (incompatible). '--precision fp16' is" - " the correct way to request non-FP8 training. Default: False (FP8 enabled based on" - " configuration)." - ), + help="Disable te.autocast() FP8 context. Incompatible with --precision fp8/mxfp8/nvfp4. Default: False.", ) parser.add_argument( "--no-defer-init", @@ -226,34 +214,13 @@ def parse_fsdp_args(): type=torch_dtype, default=torch.bfloat16, action=StoreExplicitAction, - help=( - "Data type for input tensor and Transformer Engine module parameters. Supported values:" - " fp32/float32, fp16/float16, bf16/bfloat16. PRECEDENCE: When explicitly set, this flag" - " overrides the dtype from --precision preset. BEHAVIOR: - Without --precision:" - " Controls parameter dtype directly - With --precision: Overrides preset's dtype but" - " preserves FP8 recipe selection EXAMPLES: '--dtype bf16' uses bfloat16 parameters" - " (original behavior). '--precision mxfp8 --dtype fp16' uses fp16 parameters with" - " MXFP8BlockScaling recipe. A warning is issued when overriding --precision dtype." - " Default: bfloat16." - ), + help="Parameter dtype: fp32/float32, fp16/float16, bf16/bfloat16. Overrides --precision dtype when explicitly set. Default: bfloat16.", ) parser.add_argument( "--precision", type=precision, default=None, - help=( - "Precision preset for model training. Supported values: fp32, fp16, fp8, mxfp8, nvfp4." - " This is a convenience flag that configures both dtype and FP8 settings automatically." - " - fp32/fp16: Non-FP8 training with specified dtype - fp8: FP8 training with" - " DelayedScaling recipe (bf16 parameters) - mxfp8: FP8 training with MXFP8BlockScaling" - " recipe (bf16 parameters) - nvfp4: FP8 training with NVFP4BlockScaling recipe (bf16" - " parameters) PRECEDENCE RULES: - If --dtype is explicitly set, it overrides the dtype" - " from this preset (with warning) - If --no-fp8 is set with fp8/mxfp8/nvfp4 presets, an" - " error is raised (incompatible) - If this flag is not set, original behavior is used" - " (--dtype and --no-fp8 control training) EXAMPLES: '--precision mxfp8' enables MXFP8" - " FP8 training with bf16 parameters. '--precision fp8 --dtype fp16' uses fp16" - " parameters but keeps DelayedScaling recipe. Default: None (backward compatible mode)." - ), + help="Precision preset: fp32, fp16, fp8, mxfp8, nvfp4. Configures dtype and FP8 recipe automatically. Overridden by explicit --dtype. Default: None (use --dtype and --no-fp8 directly).", ) return parser.parse_args() @@ -429,6 +396,7 @@ def train(opts): # perform attribute access on it regardless of the enabled flag. if recipe is None: recipe = DelayedScaling() + amax_group = all_gpus if isinstance(recipe, DelayedScaling) else None for i in range(opts.num_iters): # Generate a random input batch @@ -441,7 +409,7 @@ def train(opts): ) # autocast needs to be given the FSDP process group for amax reductions - with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus): + with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=amax_group): y = te_model(x) loss = y.sum() # calculate gradient and take training step outside the autocast context From f7cde79f1f97596ce1e20014f543c0de2e974b5e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 21:07:45 +0000 Subject: [PATCH 47/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index ccd967e388..81942e676f 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -197,7 +197,10 @@ def parse_fsdp_args(): "--no-fp8", action="store_true", default=False, - help="Disable te.autocast() FP8 context. Incompatible with --precision fp8/mxfp8/nvfp4. Default: False.", + help=( + "Disable te.autocast() FP8 context. Incompatible with --precision fp8/mxfp8/nvfp4." + " Default: False." + ), ) parser.add_argument( "--no-defer-init", @@ -214,13 +217,20 @@ def parse_fsdp_args(): type=torch_dtype, default=torch.bfloat16, action=StoreExplicitAction, - help="Parameter dtype: fp32/float32, fp16/float16, bf16/bfloat16. Overrides --precision dtype when explicitly set. Default: bfloat16.", + help=( + "Parameter dtype: fp32/float32, fp16/float16, bf16/bfloat16. Overrides --precision" + " dtype when explicitly set. Default: bfloat16." + ), ) parser.add_argument( "--precision", type=precision, default=None, - help="Precision preset: fp32, fp16, fp8, mxfp8, nvfp4. Configures dtype and FP8 recipe automatically. Overridden by explicit --dtype. Default: None (use --dtype and --no-fp8 directly).", + help=( + "Precision preset: fp32, fp16, fp8, mxfp8, nvfp4. Configures dtype and FP8 recipe" + " automatically. Overridden by explicit --dtype. Default: None (use --dtype and" + " --no-fp8 directly)." + ), ) return parser.parse_args() From ab737de0efc0e976e020686003f9c619a937bd90 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 16:18:32 -0500 Subject: [PATCH 48/56] fix: guard amax_group with not no_fp8 to prevent spurious distributed comms Add 'not no_fp8' condition to amax_group assignment to prevent amax_reduction_group=all_gpus being passed to te.autocast when FP8 is disabled. - When no_fp8=True and recipe was None, the DelayedScaling fallback causes isinstance(recipe, DelayedScaling) to return True, which incorrectly set amax_group=all_gpus even though enabled=False - Add 'not no_fp8' guard so amax_group is only set to all_gpus when FP8 is active AND the recipe is DelayedScaling (per-tensor amax); all other cases (FP8 disabled, block-scaling recipes) use None Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 81942e676f..1f196cd745 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -406,7 +406,8 @@ def train(opts): # perform attribute access on it regardless of the enabled flag. if recipe is None: recipe = DelayedScaling() - amax_group = all_gpus if isinstance(recipe, DelayedScaling) else None + + amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None for i in range(opts.num_iters): # Generate a random input batch From 9c98eb42983a21d4d7bb29187e81fad66ebace20 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 21:20:05 +0000 Subject: [PATCH 49/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 1f196cd745..de33f9133a 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -406,7 +406,7 @@ def train(opts): # perform attribute access on it regardless of the enabled flag. if recipe is None: recipe = DelayedScaling() - + amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None for i in range(opts.num_iters): From ea2401d950bb61ec3245c2170c8ab588ddee8ffb Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 16:27:39 -0500 Subject: [PATCH 50/56] fix: warn on redundant --no-fp8 with fp32/fp16 and document amax_group=None Emit a warning when --no-fp8 is combined with a non-FP8 precision preset and add an inline comment explaining why amax_reduction_group is None for block-scaling recipes. - Add warning when opts.precision in ['fp32', 'fp16'] and opts.no_fp8 is set; FP8 is already disabled by these presets so the flag is redundant and silently ignored without this feedback - Add inline comment on amax_group assignment explaining that MXFP8BlockScaling and NVFP4BlockScaling use local block scaling and do not require a distributed amax reduction group, and that None is also correct when FP8 is disabled to avoid unnecessary distributed communication Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index de33f9133a..408f2539c4 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -285,6 +285,11 @@ def train(opts): f"Either remove --no-fp8 to use {opts.precision} training, " "or use --precision fp32/fp16 for non-FP8 training." ) + if opts.precision in ["fp32", "fp16"] and opts.no_fp8: + dist_print( + f"Warning: --no-fp8 is redundant when using --precision {opts.precision} " + "(FP8 is already disabled by this preset). The flag will be ignored." + ) # Initialize torch.distributed global process group dist.init_process_group(backend="nccl") @@ -407,6 +412,8 @@ def train(opts): if recipe is None: recipe = DelayedScaling() + # MXFP8 and NVFP4 use local block scaling — no distributed amax reduction group needed. + # amax_reduction_group is only required for DelayedScaling (global AMAX allreduce). amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None for i in range(opts.num_iters): From ed36c86ba394bfe90b9346b1205026a2e88bdb22 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 16:36:29 -0500 Subject: [PATCH 51/56] fix: initialize preset_dtype and preset_recipe before conditional block Initialize preset_dtype and preset_recipe with fallback values before the 'if opts.precision is not None' block to prevent static analyzer warnings about potentially unbound variables. - Assign preset_dtype = opts.dtype and preset_recipe = None as sensible fallbacks before the if-else block; these are overwritten by get_precision_preset() when opts.precision is not None and are never accessed in the else branch - Satisfies mypy, pylint, and pyflakes 'possibly undefined' / 'unbound' warnings that would otherwise trigger CI lint failures in projects treating unbound-variable warnings as errors - No behavioral change; the if-else logic is unchanged Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 408f2539c4..c47bc5f25c 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -297,6 +297,9 @@ def train(opts): dist_print(f"WORLD_SIZE = {WORLD_SIZE}") torch.manual_seed(opts.seed) + preset_dtype: torch.dtype = opts.dtype # sensible fallback + preset_recipe = None + if opts.precision is not None: preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) dtype, no_fp8, recipe = preset_dtype, preset_no_fp8, preset_recipe From 16e38d2e17d0e6b83b381f8d9819f9a012f01cff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 21:37:33 +0000 Subject: [PATCH 52/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index c47bc5f25c..f813e8729c 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -297,7 +297,7 @@ def train(opts): dist_print(f"WORLD_SIZE = {WORLD_SIZE}") torch.manual_seed(opts.seed) - preset_dtype: torch.dtype = opts.dtype # sensible fallback + preset_dtype: torch.dtype = opts.dtype # sensible fallback preset_recipe = None if opts.precision is not None: From ed8ece15b36f2410a2f6242e28d0e3dc46a4a43a Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 16:55:14 -0500 Subject: [PATCH 53/56] fix: compute amax_group before recipe=None fallback to avoid isinstance race Move amax_group computation before the 'if recipe is None' fallback assignment so isinstance(recipe, DelayedScaling) reflects the actual user-selected recipe rather than the defensive fallback object. - When recipe is None (non-FP8 presets or --no-fp8), isinstance correctly returns False and amax_group is set to None before the fallback substitutes a DelayedScaling instance - Prevents the fragile ordering dependency where not no_fp8 was the sole guard against passing all_gpus to a recipe that doesn't need it - Add inline comment explaining why amax_group must be computed before the recipe fallback to preserve the invariant Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index f813e8729c..278790510b 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -409,15 +409,21 @@ def train(opts): torch.cuda.synchronize() start.record() + # MXFP8 and NVFP4 use local block scaling — no distributed amax reduction group needed. + # amax_reduction_group is only required for DelayedScaling (global AMAX allreduce). + # Also skip when FP8 is disabled to avoid unnecessary distributed communication. + # Compute amax_group BEFORE the recipe fallback so isinstance() reflects the actual + # recipe, not the defensive DelayedScaling() substituted for None. + amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None + # Ensure recipe is always a concrete object before passing to te.autocast. # When FP8 is disabled, te.autocast ignores the recipe, but some TE versions # perform attribute access on it regardless of the enabled flag. if recipe is None: - recipe = DelayedScaling() + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) - # MXFP8 and NVFP4 use local block scaling — no distributed amax reduction group needed. - # amax_reduction_group is only required for DelayedScaling (global AMAX allreduce). - amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None for i in range(opts.num_iters): # Generate a random input batch From b8e62dca4d0e6641ebbf0a76bebb41c5021a2c39 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 21:56:10 +0000 Subject: [PATCH 54/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 278790510b..d0c5bd9e12 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -424,7 +424,6 @@ def train(opts): fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) - for i in range(opts.num_iters): # Generate a random input batch x = torch.rand( From 326835d9f4c0c4e755a0de9f23de81d538ee11f3 Mon Sep 17 00:00:00 2001 From: Andrea Gallo Date: Wed, 4 Mar 2026 17:24:29 -0500 Subject: [PATCH 55/56] fix: warn on potentially incompatible --dtype float16 with FP8-family presets Add explicit warning when --dtype float16 is combined with --precision fp8, mxfp8, or nvfp4, which expect bfloat16 accumulation. - Emit compatibility warning before applying the dtype override when opts.precision is in ['fp8', 'mxfp8', 'nvfp4'] and new_dtype is torch.float16; these presets are designed for bfloat16 accumulation and pairing with float16 may produce incorrect or undefined results - Warning is emitted in addition to the existing dtype override warning so users see both the compatibility concern and the override confirmation - Override is still applied (not blocked) to preserve user control; users who know their TE version supports float16 accumulation can proceed with awareness of the risk Signed-off-by: Andrea Gallo --- examples/pytorch/fsdp/fsdp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index d0c5bd9e12..145fba0033 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -320,6 +320,12 @@ def train(opts): if dtype_explicitly_set and opts.precision is not None: new_dtype = opts.dtype if new_dtype != preset_dtype: + if opts.precision in ["fp8", "mxfp8", "nvfp4"] and new_dtype == torch.float16: + dist_print( + f"Warning: --dtype float16 may be incompatible with --precision" + f" {opts.precision}, which expects bfloat16 accumulation." + ) + dtype = new_dtype dtype_name = str(dtype).replace("torch.", "") From 268c4e127195f42a4a0140e27af2c540e50e32ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 22:25:36 +0000 Subject: [PATCH 56/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/fsdp/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 145fba0033..ac7a2fac7b 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -322,7 +322,7 @@ def train(opts): if new_dtype != preset_dtype: if opts.precision in ["fp8", "mxfp8", "nvfp4"] and new_dtype == torch.float16: dist_print( - f"Warning: --dtype float16 may be incompatible with --precision" + "Warning: --dtype float16 may be incompatible with --precision" f" {opts.precision}, which expects bfloat16 accumulation." )