Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713cspades wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
50da1dc to
925d022
Compare
Greptile SummaryThis PR adds FSDP2 + Tensor Parallelism (TP) strided-sharding support and DCP checkpoint compatibility across all Key changes:
Concern: The test script implements a complete Confidence Score: 2/5
Sequence DiagramsequenceDiagram
participant User
participant TransformerLayer
participant MultiheadAttention
participant LayerNormLinear
participant base as TransformerEngineBaseModule
participant distributed as distributed.py
User->>TransformerLayer: __init__(tp_mesh, weight_mesh)
TransformerLayer->>MultiheadAttention: __init__(tp_mesh, weight_mesh)
MultiheadAttention->>LayerNormLinear: __init__(tp_mesh, weight_mesh)
LayerNormLinear->>LayerNormLinear: init_fp8_metadata()
LayerNormLinear->>LayerNormLinear: set_device_mesh(tp_mesh, weight_mesh)
LayerNormLinear->>distributed: _convert_param_to_dtensor_param(weight, tp_mesh, Shard/Replicate)
distributed-->>LayerNormLinear: nn.Parameter(DTensor)
LayerNormLinear->>LayerNormLinear: quantizer.amax_reduction_group = weight_mesh.get_group()
LayerNormLinear->>LayerNormLinear: reset_parameters()
LayerNormLinear->>base: reset_parameters() [handles DTensor params]
User->>TransformerLayer: forward(inp)
TransformerLayer->>base: prepare_forward(inp)
Note over base: if isinstance(inp, DTensor): inp = inp.to_local()
base-->>TransformerLayer: local tensor inp
TransformerLayer->>LayerNormLinear: forward(inp)
LayerNormLinear->>LayerNormLinear: _get_weight_tensors() → weight.to_local()
LayerNormLinear->>LayerNormLinear: _get_layernorm_weight_and_bias() → to_local()
Last reviewed commit: dbb9d14 |
| if args.sharding_dims: | ||
| assert len(args.sharding_dims) <= 2 | ||
| assert len(args.sharding_dims) <= 3 | ||
| if len(args.sharding_dims) >= 3: | ||
| # Set the TP size in args. | ||
| args.tp_size = args.sharding_dims[2] | ||
| else: | ||
| args.tp_size = 1 | ||
| return args |
There was a problem hiding this comment.
args.sharding_dims not guarded against None
At line 153, len(args.sharding_dims) is called unconditionally, but args.sharding_dims can be None when the --sharding-dims flag is omitted (since the argument is not marked required=True and uses nargs="+"). This will raise TypeError: object of type 'NoneType' has no len().
The if len(args.sharding_dims) >= 3: block should be nested inside the existing if args.sharding_dims: guard:
| if args.sharding_dims: | |
| assert len(args.sharding_dims) <= 2 | |
| assert len(args.sharding_dims) <= 3 | |
| if len(args.sharding_dims) >= 3: | |
| # Set the TP size in args. | |
| args.tp_size = args.sharding_dims[2] | |
| else: | |
| args.tp_size = 1 | |
| return args | |
| if args.sharding_dims: | |
| assert len(args.sharding_dims) <= 3 | |
| if len(args.sharding_dims) >= 3: | |
| # Set the TP size in args. | |
| args.tp_size = args.sharding_dims[2] | |
| else: | |
| args.tp_size = 1 | |
| else: | |
| args.tp_size = 1 |
| self.fc1_bias = _convert_param_to_dtensor_param( | ||
| self.fc1_bias, tp_mesh, placements=(Shard(dim=0),) | ||
| ) | ||
| # FC2 Weight -> Row-Parallel -> Shard(dim=1) | ||
| self.fc2_weight = _convert_param_to_dtensor_param( | ||
| self.fc2_weight, tp_mesh, placements=(Shard(dim=1),) | ||
| ) | ||
| # LN & FC2 Bias -> Replicate() | ||
| self.fc2_bias = _convert_param_to_dtensor_param( | ||
| self.fc2_bias, tp_mesh, placements=(Replicate(),) | ||
| ) |
There was a problem hiding this comment.
Bias converted unconditionally when use_bias=False
When use_bias=False, self.fc1_bias and self.fc2_bias are initialized as plain torch.Tensor objects (not nn.Parameter, see lines 1940 and 1958):
else:
self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device)Calling _convert_param_to_dtensor_param on them returns nn.Parameter(DTensor.from_local(...)). When this is then assigned back via self.fc1_bias = new_param, PyTorch's Module.__setattr__ will detect the nn.Parameter type and register the bias as a named module parameter, even though biases are disabled. This would pollute model.named_parameters(), the optimizer parameter list, and checkpoint state.
The fix is to guard these two conversions behind if self.use_bias:, following the same pattern already used for layer_norm_bias at line 2091.
| weight_mesh : Optional[DeviceMesh] | ||
| Not used for DotProductAttention as there are no quantized weights. | ||
| """ | ||
| warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}") |
There was a problem hiding this comment.
Spurious warning when weight_mesh is None
warnings.warn(...) is emitted unconditionally every time set_device_mesh is called, even when weight_mesh=None. The calling code invokes this method whenever tp_mesh is not None or weight_mesh is not None, so a normal call with only tp_mesh provided will generate a misleading warning like "weight_mesh not necessary for DotProductAttention: None".
The warning should only fire when the caller explicitly passes a non-None weight_mesh. The same spurious warning exists in transformer_engine/pytorch/module/layernorm.py (line 171) and transformer_engine/pytorch/module/rmsnorm.py (line 174).
| warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}") | |
| if weight_mesh is not None: | |
| warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}") |
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
4ec2947 to
dbb9d14
Compare
| @@ -30,6 +38,61 @@ | |||
| LOCAL_RANK = None | |||
|
|
|||
|
|
|||
| @dataclass | |||
| class AppState(Stateful): | |||
| """AppState for FSDP2 checkpoint via Torch DCP. | |||
|
|
|||
| Adapted from https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html | |||
| """ | |||
|
|
|||
| model: torch.nn.Module | |||
| optimizer: torch.optim.Optimizer | |||
|
|
|||
| def state_dict(self): | |||
| """ | |||
| Get the state dict for the model, optimizer, scheduler, and step. | |||
| This factory both retrieves the model state dictionary when saving | |||
| checkpoints and initializes a destination for the state read from | |||
| DCP checkpoint files when loading checkpoints. | |||
| """ | |||
| model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) | |||
| for fqn in list(model_state_dict.keys()): | |||
| # Get the model parameter. | |||
| model_param = model_state_dict[fqn] | |||
| if isinstance(model_param, DTensor): | |||
| model_param = model_param.to_local() | |||
| if model_param.numel() == 0 and fqn in optimizer_state_dict["state"]: | |||
| # Empty model parameter. Clear the associated optimizer state | |||
| # when initializing the optimizer state upon DCP load, because | |||
| # empty optimizer state DTensors are not checkpointed with DCP, | |||
| # yet get_state_dict / _init_optim_state produce empty Tensors. | |||
| # TransformerEngine uses empty Tensors for dummy Parameters. | |||
| optimizer_state_dict["state"][fqn] = {} | |||
| if fqn.endswith("._extra_state"): | |||
| # Evict `_extra_state` quantization data from model checkpoint. | |||
| model_state_dict.pop(fqn) | |||
| return { | |||
| "model": model_state_dict, | |||
| "optim": optimizer_state_dict, | |||
| } | |||
|
|
|||
| def load_state_dict(self, state_dict: dict): | |||
| """ | |||
| Load the state dict for the model, optimizer, scheduler, and step. | |||
| Given the checkpoint-loaded state_dict, set the state of the model, | |||
| optimizer, scheduler, step, and epoch to the values in state_dict. | |||
| """ | |||
| set_state_dict( | |||
| self.model, | |||
| self.optimizer, | |||
| model_state_dict=state_dict["model"], | |||
| optim_state_dict=state_dict["optim"], | |||
| # Non-strict checkpoint loading ignores empty optimizer states, | |||
| # skips loading non-FP8 checkpoint weights (e.g. _extra_state). | |||
| options=StateDictOptions(strict=False), | |||
| ) | |||
|
|
|||
|
|
|||
| def dist_print(msg): | |||
| if LOCAL_RANK == 0: | |||
There was a problem hiding this comment.
DCP checkpoint functionality is not exercised in the test
The AppState class (lines 42–93) and DCP checkpoint operations (save, load, get_state_dict, set_state_dict) are imported and fully implemented, but the training loop in _train() (lines 480–490) does not call any checkpoint save/load operations. The function ends at line 497 with dist.destroy_process_group() and no checkpoint round-trip.
Since the PR title is "Add DCP compatibility for FSDP2-TP sharding," the checkpoint functionality is the headline feature. Without an actual save/load call in the test, neither the AppState.state_dict() eviction logic nor the set_state_dict(strict=False) reload path is validated.
Recommendation: Add a checkpoint save/load round-trip after the training loop (before dist.destroy_process_group()) to exercise the DCP functionality, or explicitly note in the test docstring that DCP round-trip testing is deferred to integration tests.
There was a problem hiding this comment.
Working on it! + GroupedLinear test case.
Summary
DTensorparameters with FP8, across allTransformerEngineBaseModule(s).Details
"shard"was the presumed weight sharding sub-mesh in theDTensor.device_mesh. Now, users can precisely specify their own custom weight-shardingDeviceMeshfor per-tensoramax_reduction_groupvia theset_device_meshAPI.Testing
mainandcspades:cye/fsdp2-tp-dcpso we can assume it is not associated to my change: https://github.com/NVIDIA/Megatron-LM/actions/runs/22637904520/job/65636890955?pr=3661 (TransformerEnginemain)mainvs.cspades:cye/fsdp2-tp-dcpwith Megatron-LMmainon PyTorch25.11Type of change
Changes
Please list the changes introduced in this PR:
Checklist: