diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index b469ef56b7..278790510b 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -18,7 +18,12 @@ ) import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.common.recipe import ( + Format, + DelayedScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) @@ -64,10 +69,21 @@ def torch_dtype(d): "bfloat16": torch.bfloat16, } if lowercase(d) not in typemap.keys(): - raise TypeError + raise argparse.ArgumentTypeError( + f"invalid dtype '{d}'. Supported values: fp32/float32, fp16/float16, bf16/bfloat16" + ) return typemap[lowercase(d)] +def precision(d): + typemap = ["fp32", "fp16", "fp8", "mxfp8", "nvfp4"] + if lowercase(d) not in typemap: + raise argparse.ArgumentTypeError( + f"invalid precision '{d}'. Supported values: {', '.join(typemap)}" + ) + return lowercase(d) + + te_layer_map = { "linear": te.Linear, "layernorm": te.LayerNorm, @@ -91,7 +107,6 @@ def get_layer_args(opts): hidden_size = opts.num_heads * opts.head_dim layer_args = (hidden_size,) layer_kwargs = { - "params_dtype": opts.dtype, "device": "cuda" if opts.no_defer_init else "meta", "get_rng_state_tracker": get_cuda_rng_tracker, } @@ -112,6 +127,15 @@ def get_layer_args(opts): return layer_args, layer_kwargs +class StoreExplicitAction(argparse.Action): + """Custom action that tracks whether an argument was explicitly set.""" + + def __call__(self, parser, namespace, values, option_string=None): + # values already converted by argparse via action.type + setattr(namespace, self.dest, values) + setattr(namespace, f"{self.dest}_explicitly_set", True) + + def parse_fsdp_args(): parser = argparse.ArgumentParser( description="Run Transformer Engine modules with the " @@ -173,7 +197,10 @@ def parse_fsdp_args(): "--no-fp8", action="store_true", default=False, - help="Disables the te.autocast() context.", + help=( + "Disable te.autocast() FP8 context. Incompatible with --precision fp8/mxfp8/nvfp4." + " Default: False." + ), ) parser.add_argument( "--no-defer-init", @@ -189,7 +216,21 @@ def parse_fsdp_args(): "--dtype", type=torch_dtype, default=torch.bfloat16, - help="Data type for input tensor and Transformer Engine module parameters.", + action=StoreExplicitAction, + help=( + "Parameter dtype: fp32/float32, fp16/float16, bf16/bfloat16. Overrides --precision" + " dtype when explicitly set. Default: bfloat16." + ), + ) + parser.add_argument( + "--precision", + type=precision, + default=None, + help=( + "Precision preset: fp32, fp16, fp8, mxfp8, nvfp4. Configures dtype and FP8 recipe" + " automatically. Overridden by explicit --dtype. Default: None (use --dtype and" + " --no-fp8 directly)." + ), ) return parser.parse_args() @@ -200,15 +241,112 @@ def dist_print(text, all_ranks=False, no_new_line=False): print(f"[GPU-{LOCAL_RANK}] " + text, end=end) +def get_precision_preset(precision_value): + """Get dtype, no_fp8, and recipe based on precision preset. + + Returns: + tuple: (dtype, no_fp8, recipe) + """ + match precision_value: + case "fp32": + return torch.float32, True, None + case "fp16": + return torch.float16, True, None + case "fp8": + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + return torch.bfloat16, False, recipe + case "mxfp8": + recipe = MXFP8BlockScaling() + return torch.bfloat16, False, recipe + case "nvfp4": + recipe = NVFP4BlockScaling() + return torch.bfloat16, False, recipe + case _: + raise ValueError( + f"Invalid precision preset: {precision_value}. " + "Supported values: fp32, fp16, fp8, mxfp8, nvfp4" + ) + + def train(opts): + # Check which flags were explicitly set + dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) + + # Validate flag combinations before touching distributed state. + # Error if user requests FP8-based precision but also sets --no-fp8 + # Safe to raise here because torchrun guarantees all ranks receive + # identical CLI arguments; all ranks will raise simultaneously. + if opts.precision in ["fp8", "mxfp8", "nvfp4"] and opts.no_fp8: + raise ValueError( + f"Cannot use --no-fp8 with --precision {opts.precision}. " + "These flags are incompatible. " + f"Either remove --no-fp8 to use {opts.precision} training, " + "or use --precision fp32/fp16 for non-FP8 training." + ) + if opts.precision in ["fp32", "fp16"] and opts.no_fp8: + dist_print( + f"Warning: --no-fp8 is redundant when using --precision {opts.precision} " + "(FP8 is already disabled by this preset). The flag will be ignored." + ) + # Initialize torch.distributed global process group dist.init_process_group(backend="nccl") torch.cuda.set_device(LOCAL_RANK) dist_print(f"WORLD_SIZE = {WORLD_SIZE}") torch.manual_seed(opts.seed) + preset_dtype: torch.dtype = opts.dtype # sensible fallback + preset_recipe = None + + if opts.precision is not None: + preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) + dtype, no_fp8, recipe = preset_dtype, preset_no_fp8, preset_recipe + dist_print(f"Using precision preset: {opts.precision}") + else: + # Original behavior: --dtype and --no-fp8 control training directly + dtype = opts.dtype + no_fp8 = opts.no_fp8 + recipe = ( + DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max") + if not no_fp8 + else None + ) + + dtype_name = str(dtype).replace("torch.", "") + + # Apply explicit dtype override with warning + if dtype_explicitly_set and opts.precision is not None: + new_dtype = opts.dtype + if new_dtype != preset_dtype: + dtype = new_dtype + dtype_name = str(dtype).replace("torch.", "") + + dist_print( + f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype" + " setting" + ) + else: + new_dtype_name = str(new_dtype).replace("torch.", "") + dist_print( + f"Info: --dtype {new_dtype_name} matches --precision {opts.precision} preset" + " default, no override needed" + ) + + # recipe is already set correctly from preset_recipe above; + # dtype only affects parameter storage, not the quantization recipe + + # Always log the final configuration being used + dist_print( + f"Training configuration: dtype={dtype_name}, " + f"quantization={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}" + ) + # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM layer_args, layer_kwargs = get_layer_args(opts) + layer_kwargs["params_dtype"] = dtype + if opts.num_layers > 1: te_layer_list = [] for i in range(opts.num_layers): @@ -239,7 +377,7 @@ def train(opts): process_group=all_gpus, use_orig_params=True, mixed_precision=MixedPrecision( - param_dtype=opts.dtype, + param_dtype=dtype, reduce_dtype=torch.float32, ), auto_wrap_policy=fsdp_wrap_policy, @@ -258,10 +396,6 @@ def train(opts): dist_print(f"Post-FSDP memory use = {post_mem_use}MiB") dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}") - # Fp8 setup for TE - fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") - # Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded optim = torch.optim.Adam(te_model.parameters(), lr=0.0001) @@ -275,17 +409,34 @@ def train(opts): torch.cuda.synchronize() start.record() + # MXFP8 and NVFP4 use local block scaling — no distributed amax reduction group needed. + # amax_reduction_group is only required for DelayedScaling (global AMAX allreduce). + # Also skip when FP8 is disabled to avoid unnecessary distributed communication. + # Compute amax_group BEFORE the recipe fallback so isinstance() reflects the actual + # recipe, not the defensive DelayedScaling() substituted for None. + amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None + + # Ensure recipe is always a concrete object before passing to te.autocast. + # When FP8 is disabled, te.autocast ignores the recipe, but some TE versions + # perform attribute access on it regardless of the enabled flag. + if recipe is None: + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + + for i in range(opts.num_iters): # Generate a random input batch x = torch.rand( opts.seq_length, opts.batch_size, opts.num_heads * opts.head_dim, - dtype=opts.dtype, + dtype=dtype, device="cuda", ) + # autocast needs to be given the FSDP process group for amax reductions - with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus): + with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=amax_group): y = te_model(x) loss = y.sum() # calculate gradient and take training step outside the autocast context