-
Notifications
You must be signed in to change notification settings - Fork 653
Add multi-precision training support to FSDP script #2662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c960e6d
5949884
cd97843
1341abb
40919d8
5c1db12
e4846c8
26aee2f
295a106
c9524e3
a506d40
b100f2c
ec31f2a
07a87a7
748bb39
c6fb3a5
da9f82b
2c5473e
a9e664c
196df3d
368820b
9e2f34b
76dcb94
e22c2f2
9d637d5
c3399a1
01852f4
cbad806
70b7786
a766d06
3ee3973
32a54b4
f53b845
3d2caa4
b473ab2
4e251ef
fa959d7
5e57ae4
dcfd6b7
1595dd9
a601034
7755dc0
ef9879a
afa756d
91948e0
6e69a8f
4d30ba0
9509f45
4e3529a
246e948
aba45ea
90d44b0
f1cf722
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,12 @@ | |
| ) | ||
|
|
||
| import transformer_engine.pytorch as te | ||
| from transformer_engine.common.recipe import Format, DelayedScaling | ||
| from transformer_engine.common.recipe import ( | ||
| Format, | ||
| DelayedScaling, | ||
| MXFP8BlockScaling, | ||
| NVFP4BlockScaling, | ||
| ) | ||
| from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp | ||
|
|
||
| LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) | ||
|
|
@@ -64,10 +69,21 @@ def torch_dtype(d): | |
| "bfloat16": torch.bfloat16, | ||
| } | ||
| if lowercase(d) not in typemap.keys(): | ||
| raise TypeError | ||
| raise argparse.ArgumentTypeError( | ||
| f"invalid dtype '{d}'. Supported values: fp32/float32, fp16/float16, bf16/bfloat16" | ||
| ) | ||
| return typemap[lowercase(d)] | ||
|
|
||
|
|
||
| def precision(d): | ||
| typemap = ["fp32", "fp16", "fp8", "mxfp8", "nvfp4"] | ||
| if lowercase(d) not in typemap: | ||
| raise argparse.ArgumentTypeError( | ||
| f"invalid precision '{d}'. Supported values: {', '.join(typemap)}" | ||
| ) | ||
| return lowercase(d) | ||
|
|
||
|
|
||
| te_layer_map = { | ||
| "linear": te.Linear, | ||
| "layernorm": te.LayerNorm, | ||
|
|
@@ -91,7 +107,6 @@ def get_layer_args(opts): | |
| hidden_size = opts.num_heads * opts.head_dim | ||
| layer_args = (hidden_size,) | ||
| layer_kwargs = { | ||
| "params_dtype": opts.dtype, | ||
| "device": "cuda" if opts.no_defer_init else "meta", | ||
| "get_rng_state_tracker": get_cuda_rng_tracker, | ||
| } | ||
|
|
@@ -112,6 +127,15 @@ def get_layer_args(opts): | |
| return layer_args, layer_kwargs | ||
|
|
||
|
|
||
| class StoreExplicitAction(argparse.Action): | ||
| """Custom action that tracks whether an argument was explicitly set.""" | ||
|
|
||
| def __call__(self, parser, namespace, values, option_string=None): | ||
| # values already converted by argparse via action.type | ||
| setattr(namespace, self.dest, values) | ||
| setattr(namespace, f"{self.dest}_explicitly_set", True) | ||
|
|
||
|
|
||
| def parse_fsdp_args(): | ||
| parser = argparse.ArgumentParser( | ||
| description="Run Transformer Engine modules with the " | ||
|
|
@@ -173,7 +197,19 @@ def parse_fsdp_args(): | |
| "--no-fp8", | ||
| action="store_true", | ||
| default=False, | ||
| help="Disables the te.autocast() context.", | ||
| help=( | ||
| "Disables the te.autocast() context. When set, FP8 training is disabled and the model" | ||
| " trains in standard precision (as specified by --dtype). PRECEDENCE: This flag is" | ||
| " incompatible with FP8-based --precision presets. BEHAVIOR: - Without --precision:" | ||
| " Disables FP8 training (original behavior) - With --precision fp32/fp16: Redundant but" | ||
| " harmless (already non-FP8) - With --precision fp8/mxfp8/nvfp4: RAISES ERROR" | ||
| " (incompatible flags) RATIONALE: FP8 presets explicitly request FP8 training, so" | ||
| " disabling FP8 would contradict the user's intent. Use --precision fp32/fp16 instead" | ||
| " for non-FP8 training. EXAMPLES: '--no-fp8' disables FP8 (original behavior)." | ||
| " '--precision fp8 --no-fp8' raises ValueError (incompatible). '--precision fp16' is" | ||
| " the correct way to request non-FP8 training. Default: False (FP8 enabled based on" | ||
| " configuration)." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--no-defer-init", | ||
|
|
@@ -189,7 +225,35 @@ def parse_fsdp_args(): | |
| "--dtype", | ||
| type=torch_dtype, | ||
| default=torch.bfloat16, | ||
| help="Data type for input tensor and Transformer Engine module parameters.", | ||
| action=StoreExplicitAction, | ||
| help=( | ||
| "Data type for input tensor and Transformer Engine module parameters. Supported values:" | ||
| " fp32/float32, fp16/float16, bf16/bfloat16. PRECEDENCE: When explicitly set, this flag" | ||
| " overrides the dtype from --precision preset. BEHAVIOR: - Without --precision:" | ||
| " Controls parameter dtype directly - With --precision: Overrides preset's dtype but" | ||
| " preserves FP8 recipe selection EXAMPLES: '--dtype bf16' uses bfloat16 parameters" | ||
| " (original behavior). '--precision mxfp8 --dtype fp16' uses fp16 parameters with" | ||
| " MXFP8BlockScaling recipe. A warning is issued when overriding --precision dtype." | ||
| " Default: bfloat16." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--precision", | ||
| type=precision, | ||
| default=None, | ||
| help=( | ||
| "Precision preset for model training. Supported values: fp32, fp16, fp8, mxfp8, nvfp4." | ||
| " This is a convenience flag that configures both dtype and FP8 settings automatically." | ||
| " - fp32/fp16: Non-FP8 training with specified dtype - fp8: FP8 training with" | ||
| " DelayedScaling recipe (bf16 parameters) - mxfp8: FP8 training with MXFP8BlockScaling" | ||
| " recipe (bf16 parameters) - nvfp4: FP8 training with NVFP4BlockScaling recipe (bf16" | ||
| " parameters) PRECEDENCE RULES: - If --dtype is explicitly set, it overrides the dtype" | ||
| " from this preset (with warning) - If --no-fp8 is set with fp8/mxfp8/nvfp4 presets, an" | ||
| " error is raised (incompatible) - If this flag is not set, original behavior is used" | ||
| " (--dtype and --no-fp8 control training) EXAMPLES: '--precision mxfp8' enables MXFP8" | ||
| " FP8 training with bf16 parameters. '--precision fp8 --dtype fp16' uses fp16" | ||
| " parameters but keeps DelayedScaling recipe. Default: None (backward compatible mode)." | ||
| ), | ||
| ) | ||
| return parser.parse_args() | ||
|
Comment on lines
+240
to
258
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conflicting CLI flags |
||
|
|
||
|
|
@@ -200,15 +264,113 @@ def dist_print(text, all_ranks=False, no_new_line=False): | |
| print(f"[GPU-{LOCAL_RANK}] " + text, end=end) | ||
|
|
||
|
|
||
| def get_precision_preset(precision_value): | ||
| """Get dtype, no_fp8, and recipe based on precision preset. | ||
|
|
||
| Returns: | ||
| tuple: (dtype, no_fp8, recipe) | ||
| """ | ||
| match precision_value: | ||
| case "fp32": | ||
| return torch.float32, True, None | ||
| case "fp16": | ||
| return torch.float16, True, None | ||
| case "fp8": | ||
| recipe = DelayedScaling( | ||
| fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| return torch.bfloat16, False, recipe | ||
| case "mxfp8": | ||
| recipe = MXFP8BlockScaling() | ||
| return torch.bfloat16, False, recipe | ||
| case "nvfp4": | ||
| recipe = NVFP4BlockScaling() | ||
| return torch.bfloat16, False, recipe | ||
| case None: | ||
| # Default case: no precision preset specified, use original behavior. | ||
| # dtype and no_fp8 are controlled directly by --dtype and --no-fp8 flags. | ||
| return torch.bfloat16, True, None | ||
| case _: | ||
| raise ValueError( | ||
| f"Invalid precision preset: {precision_value}. " | ||
| "Supported values: fp32, fp16, fp8, mxfp8, nvfp4" | ||
| ) | ||
|
|
||
|
|
||
| def train(opts): | ||
| # Check which flags were explicitly set | ||
| dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) | ||
|
|
||
| # Validate flag combinations before touching distributed state. | ||
| # Error if user requests FP8-based precision but also sets --no-fp8 | ||
| # Safe to raise here because torchrun guarantees all ranks receive | ||
| # identical CLI arguments; all ranks will raise simultaneously. | ||
| if opts.precision in ["fp8", "mxfp8", "nvfp4"] and opts.no_fp8: | ||
| raise ValueError( | ||
| f"Cannot use --no-fp8 with --precision {opts.precision}. " | ||
| "These flags are incompatible. " | ||
| f"Either remove --no-fp8 to use {opts.precision} training, " | ||
| "or use --precision fp32/fp16 for non-FP8 training." | ||
| ) | ||
|
|
||
| # Initialize torch.distributed global process group | ||
| dist.init_process_group(backend="nccl") | ||
| torch.cuda.set_device(LOCAL_RANK) | ||
| dist_print(f"WORLD_SIZE = {WORLD_SIZE}") | ||
| torch.manual_seed(opts.seed) | ||
|
|
||
| # Start with precision preset values | ||
| preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) | ||
|
|
||
| dtype = preset_dtype | ||
| no_fp8 = preset_no_fp8 | ||
| recipe = preset_recipe | ||
|
|
||
| # When no --precision is set, respect --no-fp8 and --dtype directly (original behavior) | ||
| if opts.precision is None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So backward compatibility in this script is not that big of a deal. Having default value of precision as DelayedScaling should be ok and this if condition can be removed to make the code simpler. |
||
| no_fp8 = opts.no_fp8 | ||
| dtype = opts.dtype | ||
| # Preserve original default: FP8 enabled → use DelayedScaling as before | ||
| if not no_fp8: | ||
| recipe = DelayedScaling( | ||
| fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" | ||
| ) | ||
| else: | ||
| dist_print(f"Using precision preset: {opts.precision}") | ||
|
|
||
| dtype_name = str(dtype).replace("torch.", "") | ||
|
|
||
| # Apply explicit dtype override with warning | ||
| if dtype_explicitly_set and opts.precision is not None: | ||
| new_dtype = opts.dtype | ||
| if new_dtype != preset_dtype: | ||
| dtype = new_dtype | ||
| dtype_name = str(dtype).replace("torch.", "") | ||
|
|
||
| dist_print( | ||
| f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype" | ||
| " setting" | ||
| ) | ||
| else: | ||
| new_dtype_name = str(new_dtype).replace("torch.", "") | ||
| dist_print( | ||
| f"Info: --dtype {new_dtype_name} matches --precision {opts.precision} preset" | ||
| " default, no override needed" | ||
| ) | ||
|
|
||
| # recipe is already set correctly from preset_recipe above; | ||
| # dtype only affects parameter storage, not the quantization recipe | ||
|
|
||
| # Always log the final configuration being used | ||
| dist_print( | ||
| f"Training configuration: dtype={dtype_name}, " | ||
| f"quantization={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}" | ||
| ) | ||
|
|
||
| # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM | ||
| layer_args, layer_kwargs = get_layer_args(opts) | ||
| layer_kwargs["params_dtype"] = dtype | ||
|
|
||
| if opts.num_layers > 1: | ||
| te_layer_list = [] | ||
| for i in range(opts.num_layers): | ||
|
|
@@ -239,7 +401,7 @@ def train(opts): | |
| process_group=all_gpus, | ||
| use_orig_params=True, | ||
| mixed_precision=MixedPrecision( | ||
| param_dtype=opts.dtype, | ||
| param_dtype=dtype, | ||
| reduce_dtype=torch.float32, | ||
| ), | ||
| auto_wrap_policy=fsdp_wrap_policy, | ||
|
|
@@ -258,10 +420,6 @@ def train(opts): | |
| dist_print(f"Post-FSDP memory use = {post_mem_use}MiB") | ||
| dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}") | ||
|
|
||
| # Fp8 setup for TE | ||
| fp8_format = Format.HYBRID | ||
| fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") | ||
|
|
||
| # Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded | ||
| optim = torch.optim.Adam(te_model.parameters(), lr=0.0001) | ||
|
|
||
|
|
@@ -281,11 +439,11 @@ def train(opts): | |
| opts.seq_length, | ||
| opts.batch_size, | ||
| opts.num_heads * opts.head_dim, | ||
| dtype=opts.dtype, | ||
| dtype=dtype, | ||
| device="cuda", | ||
| ) | ||
| # autocast needs to be given the FSDP process group for amax reductions | ||
| with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus): | ||
| with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus): | ||
| y = te_model(x) | ||
| loss = y.sum() | ||
| # calculate gradient and take training step outside the autocast context | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing
action=StoreExplicitActionto track explicit--no-fp8usage - required for precedence logic at line 250