Skip to content

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713

Open
cspades wants to merge 4 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp
Open

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
cspades wants to merge 4 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp

Conversation

@cspades
Copy link
Member

@cspades cspades commented Feb 26, 2026

Summary

Details

  • Fix bug where "shard" was the presumed weight sharding sub-mesh in the DTensor.device_mesh. Now, users can precisely specify their own custom weight-sharding DeviceMesh for per-tensor amax_reduction_group via the set_device_mesh API.

Testing

# TransformerEngine Main
[Rank 0] (after 1 iterations) memory (MB) | allocated: 23511.65 | max allocated: 25189.68 | reserved: 25678.00 | max reserved: 25678.00
 [2026-03-02 09:55:17.189564] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12715.7 | throughput per GPU (TFLOP/s/GPU): 530.6 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124915E+00 | loss scale: 1.0 | grad norm: 5.474 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:55:29.768521] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12578.7 | throughput per GPU (TFLOP/s/GPU): 536.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.143806E+00 | loss scale: 1.0 | grad norm: 5.366 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Post-DCP Modifications (This PR)
[Rank 0] (after 2 iterations) memory (MB) | allocated: 23511.65 | max allocated: 29783.24 | reserved: 25678.00 | max reserved: 31510.00
 [2026-03-02 09:29:36.550070] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12556.5 | throughput per GPU (TFLOP/s/GPU): 537.3 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124463E+00 | loss scale: 1.0 | grad norm: 5.471 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:29:49.216068] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12665.7 | throughput per GPU (TFLOP/s/GPU): 532.7 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.142863E+00 | loss scale: 1.0 | grad norm: 5.355 | number of skipped iterations:   0 | number of nan iterations:   0 |

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR adds FSDP2 + Tensor Parallelism (TP) strided-sharding support and DCP checkpoint compatibility across all TransformerEngineBaseModule subclasses by introducing a set_device_mesh(tp_mesh, weight_mesh) API that converts module parameters into DTensor shards and wires the correct amax_reduction_group for FP8 per-tensor scaling recipes.

Key changes:

  • New set_device_mesh() API added to Linear, LayerNormLinear, LayerNormMLP, GroupedLinear, MultiheadAttention, TransformerLayer, LayerNorm, RMSNorm, and DotProductAttention — each class converts its parameters to appropriately-placed DTensors (column-Shard(0), row-Shard(1), or Replicate()) for seamless FSDP2-TP strided-sharding.
  • Fixes the hardcoded "shard" mesh dimension in amax_reduction_group fallback — now uses device_mesh.get_group() on the full mesh rather than assuming a "shard" named dimension.
  • _convert_param_to_dtensor_param utility added in distributed.py to safely convert a Parameter (including those with custom attributes like param_init_meta) into a DTensor-wrapped Parameter.
  • DTensor localization added throughout forward/backward paths (base.py, ops/basic/layer_norm.py, ops/basic/rmsnorm.py, all _get_weight_tensors/_get_bias_tensors helpers) so TE's C++ kernels always receive plain local tensors.
  • FP8 all-gather output DTensor workaround in float8_tensor.py and mxfp8_tensor.py uses _local_tensor (with a FIXME annotation) to avoid a Torch Dispatch limitation.

Concern: The test script implements a complete AppState DCP class and imports checkpoint utilities, but the training loop never calls save()/load() to actually validate the checkpoint round-trip. Since DCP compatibility is the PR's headline claim, the test coverage for this core feature is missing.

Confidence Score: 2/5

  • Core FSDP2-TP sharding logic is well-designed and parity tests pass, but the PR's headline DCP checkpoint feature is untested in the test script.
  • The set_device_mesh() API implementation across all module types is architecturally sound, DTensor handling is consistent, and parameter conversion utilities are correct. Megatron-LM parity results demonstrate the sharding logic works in practice. However, the PR title is "Add DCP compatibility"—this is the primary feature claim—yet the test script defines a complete AppState DCP class but never exercises the save/load round-trip. Without validating the checkpoint functionality, confidence in the core feature cannot be high.
  • tests/pytorch/distributed/run_fsdp2_model.py requires integration of an actual DCP checkpoint save/load cycle in _train() to validate the headline feature.

Sequence Diagram

sequenceDiagram
    participant User
    participant TransformerLayer
    participant MultiheadAttention
    participant LayerNormLinear
    participant base as TransformerEngineBaseModule
    participant distributed as distributed.py

    User->>TransformerLayer: __init__(tp_mesh, weight_mesh)
    TransformerLayer->>MultiheadAttention: __init__(tp_mesh, weight_mesh)
    MultiheadAttention->>LayerNormLinear: __init__(tp_mesh, weight_mesh)
    LayerNormLinear->>LayerNormLinear: init_fp8_metadata()
    LayerNormLinear->>LayerNormLinear: set_device_mesh(tp_mesh, weight_mesh)
    LayerNormLinear->>distributed: _convert_param_to_dtensor_param(weight, tp_mesh, Shard/Replicate)
    distributed-->>LayerNormLinear: nn.Parameter(DTensor)
    LayerNormLinear->>LayerNormLinear: quantizer.amax_reduction_group = weight_mesh.get_group()
    LayerNormLinear->>LayerNormLinear: reset_parameters()
    LayerNormLinear->>base: reset_parameters() [handles DTensor params]

    User->>TransformerLayer: forward(inp)
    TransformerLayer->>base: prepare_forward(inp)
    Note over base: if isinstance(inp, DTensor): inp = inp.to_local()
    base-->>TransformerLayer: local tensor inp
    TransformerLayer->>LayerNormLinear: forward(inp)
    LayerNormLinear->>LayerNormLinear: _get_weight_tensors() → weight.to_local()
    LayerNormLinear->>LayerNormLinear: _get_layernorm_weight_and_bias() → to_local()
Loading

Last reviewed commit: dbb9d14

Comment on lines 151 to 158
if args.sharding_dims:
assert len(args.sharding_dims) <= 2
assert len(args.sharding_dims) <= 3
if len(args.sharding_dims) >= 3:
# Set the TP size in args.
args.tp_size = args.sharding_dims[2]
else:
args.tp_size = 1
return args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args.sharding_dims not guarded against None

At line 153, len(args.sharding_dims) is called unconditionally, but args.sharding_dims can be None when the --sharding-dims flag is omitted (since the argument is not marked required=True and uses nargs="+"). This will raise TypeError: object of type 'NoneType' has no len().

The if len(args.sharding_dims) >= 3: block should be nested inside the existing if args.sharding_dims: guard:

Suggested change
if args.sharding_dims:
assert len(args.sharding_dims) <= 2
assert len(args.sharding_dims) <= 3
if len(args.sharding_dims) >= 3:
# Set the TP size in args.
args.tp_size = args.sharding_dims[2]
else:
args.tp_size = 1
return args
if args.sharding_dims:
assert len(args.sharding_dims) <= 3
if len(args.sharding_dims) >= 3:
# Set the TP size in args.
args.tp_size = args.sharding_dims[2]
else:
args.tp_size = 1
else:
args.tp_size = 1

Comment on lines +2077 to +2087
self.fc1_bias = _convert_param_to_dtensor_param(
self.fc1_bias, tp_mesh, placements=(Shard(dim=0),)
)
# FC2 Weight -> Row-Parallel -> Shard(dim=1)
self.fc2_weight = _convert_param_to_dtensor_param(
self.fc2_weight, tp_mesh, placements=(Shard(dim=1),)
)
# LN & FC2 Bias -> Replicate()
self.fc2_bias = _convert_param_to_dtensor_param(
self.fc2_bias, tp_mesh, placements=(Replicate(),)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bias converted unconditionally when use_bias=False

When use_bias=False, self.fc1_bias and self.fc2_bias are initialized as plain torch.Tensor objects (not nn.Parameter, see lines 1940 and 1958):

else:
    self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device)

Calling _convert_param_to_dtensor_param on them returns nn.Parameter(DTensor.from_local(...)). When this is then assigned back via self.fc1_bias = new_param, PyTorch's Module.__setattr__ will detect the nn.Parameter type and register the bias as a named module parameter, even though biases are disabled. This would pollute model.named_parameters(), the optimizer parameter list, and checkpoint state.

The fix is to guard these two conversions behind if self.use_bias:, following the same pattern already used for layer_norm_bias at line 2091.

weight_mesh : Optional[DeviceMesh]
Not used for DotProductAttention as there are no quantized weights.
"""
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spurious warning when weight_mesh is None

warnings.warn(...) is emitted unconditionally every time set_device_mesh is called, even when weight_mesh=None. The calling code invokes this method whenever tp_mesh is not None or weight_mesh is not None, so a normal call with only tp_mesh provided will generate a misleading warning like "weight_mesh not necessary for DotProductAttention: None".

The warning should only fire when the caller explicitly passes a non-None weight_mesh. The same spurious warning exists in transformer_engine/pytorch/module/layernorm.py (line 171) and transformer_engine/pytorch/module/rmsnorm.py (line 174).

Suggested change
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
if weight_mesh is not None:
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")

cspades and others added 4 commits March 4, 2026 10:10
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from 4ec2947 to dbb9d14 Compare March 4, 2026 18:10
Comment on lines 22 to 97
@@ -30,6 +38,61 @@
LOCAL_RANK = None


@dataclass
class AppState(Stateful):
"""AppState for FSDP2 checkpoint via Torch DCP.

Adapted from https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
"""

model: torch.nn.Module
optimizer: torch.optim.Optimizer

def state_dict(self):
"""
Get the state dict for the model, optimizer, scheduler, and step.
This factory both retrieves the model state dictionary when saving
checkpoints and initializes a destination for the state read from
DCP checkpoint files when loading checkpoints.
"""
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
for fqn in list(model_state_dict.keys()):
# Get the model parameter.
model_param = model_state_dict[fqn]
if isinstance(model_param, DTensor):
model_param = model_param.to_local()
if model_param.numel() == 0 and fqn in optimizer_state_dict["state"]:
# Empty model parameter. Clear the associated optimizer state
# when initializing the optimizer state upon DCP load, because
# empty optimizer state DTensors are not checkpointed with DCP,
# yet get_state_dict / _init_optim_state produce empty Tensors.
# TransformerEngine uses empty Tensors for dummy Parameters.
optimizer_state_dict["state"][fqn] = {}
if fqn.endswith("._extra_state"):
# Evict `_extra_state` quantization data from model checkpoint.
model_state_dict.pop(fqn)
return {
"model": model_state_dict,
"optim": optimizer_state_dict,
}

def load_state_dict(self, state_dict: dict):
"""
Load the state dict for the model, optimizer, scheduler, and step.
Given the checkpoint-loaded state_dict, set the state of the model,
optimizer, scheduler, step, and epoch to the values in state_dict.
"""
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
# Non-strict checkpoint loading ignores empty optimizer states,
# skips loading non-FP8 checkpoint weights (e.g. _extra_state).
options=StateDictOptions(strict=False),
)


def dist_print(msg):
if LOCAL_RANK == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DCP checkpoint functionality is not exercised in the test

The AppState class (lines 42–93) and DCP checkpoint operations (save, load, get_state_dict, set_state_dict) are imported and fully implemented, but the training loop in _train() (lines 480–490) does not call any checkpoint save/load operations. The function ends at line 497 with dist.destroy_process_group() and no checkpoint round-trip.

Since the PR title is "Add DCP compatibility for FSDP2-TP sharding," the checkpoint functionality is the headline feature. Without an actual save/load call in the test, neither the AppState.state_dict() eviction logic nor the set_state_dict(strict=False) reload path is validated.

Recommendation: Add a checkpoint save/load round-trip after the training loop (before dist.destroy_process_group()) to exercise the DCP functionality, or explicitly note in the test docstring that DCP round-trip testing is deferred to integration tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on it! + GroupedLinear test case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant