Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 177 additions & 24 deletions tests/pytorch/distributed/run_fsdp2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import sys
import argparse
from dataclasses import dataclass

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import (
Expand All @@ -18,6 +19,13 @@

import torch
import torch.distributed as dist
from torch.distributed.checkpoint import save, load
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_state_dict,
set_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.tensor import DTensor
import torch.nn.functional as F
from torch import nn, optim
Expand All @@ -30,6 +38,61 @@
LOCAL_RANK = None


@dataclass
class AppState(Stateful):
"""AppState for FSDP2 checkpoint via Torch DCP.

Adapted from https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
"""

model: torch.nn.Module
optimizer: torch.optim.Optimizer

def state_dict(self):
"""
Get the state dict for the model, optimizer, scheduler, and step.
This factory both retrieves the model state dictionary when saving
checkpoints and initializes a destination for the state read from
DCP checkpoint files when loading checkpoints.
"""
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
for fqn in list(model_state_dict.keys()):
# Get the model parameter.
model_param = model_state_dict[fqn]
if isinstance(model_param, DTensor):
model_param = model_param.to_local()
if model_param.numel() == 0 and fqn in optimizer_state_dict["state"]:
# Empty model parameter. Clear the associated optimizer state
# when initializing the optimizer state upon DCP load, because
# empty optimizer state DTensors are not checkpointed with DCP,
# yet get_state_dict / _init_optim_state produce empty Tensors.
# TransformerEngine uses empty Tensors for dummy Parameters.
optimizer_state_dict["state"][fqn] = {}
if fqn.endswith("._extra_state"):
# Evict `_extra_state` quantization data from model checkpoint.
model_state_dict.pop(fqn)
return {
"model": model_state_dict,
"optim": optimizer_state_dict,
}

def load_state_dict(self, state_dict: dict):
"""
Load the state dict for the model, optimizer, scheduler, and step.
Given the checkpoint-loaded state_dict, set the state of the model,
optimizer, scheduler, step, and epoch to the values in state_dict.
"""
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
# Non-strict checkpoint loading ignores empty optimizer states,
# skips loading non-FP8 checkpoint weights (e.g. _extra_state).
options=StateDictOptions(strict=False),
)


def dist_print(msg):
if LOCAL_RANK == 0:
Comment on lines 22 to 97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DCP checkpoint functionality is not exercised in the test

The AppState class (lines 42–93) and DCP checkpoint operations (save, load, get_state_dict, set_state_dict) are imported and fully implemented, but the training loop in _train() (lines 480–490) does not call any checkpoint save/load operations. The function ends at line 497 with dist.destroy_process_group() and no checkpoint round-trip.

Since the PR title is "Add DCP compatibility for FSDP2-TP sharding," the checkpoint functionality is the headline feature. Without an actual save/load call in the test, neither the AppState.state_dict() eviction logic nor the set_state_dict(strict=False) reload path is validated.

Recommendation: Add a checkpoint save/load round-trip after the training loop (before dist.destroy_process_group()) to exercise the DCP functionality, or explicitly note in the test docstring that DCP round-trip testing is deferred to integration tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on it! + GroupedLinear test case.

print(msg)
Expand Down Expand Up @@ -82,11 +145,16 @@ def _parse_args(argv=None, namespace=None):
"--sharding-dims",
type=int,
nargs="+",
help='FSDP/HSDP sharding dimensions ("replicate", "shard")',
help='FSDP/HSDP sharding dimensions ("dp_replicate", "dp_shard", "tp")',
)
args = parser.parse_args(argv, namespace)
if args.sharding_dims:
assert len(args.sharding_dims) <= 2
assert len(args.sharding_dims) <= 3
if len(args.sharding_dims) >= 3:
# Set the TP size in args.
args.tp_size = args.sharding_dims[2]
else:
args.tp_size = 1
return args
Comment on lines 151 to 158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args.sharding_dims not guarded against None

At line 153, len(args.sharding_dims) is called unconditionally, but args.sharding_dims can be None when the --sharding-dims flag is omitted (since the argument is not marked required=True and uses nargs="+"). This will raise TypeError: object of type 'NoneType' has no len().

The if len(args.sharding_dims) >= 3: block should be nested inside the existing if args.sharding_dims: guard:

Suggested change
if args.sharding_dims:
assert len(args.sharding_dims) <= 2
assert len(args.sharding_dims) <= 3
if len(args.sharding_dims) >= 3:
# Set the TP size in args.
args.tp_size = args.sharding_dims[2]
else:
args.tp_size = 1
return args
if args.sharding_dims:
assert len(args.sharding_dims) <= 3
if len(args.sharding_dims) >= 3:
# Set the TP size in args.
args.tp_size = args.sharding_dims[2]
else:
args.tp_size = 1
else:
args.tp_size = 1



Expand Down Expand Up @@ -136,11 +204,17 @@ def init_te_model(config):
"params_dtype": params_dtype,
}
kwargs["device"] = config.device
kwargs["tp_size"] = config.tp_size

layer_type = get_te_layer_from_string(config.layer_type)
# We are creating model in a way so that we can test both reshard_after_forward=True/False cases.
# more details below.
if layer_type in [te.MultiheadAttention, te.TransformerLayer]:
if layer_type in [
te.TransformerLayer,
te.MultiheadAttention,
te.LayerNormMLP,
# TODO(@cspades): GroupedLinear testing.
]:
# For this case, we are creating a model that resemebles production use-cases
# wherein there are mltiple TransformerLayers in the model. And we would need
# to shard each transformer layer. Since each transformer layer is not a root module,
Expand All @@ -150,44 +224,102 @@ def init_te_model(config):
kwargs["fuse_qkv_params"] = True
if layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
# DeviceMesh / DTensor-related model parameter operations!
# NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters.
# If not using meta device initialization, reset_parameters is called during __init__.
if config.tp_size > 1:
assert "dp_shard" in config.mesh.mesh_dim_names
assert "tp" in config.mesh.mesh_dim_names
dist_print(f"Tensor parallelism activated with size: {config.tp_size}")
# Activate TP in TE.
kwargs["set_parallel_mode"] = True
# For TP shards as DTensors.
kwargs["tp_mesh"] = config.mesh["tp"]
# For per-tensor quantization recipes with TP.
kwargs["weight_mesh"] = config.mesh["dp_shard", "tp"]._flatten("weight_mesh")
elif len(config.mesh.mesh_dim_names) > 1:
assert "dp_shard" in config.mesh.mesh_dim_names
# HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`.
# Used for per-tensor quantization recipes like Float8CurrentScaling.
kwargs["weight_mesh"] = config.mesh["dp_shard"] # Only sharding with FSDP.
# Initialize model.
model = nn.Sequential(*[layer_type(*args, **kwargs) for _ in range(config.num_layers)])
elif layer_type == te.LayerNormLinear:
elif layer_type in [te.LayerNormLinear, te.Linear]:
# For this case, we are creating a model with just one LayerNormLinear layer
# so that the model itself is a root module, and FSDP2's fully_shard assigns
# reshard_after_forward=True for the parameters of these model.
args[1] *= 3 # QKV projection
out_shape[-1] *= 3
# DeviceMesh / DTensor-related model parameter operations!
# NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters.
# If not using meta device initialization, reset_parameters is called during __init__.
if config.tp_size > 1:
assert "dp_shard" in config.mesh.mesh_dim_names
assert "tp" in config.mesh.mesh_dim_names
dist_print(f"Tensor parallelism activated with size: {config.tp_size}")
# Activate TP in TE.
kwargs["parallel_mode"] = "column"
# For TP shards as DTensors.
kwargs["tp_mesh"] = config.mesh["tp"]
# For per-tensor quantization recipes with TP.
kwargs["weight_mesh"] = config.mesh["dp_shard", "tp"]._flatten("weight_mesh")
# Modify output shape for column-parallel Linear.
out_shape[-1] //= config.tp_size
elif len(config.mesh.mesh_dim_names) > 1:
assert "dp_shard" in config.mesh.mesh_dim_names
# HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`.
# Used for per-tensor quantization recipes like Float8CurrentScaling.
kwargs["weight_mesh"] = config.mesh["dp_shard"] # Only sharding with FSDP.
# Initialize model.
model = layer_type(*args, **kwargs)
else:
# Other TE module. Just ambiguously initialize it.
model = layer_type(*args, **kwargs)

return model, inp_shape, out_shape


def get_device_mesh(world_size, sharding_dims):
dist_print(f"sharding-dims:{sharding_dims}")
dist_print(f"sharding-dims: {sharding_dims}")
device_ids = list(range(world_size))
if sharding_dims is None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(sharding_dims) == 1:
assert sharding_dims[0] == world_size
mesh = DeviceMesh("cuda", device_ids)
elif len(sharding_dims) == 2: # HSDP
# FSDP
if sharding_dims is None or len(sharding_dims) == 1:
assert sharding_dims is None or sharding_dims[0] == world_size
mesh = init_device_mesh(
"cuda",
(world_size,),
mesh_dim_names=("dp_shard",),
)
# HSDP
elif len(sharding_dims) == 2:
assert sharding_dims[0] * sharding_dims[1] == world_size
mesh = init_device_mesh(
"cuda",
(sharding_dims[0], sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
mesh_dim_names=("dp_replicate", "dp_shard"),
)
# (H/F)SDP-TP
elif len(sharding_dims) == 3:
assert sharding_dims[0] * sharding_dims[1] * sharding_dims[2] == world_size
mesh = init_device_mesh(
"cuda",
(sharding_dims[0], sharding_dims[1], sharding_dims[2]),
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
)
else:
# Unsupported topology.
assert False
return mesh


def shard_model_with_fsdp2(model, mesh):
assert "dp_shard" in mesh.mesh_dim_names
dp_dims = (
("dp_replicate", "dp_shard") if "dp_replicate" in mesh.mesh_dim_names else ("dp_shard",)
)
for child in model.children():
fully_shard(child, mesh=mesh)
fully_shard(model, mesh=mesh)
fully_shard(child, mesh=mesh[dp_dims])
fully_shard(model, mesh=mesh[dp_dims])
return model


Expand Down Expand Up @@ -216,16 +348,18 @@ def restore_custom_attrs(module, custom_attrs):

@torch.no_grad()
def test_fp8_fsdp2_allgather(model):
# Do manual allgather in fp32 and match against fp8 allgather done
# with fsdp2
"""
Compare the result of the FP8 AG by FSDP2 with a manual AG in FP32
after dequantizing the FP8 values.
"""
# FP32 manual weight allgather
fp32_allgathered_params = {}
for name, param in model.named_parameters():
assert isinstance(param, DTensor)
local_tensor = param._local_tensor
device_mesh = param.device_mesh
dist_group = (
device_mesh.get_group(mesh_dim="shard")
device_mesh.get_group(mesh_dim="dp_shard")
if device_mesh.ndim > 1
else device_mesh.get_group()
)
Expand All @@ -244,6 +378,10 @@ def test_fp8_fsdp2_allgather(model):
module.unshard()
# Make sure allgathered parameters match exactly
for name, param in model.named_parameters():
if isinstance(param, DTensor):
# Will still be a DTensor in the case of TP, even after FSDP2 AG,
# because we wrap our weights as DTensor shards over the TP group.
param = param._local_tensor
assert torch.allclose(param.dequantize(), fp32_allgathered_params[name])
# Revert model to original sharded state
for module in model.modules():
Expand All @@ -253,6 +391,9 @@ def test_fp8_fsdp2_allgather(model):


def _train(args):
"""
Torch Distributed Initialization
"""
global LOCAL_RANK
assert "TORCHELASTIC_RUN_ID" in os.environ
WORLD_RANK = int(os.getenv("RANK", "0"))
Expand All @@ -277,10 +418,20 @@ def _train(args):
nccl_world = dist.new_group(backend="nccl")
device = torch.device(f"cuda:{LOCAL_RANK}")

# Create a DeviceMesh for fully_shard.
world_size = int(WORLD_SIZE)
# Setup the sharding mesh for FSDP/HSDP.
mesh = get_device_mesh(world_size, args.sharding_dims)
args.mesh = mesh

"""
TransformerEngine Model Initialization
"""
# FP8 Configuration
fp8_format = Format.HYBRID
fp8_recipe = get_recipe_from_string(args.recipe, fp8_format)

# Model initialization context.
build_model_context_args = {}
if not args.fp8_init:
# Build model context (FP8 init)
Expand All @@ -301,29 +452,31 @@ def _train(args):
f" {torch.cuda.memory_allocated(device)/1e6} MB"
)

# Creating a DeviceMesh for fully_shard
world_size = int(WORLD_SIZE)
# Setup the sharding mesh for FSDP/HSDP
mesh = get_device_mesh(world_size, args.sharding_dims)
# Avoid passing custom attributes to FSDP2.
custom_attrs = save_custom_attrs(model)
# Fully-shard the model. Will convert model parameters into DTensor
# if not already converted by TP.
model = shard_model_with_fsdp2(model, mesh)
# Restore custom attributes on parameters.
restore_custom_attrs(model, custom_attrs)
# model now has DTensors as its parameters

if args.device == "meta":
# After FSDP2 has been applied, materialize and initialize the sharded parameters
# TE base.py's reset_parameters() handles DTensors with FP8 initialization
# TE base.py's reset_parameters() handles DTensors with FP8 initialization.
for module in model.modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters()
dist_print(f" Sharded parameters materialized and initialized on cuda device.")

dist_print(
f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB"
f"FSDP2 model in CUDA, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB"
)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

"""
Pre-Save Training
"""
for iteration in range(args.iter):
# Zero the parameter gradients
optimizer.zero_grad()
Expand Down
26 changes: 16 additions & 10 deletions tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,32 @@
def _run_test(fp_init, sharding_dims, recipe, layer_type):
test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py"
test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)]

if fp_init:
test_cmd += ["--fp8-init"]

if len(sharding_dims) == 1:
test_cmd += ["--sharding-dims", str(sharding_dims[0])]
elif len(sharding_dims) == 2:
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])]
else:
assert False
test_cmd += ["--sharding-dims"]
for x in sharding_dims:
test_cmd.append(str(x))
test_cmd += ["--recipe", recipe]
test_cmd += ["--layer-type", layer_type]

result = subprocess.run(test_cmd, env=os.environ, check=True)


@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize(
"sharding_dims",
(
# FSDP
[NUM_PROCS],
# HSDP
[2, NUM_PROCS // 2],
# FSDP-TP
[1, 2, NUM_PROCS // 2],
# HSDP-TP
[NUM_PROCS // 4, 2, 2],
),
)
@pytest.mark.parametrize("fp8_init", (False, True))
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
@pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer"))
Expand Down
Loading
Loading