-
Notifications
You must be signed in to change notification settings - Fork 653
Add DCP compatibility for FSDP2-TP sharding in TransformerEngine. #2713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||
| import sys | ||||||||||||||||||||||||||||||||||||||
| import argparse | ||||||||||||||||||||||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| import transformer_engine.pytorch as te | ||||||||||||||||||||||||||||||||||||||
| from transformer_engine.common.recipe import ( | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -18,6 +19,13 @@ | |||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||
| import torch.distributed as dist | ||||||||||||||||||||||||||||||||||||||
| from torch.distributed.checkpoint import save, load | ||||||||||||||||||||||||||||||||||||||
| from torch.distributed.checkpoint.state_dict import ( | ||||||||||||||||||||||||||||||||||||||
| StateDictOptions, | ||||||||||||||||||||||||||||||||||||||
| get_state_dict, | ||||||||||||||||||||||||||||||||||||||
| set_state_dict, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| from torch.distributed.checkpoint.stateful import Stateful | ||||||||||||||||||||||||||||||||||||||
| from torch.distributed.tensor import DTensor | ||||||||||||||||||||||||||||||||||||||
| import torch.nn.functional as F | ||||||||||||||||||||||||||||||||||||||
| from torch import nn, optim | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -30,6 +38,61 @@ | |||||||||||||||||||||||||||||||||||||
| LOCAL_RANK = None | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||||||||||||||
| class AppState(Stateful): | ||||||||||||||||||||||||||||||||||||||
| """AppState for FSDP2 checkpoint via Torch DCP. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Adapted from https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| model: torch.nn.Module | ||||||||||||||||||||||||||||||||||||||
| optimizer: torch.optim.Optimizer | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def state_dict(self): | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| Get the state dict for the model, optimizer, scheduler, and step. | ||||||||||||||||||||||||||||||||||||||
| This factory both retrieves the model state dictionary when saving | ||||||||||||||||||||||||||||||||||||||
| checkpoints and initializes a destination for the state read from | ||||||||||||||||||||||||||||||||||||||
| DCP checkpoint files when loading checkpoints. | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) | ||||||||||||||||||||||||||||||||||||||
| for fqn in list(model_state_dict.keys()): | ||||||||||||||||||||||||||||||||||||||
| # Get the model parameter. | ||||||||||||||||||||||||||||||||||||||
| model_param = model_state_dict[fqn] | ||||||||||||||||||||||||||||||||||||||
| if isinstance(model_param, DTensor): | ||||||||||||||||||||||||||||||||||||||
| model_param = model_param.to_local() | ||||||||||||||||||||||||||||||||||||||
| if model_param.numel() == 0 and fqn in optimizer_state_dict["state"]: | ||||||||||||||||||||||||||||||||||||||
| # Empty model parameter. Clear the associated optimizer state | ||||||||||||||||||||||||||||||||||||||
| # when initializing the optimizer state upon DCP load, because | ||||||||||||||||||||||||||||||||||||||
| # empty optimizer state DTensors are not checkpointed with DCP, | ||||||||||||||||||||||||||||||||||||||
| # yet get_state_dict / _init_optim_state produce empty Tensors. | ||||||||||||||||||||||||||||||||||||||
| # TransformerEngine uses empty Tensors for dummy Parameters. | ||||||||||||||||||||||||||||||||||||||
| optimizer_state_dict["state"][fqn] = {} | ||||||||||||||||||||||||||||||||||||||
| if fqn.endswith("._extra_state"): | ||||||||||||||||||||||||||||||||||||||
| # Evict `_extra_state` quantization data from model checkpoint. | ||||||||||||||||||||||||||||||||||||||
| model_state_dict.pop(fqn) | ||||||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||||||
| "model": model_state_dict, | ||||||||||||||||||||||||||||||||||||||
| "optim": optimizer_state_dict, | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def load_state_dict(self, state_dict: dict): | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| Load the state dict for the model, optimizer, scheduler, and step. | ||||||||||||||||||||||||||||||||||||||
| Given the checkpoint-loaded state_dict, set the state of the model, | ||||||||||||||||||||||||||||||||||||||
| optimizer, scheduler, step, and epoch to the values in state_dict. | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| set_state_dict( | ||||||||||||||||||||||||||||||||||||||
| self.model, | ||||||||||||||||||||||||||||||||||||||
| self.optimizer, | ||||||||||||||||||||||||||||||||||||||
| model_state_dict=state_dict["model"], | ||||||||||||||||||||||||||||||||||||||
| optim_state_dict=state_dict["optim"], | ||||||||||||||||||||||||||||||||||||||
| # Non-strict checkpoint loading ignores empty optimizer states, | ||||||||||||||||||||||||||||||||||||||
| # skips loading non-FP8 checkpoint weights (e.g. _extra_state). | ||||||||||||||||||||||||||||||||||||||
| options=StateDictOptions(strict=False), | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def dist_print(msg): | ||||||||||||||||||||||||||||||||||||||
| if LOCAL_RANK == 0: | ||||||||||||||||||||||||||||||||||||||
| print(msg) | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -82,11 +145,16 @@ def _parse_args(argv=None, namespace=None): | |||||||||||||||||||||||||||||||||||||
| "--sharding-dims", | ||||||||||||||||||||||||||||||||||||||
| type=int, | ||||||||||||||||||||||||||||||||||||||
| nargs="+", | ||||||||||||||||||||||||||||||||||||||
| help='FSDP/HSDP sharding dimensions ("replicate", "shard")', | ||||||||||||||||||||||||||||||||||||||
| help='FSDP/HSDP sharding dimensions ("dp_replicate", "dp_shard", "tp")', | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| args = parser.parse_args(argv, namespace) | ||||||||||||||||||||||||||||||||||||||
| if args.sharding_dims: | ||||||||||||||||||||||||||||||||||||||
| assert len(args.sharding_dims) <= 2 | ||||||||||||||||||||||||||||||||||||||
| assert len(args.sharding_dims) <= 3 | ||||||||||||||||||||||||||||||||||||||
| if len(args.sharding_dims) >= 3: | ||||||||||||||||||||||||||||||||||||||
| # Set the TP size in args. | ||||||||||||||||||||||||||||||||||||||
| args.tp_size = args.sharding_dims[2] | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| args.tp_size = 1 | ||||||||||||||||||||||||||||||||||||||
| return args | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
151
to
158
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
At line 153, The
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
@@ -136,11 +204,17 @@ def init_te_model(config): | |||||||||||||||||||||||||||||||||||||
| "params_dtype": params_dtype, | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| kwargs["device"] = config.device | ||||||||||||||||||||||||||||||||||||||
| kwargs["tp_size"] = config.tp_size | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| layer_type = get_te_layer_from_string(config.layer_type) | ||||||||||||||||||||||||||||||||||||||
| # We are creating model in a way so that we can test both reshard_after_forward=True/False cases. | ||||||||||||||||||||||||||||||||||||||
| # more details below. | ||||||||||||||||||||||||||||||||||||||
| if layer_type in [te.MultiheadAttention, te.TransformerLayer]: | ||||||||||||||||||||||||||||||||||||||
| if layer_type in [ | ||||||||||||||||||||||||||||||||||||||
| te.TransformerLayer, | ||||||||||||||||||||||||||||||||||||||
| te.MultiheadAttention, | ||||||||||||||||||||||||||||||||||||||
| te.LayerNormMLP, | ||||||||||||||||||||||||||||||||||||||
| # TODO(@cspades): GroupedLinear testing. | ||||||||||||||||||||||||||||||||||||||
| ]: | ||||||||||||||||||||||||||||||||||||||
| # For this case, we are creating a model that resemebles production use-cases | ||||||||||||||||||||||||||||||||||||||
| # wherein there are mltiple TransformerLayers in the model. And we would need | ||||||||||||||||||||||||||||||||||||||
| # to shard each transformer layer. Since each transformer layer is not a root module, | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -150,44 +224,102 @@ def init_te_model(config): | |||||||||||||||||||||||||||||||||||||
| kwargs["fuse_qkv_params"] = True | ||||||||||||||||||||||||||||||||||||||
| if layer_type is te.MultiheadAttention: | ||||||||||||||||||||||||||||||||||||||
| kwargs["input_layernorm"] = True | ||||||||||||||||||||||||||||||||||||||
| # DeviceMesh / DTensor-related model parameter operations! | ||||||||||||||||||||||||||||||||||||||
| # NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters. | ||||||||||||||||||||||||||||||||||||||
| # If not using meta device initialization, reset_parameters is called during __init__. | ||||||||||||||||||||||||||||||||||||||
| if config.tp_size > 1: | ||||||||||||||||||||||||||||||||||||||
| assert "dp_shard" in config.mesh.mesh_dim_names | ||||||||||||||||||||||||||||||||||||||
| assert "tp" in config.mesh.mesh_dim_names | ||||||||||||||||||||||||||||||||||||||
| dist_print(f"Tensor parallelism activated with size: {config.tp_size}") | ||||||||||||||||||||||||||||||||||||||
| # Activate TP in TE. | ||||||||||||||||||||||||||||||||||||||
| kwargs["set_parallel_mode"] = True | ||||||||||||||||||||||||||||||||||||||
| # For TP shards as DTensors. | ||||||||||||||||||||||||||||||||||||||
| kwargs["tp_mesh"] = config.mesh["tp"] | ||||||||||||||||||||||||||||||||||||||
| # For per-tensor quantization recipes with TP. | ||||||||||||||||||||||||||||||||||||||
| kwargs["weight_mesh"] = config.mesh["dp_shard", "tp"]._flatten("weight_mesh") | ||||||||||||||||||||||||||||||||||||||
| elif len(config.mesh.mesh_dim_names) > 1: | ||||||||||||||||||||||||||||||||||||||
| assert "dp_shard" in config.mesh.mesh_dim_names | ||||||||||||||||||||||||||||||||||||||
| # HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`. | ||||||||||||||||||||||||||||||||||||||
| # Used for per-tensor quantization recipes like Float8CurrentScaling. | ||||||||||||||||||||||||||||||||||||||
| kwargs["weight_mesh"] = config.mesh["dp_shard"] # Only sharding with FSDP. | ||||||||||||||||||||||||||||||||||||||
| # Initialize model. | ||||||||||||||||||||||||||||||||||||||
| model = nn.Sequential(*[layer_type(*args, **kwargs) for _ in range(config.num_layers)]) | ||||||||||||||||||||||||||||||||||||||
| elif layer_type == te.LayerNormLinear: | ||||||||||||||||||||||||||||||||||||||
| elif layer_type in [te.LayerNormLinear, te.Linear]: | ||||||||||||||||||||||||||||||||||||||
| # For this case, we are creating a model with just one LayerNormLinear layer | ||||||||||||||||||||||||||||||||||||||
| # so that the model itself is a root module, and FSDP2's fully_shard assigns | ||||||||||||||||||||||||||||||||||||||
| # reshard_after_forward=True for the parameters of these model. | ||||||||||||||||||||||||||||||||||||||
| args[1] *= 3 # QKV projection | ||||||||||||||||||||||||||||||||||||||
| out_shape[-1] *= 3 | ||||||||||||||||||||||||||||||||||||||
| # DeviceMesh / DTensor-related model parameter operations! | ||||||||||||||||||||||||||||||||||||||
| # NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters. | ||||||||||||||||||||||||||||||||||||||
| # If not using meta device initialization, reset_parameters is called during __init__. | ||||||||||||||||||||||||||||||||||||||
| if config.tp_size > 1: | ||||||||||||||||||||||||||||||||||||||
| assert "dp_shard" in config.mesh.mesh_dim_names | ||||||||||||||||||||||||||||||||||||||
| assert "tp" in config.mesh.mesh_dim_names | ||||||||||||||||||||||||||||||||||||||
| dist_print(f"Tensor parallelism activated with size: {config.tp_size}") | ||||||||||||||||||||||||||||||||||||||
| # Activate TP in TE. | ||||||||||||||||||||||||||||||||||||||
| kwargs["parallel_mode"] = "column" | ||||||||||||||||||||||||||||||||||||||
| # For TP shards as DTensors. | ||||||||||||||||||||||||||||||||||||||
| kwargs["tp_mesh"] = config.mesh["tp"] | ||||||||||||||||||||||||||||||||||||||
| # For per-tensor quantization recipes with TP. | ||||||||||||||||||||||||||||||||||||||
| kwargs["weight_mesh"] = config.mesh["dp_shard", "tp"]._flatten("weight_mesh") | ||||||||||||||||||||||||||||||||||||||
| # Modify output shape for column-parallel Linear. | ||||||||||||||||||||||||||||||||||||||
| out_shape[-1] //= config.tp_size | ||||||||||||||||||||||||||||||||||||||
| elif len(config.mesh.mesh_dim_names) > 1: | ||||||||||||||||||||||||||||||||||||||
| assert "dp_shard" in config.mesh.mesh_dim_names | ||||||||||||||||||||||||||||||||||||||
| # HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`. | ||||||||||||||||||||||||||||||||||||||
| # Used for per-tensor quantization recipes like Float8CurrentScaling. | ||||||||||||||||||||||||||||||||||||||
| kwargs["weight_mesh"] = config.mesh["dp_shard"] # Only sharding with FSDP. | ||||||||||||||||||||||||||||||||||||||
| # Initialize model. | ||||||||||||||||||||||||||||||||||||||
| model = layer_type(*args, **kwargs) | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| # Other TE module. Just ambiguously initialize it. | ||||||||||||||||||||||||||||||||||||||
| model = layer_type(*args, **kwargs) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| return model, inp_shape, out_shape | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def get_device_mesh(world_size, sharding_dims): | ||||||||||||||||||||||||||||||||||||||
| dist_print(f"sharding-dims:{sharding_dims}") | ||||||||||||||||||||||||||||||||||||||
| dist_print(f"sharding-dims: {sharding_dims}") | ||||||||||||||||||||||||||||||||||||||
| device_ids = list(range(world_size)) | ||||||||||||||||||||||||||||||||||||||
| if sharding_dims is None: # FSDP | ||||||||||||||||||||||||||||||||||||||
| mesh = DeviceMesh("cuda", device_ids) | ||||||||||||||||||||||||||||||||||||||
| elif len(sharding_dims) == 1: | ||||||||||||||||||||||||||||||||||||||
| assert sharding_dims[0] == world_size | ||||||||||||||||||||||||||||||||||||||
| mesh = DeviceMesh("cuda", device_ids) | ||||||||||||||||||||||||||||||||||||||
| elif len(sharding_dims) == 2: # HSDP | ||||||||||||||||||||||||||||||||||||||
| # FSDP | ||||||||||||||||||||||||||||||||||||||
| if sharding_dims is None or len(sharding_dims) == 1: | ||||||||||||||||||||||||||||||||||||||
| assert sharding_dims is None or sharding_dims[0] == world_size | ||||||||||||||||||||||||||||||||||||||
| mesh = init_device_mesh( | ||||||||||||||||||||||||||||||||||||||
| "cuda", | ||||||||||||||||||||||||||||||||||||||
| (world_size,), | ||||||||||||||||||||||||||||||||||||||
| mesh_dim_names=("dp_shard",), | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| # HSDP | ||||||||||||||||||||||||||||||||||||||
| elif len(sharding_dims) == 2: | ||||||||||||||||||||||||||||||||||||||
| assert sharding_dims[0] * sharding_dims[1] == world_size | ||||||||||||||||||||||||||||||||||||||
| mesh = init_device_mesh( | ||||||||||||||||||||||||||||||||||||||
| "cuda", | ||||||||||||||||||||||||||||||||||||||
| (sharding_dims[0], sharding_dims[1]), | ||||||||||||||||||||||||||||||||||||||
| mesh_dim_names=("replicate", "shard"), | ||||||||||||||||||||||||||||||||||||||
| mesh_dim_names=("dp_replicate", "dp_shard"), | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| # (H/F)SDP-TP | ||||||||||||||||||||||||||||||||||||||
| elif len(sharding_dims) == 3: | ||||||||||||||||||||||||||||||||||||||
| assert sharding_dims[0] * sharding_dims[1] * sharding_dims[2] == world_size | ||||||||||||||||||||||||||||||||||||||
| mesh = init_device_mesh( | ||||||||||||||||||||||||||||||||||||||
| "cuda", | ||||||||||||||||||||||||||||||||||||||
| (sharding_dims[0], sharding_dims[1], sharding_dims[2]), | ||||||||||||||||||||||||||||||||||||||
| mesh_dim_names=("dp_replicate", "dp_shard", "tp"), | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| # Unsupported topology. | ||||||||||||||||||||||||||||||||||||||
| assert False | ||||||||||||||||||||||||||||||||||||||
| return mesh | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def shard_model_with_fsdp2(model, mesh): | ||||||||||||||||||||||||||||||||||||||
| assert "dp_shard" in mesh.mesh_dim_names | ||||||||||||||||||||||||||||||||||||||
| dp_dims = ( | ||||||||||||||||||||||||||||||||||||||
| ("dp_replicate", "dp_shard") if "dp_replicate" in mesh.mesh_dim_names else ("dp_shard",) | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| for child in model.children(): | ||||||||||||||||||||||||||||||||||||||
| fully_shard(child, mesh=mesh) | ||||||||||||||||||||||||||||||||||||||
| fully_shard(model, mesh=mesh) | ||||||||||||||||||||||||||||||||||||||
| fully_shard(child, mesh=mesh[dp_dims]) | ||||||||||||||||||||||||||||||||||||||
| fully_shard(model, mesh=mesh[dp_dims]) | ||||||||||||||||||||||||||||||||||||||
| return model | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
@@ -216,16 +348,18 @@ def restore_custom_attrs(module, custom_attrs): | |||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| @torch.no_grad() | ||||||||||||||||||||||||||||||||||||||
| def test_fp8_fsdp2_allgather(model): | ||||||||||||||||||||||||||||||||||||||
| # Do manual allgather in fp32 and match against fp8 allgather done | ||||||||||||||||||||||||||||||||||||||
| # with fsdp2 | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| Compare the result of the FP8 AG by FSDP2 with a manual AG in FP32 | ||||||||||||||||||||||||||||||||||||||
| after dequantizing the FP8 values. | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| # FP32 manual weight allgather | ||||||||||||||||||||||||||||||||||||||
| fp32_allgathered_params = {} | ||||||||||||||||||||||||||||||||||||||
| for name, param in model.named_parameters(): | ||||||||||||||||||||||||||||||||||||||
| assert isinstance(param, DTensor) | ||||||||||||||||||||||||||||||||||||||
| local_tensor = param._local_tensor | ||||||||||||||||||||||||||||||||||||||
| device_mesh = param.device_mesh | ||||||||||||||||||||||||||||||||||||||
| dist_group = ( | ||||||||||||||||||||||||||||||||||||||
| device_mesh.get_group(mesh_dim="shard") | ||||||||||||||||||||||||||||||||||||||
| device_mesh.get_group(mesh_dim="dp_shard") | ||||||||||||||||||||||||||||||||||||||
| if device_mesh.ndim > 1 | ||||||||||||||||||||||||||||||||||||||
| else device_mesh.get_group() | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -244,6 +378,10 @@ def test_fp8_fsdp2_allgather(model): | |||||||||||||||||||||||||||||||||||||
| module.unshard() | ||||||||||||||||||||||||||||||||||||||
| # Make sure allgathered parameters match exactly | ||||||||||||||||||||||||||||||||||||||
| for name, param in model.named_parameters(): | ||||||||||||||||||||||||||||||||||||||
| if isinstance(param, DTensor): | ||||||||||||||||||||||||||||||||||||||
| # Will still be a DTensor in the case of TP, even after FSDP2 AG, | ||||||||||||||||||||||||||||||||||||||
| # because we wrap our weights as DTensor shards over the TP group. | ||||||||||||||||||||||||||||||||||||||
| param = param._local_tensor | ||||||||||||||||||||||||||||||||||||||
| assert torch.allclose(param.dequantize(), fp32_allgathered_params[name]) | ||||||||||||||||||||||||||||||||||||||
| # Revert model to original sharded state | ||||||||||||||||||||||||||||||||||||||
| for module in model.modules(): | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -253,6 +391,9 @@ def test_fp8_fsdp2_allgather(model): | |||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _train(args): | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| Torch Distributed Initialization | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| global LOCAL_RANK | ||||||||||||||||||||||||||||||||||||||
| assert "TORCHELASTIC_RUN_ID" in os.environ | ||||||||||||||||||||||||||||||||||||||
| WORLD_RANK = int(os.getenv("RANK", "0")) | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -277,10 +418,20 @@ def _train(args): | |||||||||||||||||||||||||||||||||||||
| nccl_world = dist.new_group(backend="nccl") | ||||||||||||||||||||||||||||||||||||||
| device = torch.device(f"cuda:{LOCAL_RANK}") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Create a DeviceMesh for fully_shard. | ||||||||||||||||||||||||||||||||||||||
| world_size = int(WORLD_SIZE) | ||||||||||||||||||||||||||||||||||||||
| # Setup the sharding mesh for FSDP/HSDP. | ||||||||||||||||||||||||||||||||||||||
| mesh = get_device_mesh(world_size, args.sharding_dims) | ||||||||||||||||||||||||||||||||||||||
| args.mesh = mesh | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| TransformerEngine Model Initialization | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| # FP8 Configuration | ||||||||||||||||||||||||||||||||||||||
| fp8_format = Format.HYBRID | ||||||||||||||||||||||||||||||||||||||
| fp8_recipe = get_recipe_from_string(args.recipe, fp8_format) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Model initialization context. | ||||||||||||||||||||||||||||||||||||||
| build_model_context_args = {} | ||||||||||||||||||||||||||||||||||||||
| if not args.fp8_init: | ||||||||||||||||||||||||||||||||||||||
| # Build model context (FP8 init) | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -301,29 +452,31 @@ def _train(args): | |||||||||||||||||||||||||||||||||||||
| f" {torch.cuda.memory_allocated(device)/1e6} MB" | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Creating a DeviceMesh for fully_shard | ||||||||||||||||||||||||||||||||||||||
| world_size = int(WORLD_SIZE) | ||||||||||||||||||||||||||||||||||||||
| # Setup the sharding mesh for FSDP/HSDP | ||||||||||||||||||||||||||||||||||||||
| mesh = get_device_mesh(world_size, args.sharding_dims) | ||||||||||||||||||||||||||||||||||||||
| # Avoid passing custom attributes to FSDP2. | ||||||||||||||||||||||||||||||||||||||
| custom_attrs = save_custom_attrs(model) | ||||||||||||||||||||||||||||||||||||||
| # Fully-shard the model. Will convert model parameters into DTensor | ||||||||||||||||||||||||||||||||||||||
| # if not already converted by TP. | ||||||||||||||||||||||||||||||||||||||
| model = shard_model_with_fsdp2(model, mesh) | ||||||||||||||||||||||||||||||||||||||
| # Restore custom attributes on parameters. | ||||||||||||||||||||||||||||||||||||||
| restore_custom_attrs(model, custom_attrs) | ||||||||||||||||||||||||||||||||||||||
| # model now has DTensors as its parameters | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if args.device == "meta": | ||||||||||||||||||||||||||||||||||||||
| # After FSDP2 has been applied, materialize and initialize the sharded parameters | ||||||||||||||||||||||||||||||||||||||
| # TE base.py's reset_parameters() handles DTensors with FP8 initialization | ||||||||||||||||||||||||||||||||||||||
| # TE base.py's reset_parameters() handles DTensors with FP8 initialization. | ||||||||||||||||||||||||||||||||||||||
| for module in model.modules(): | ||||||||||||||||||||||||||||||||||||||
| if hasattr(module, "reset_parameters"): | ||||||||||||||||||||||||||||||||||||||
| module.reset_parameters() | ||||||||||||||||||||||||||||||||||||||
| dist_print(f" Sharded parameters materialized and initialized on cuda device.") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| dist_print( | ||||||||||||||||||||||||||||||||||||||
| f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB" | ||||||||||||||||||||||||||||||||||||||
| f"FSDP2 model in CUDA, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB" | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| optimizer = optim.Adam(model.parameters(), lr=1e-3) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| Pre-Save Training | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| for iteration in range(args.iter): | ||||||||||||||||||||||||||||||||||||||
| # Zero the parameter gradients | ||||||||||||||||||||||||||||||||||||||
| optimizer.zero_grad() | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DCP checkpoint functionality is not exercised in the test
The
AppStateclass (lines 42–93) and DCP checkpoint operations (save,load,get_state_dict,set_state_dict) are imported and fully implemented, but the training loop in_train()(lines 480–490) does not call any checkpoint save/load operations. The function ends at line 497 withdist.destroy_process_group()and no checkpoint round-trip.Since the PR title is "Add DCP compatibility for FSDP2-TP sharding," the checkpoint functionality is the headline feature. Without an actual save/load call in the test, neither the
AppState.state_dict()eviction logic nor theset_state_dict(strict=False)reload path is validated.Recommendation: Add a checkpoint save/load round-trip after the training loop (before
dist.destroy_process_group()) to exercise the DCP functionality, or explicitly note in the test docstring that DCP round-trip testing is deferred to integration tests.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Working on it! + GroupedLinear test case.