PyTorch FSDP2 (RFC) provides a fully sharded data parallelism (FSDP) implementation targeting performant eager-mode while using per-parameter sharding for improved usability
-
See the Getting Started with FSDP2 tutorial for more information.
-
If you are currently using FSDP1, consider migrating to FSDP2 using our migration guide.
The user contract for fully_shard(model) is as follows
-
For model initialization, fully_shard converts model.parameters() from plain torch.Tensor to DTensor in-place. The parameters are moved to the appropriate device according to the device mesh.
-
Before forward and backward passes, pre-forward/backward hooks are responsible for all-gathering the parameters and converting model.parameters() from DTensor to plain torch.Tensor.
-
After forward and backward passes, post-forward/backward hooks free the unsharded parameters (no communication needed) and convert model.parameters() from plain torch.Tensor back to DTensor.
-
For the optimizer, it must be initialized with the DTensor model.parameters(), and the optimizer step should be performed on DTensor parameters.
-
Call
model(input)instead ofmodel.forward(input)to trigger pre-forward hooks to all-gather parameters. To make model.forward(input) work, users must either callmodel.unshard()explicitly or useregister_fsdp_forward_method(model, "forward")to register the forward method for hooking. -
fully_shard groups parameters together for a single all-gather. User should apply fully_shard in a bottom-up manner. For example, in a Transformer model, fully_shard should be applied to each layer before applying it to the root model. When applied to the root model, fully_shard excludes model.parameters() from each layer and groups the remaining parameters (e.g., embeddings, output projection) into a single all-gather group.
-
type(model)is "unioned" withFSDPModulein-place. For example, if model is originally of type nn.Linear, then fully_shard changestype(model)from nn.Linear toFSDPLinearin-place.FSDPLinearis an instance of both nn.Linear andFSDPModule. It retains all methods of nn.Linear while also exposing FSDP2-specific APIs under FSDPModule, such asreshard()andunshard(). -
Fully Qualified Names (FQNs) for parameters remain unchanged. If we call
model.state_dict(), the FQNs are the same before and after applying fully_shard. This is because fully_shard does not wrap the module but only registers hooks to the original module.
Each call to fully_shard creates one communication group containing all
parameters in the module that are not already assigned to a group from an
earlier call on a submodule. Each group's parameters are all-gathered together
in one collective before forward, and their gradients are reduce-scattered
together in one collective after backward. Unlike DDP, FSDP2 has no
bucket_cap_mb parameter — the communication boundaries are determined
entirely by which modules you apply fully_shard to.
Consider a model with four submodules where a, b, c, and d
denote the number of parameters in each:
model[ m1[a] -> m2[b] -> m3[c] -> m4[d] ]
If you only call fully_shard(model) (root only), all parameters are
in a single group. This means the entire forward and backward look like:
all-gather(a+b+c+d) -> forward(m1,m2,m3,m4) -> backward(m4,m3,m2,m1) -> reduce-scatter(a+b+c+d)
All communication happens as two large blocking operations with no overlap with compute. This is almost never what you want.
If you apply fully_shard per submodule — for example, calling
fully_shard(m2), fully_shard(m3), and then fully_shard(model) —
the remaining parameters (a and d) form the root group, while m2
and m3 each get their own group.
In forward, all-gathers run on a separate CUDA stream, so the next module's all-gather can overlap with the current module's forward compute. Each module's pre-forward hook issues its own all-gather and waits for it to complete before running the module. Because the CPU typically runs ahead of the GPU, the next module's all-gather is issued on the AG stream while the current module's forward is still executing on the compute stream:
time ──────────────────────────────────────────────►
compute: [wait] [ fwd(m1) | fwd(m2) | fwd(m3,m4) ]
AG stream: [AG(a,d)] [AG(b) | AG(c) ]
While fwd(m1) runs on the compute stream, the CPU fires m2's
pre-forward hook, which issues AG(b) on the AG stream. To make this
overlap more robust (e.g. when CPU-side overhead reduces the lead), use
set_modules_to_forward_prefetch to issue the next all-gather earlier —
inside the current module's pre-forward hook rather than waiting for the next
module's hook to fire.
In backward, FSDP2 additionally prefetches the next module's all-gather explicitly and runs reduce-scatters on a separate CUDA stream, all without any additional configuration:
time ──────────────────────────────────────────────►
compute: [ bwd(m4,m3) | bwd(m2) | bwd(m1) ]
AG stream: [AG(c)] [ AG(b) | AG(a,d) ]
RS stream: |[RS(c)] [ RS(b)| RS(a,d) ]
While bwd(m4,m3) runs on the compute stream, the all-gather for b
(needed by m2) is prefetched on the AG stream. While bwd(m2) runs,
both AG(a,d) and RS(c) overlap with compute. This pipelining is why
the recommended pattern is to apply fully_shard bottom-up to each layer
before applying it to the root.
To control the size of each communication group, choose which modules to wrap: wrapping more fine-grained modules produces smaller, more overlappable groups (similar to smaller DDP buckets), while wrapping fewer modules produces larger groups. There is no automatic bucketing — the grouping is explicit and determined by the module structure.
Compared to PyTorch FSDP1 (FullyShardedDataParallel):
- FSDP2 uses
DTensor-based dim-0 per-parameter sharding for a simpler sharding representation compared to FSDP1's flat-parameter sharding, while preserving similar throughput performance. More specifically, FSDP2 chunks each parameter on dim-0 across the data parallel workers (usingtorch.chunk(dim=0)), whereas FSDP1 flattens, concatenates, and chunks a group of tensors together, making reasoning about what data is present on each worker and resharding to different parallelisms complex. Per-parameter sharding provides a more intuitive user experience, relaxes constraints around frozen parameters, and allows for communication-free (sharded) state dicts, which otherwise require all-gathers in FSDP1. - FSDP2 implements a different memory management approach to handle the
multi-stream usages that avoids
torch.Tensor.record_stream. This ensures deterministic and expected memory usage and does not require blocking the CPU like in FSDP1'slimit_all_gathers=True. - FSDP2 exposes APIs for manual control over prefetching and collective
scheduling, allowing power users more customization. See the methods on
FSDPModulebelow for details. - FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly
support full state dicts. Instead, users can reshard the sharded state dicts
containing
DTensors to full state dicts themselves usingDTensorAPIs likeDTensor.full_tensor()or by using higher-level APIs like PyTorch Distributed Checkpoint 's distributed state dict APIs. Also, some other args have been removed; see here for details.
.. currentmodule:: torch.distributed.fsdp
The frontend API is fully_shard that can be called on a module:
.. autofunction:: fully_shard
.. autoclass:: FSDPModule
:members:
:member-order: bysource
.. autoclass:: UnshardHandle
:members:
.. autofunction:: register_fsdp_forward_method
.. autoclass:: MixedPrecisionPolicy
:members:
.. autoclass:: OffloadPolicy
:members:
.. autoclass:: CPUOffloadPolicy
:members:
.. autofunction:: share_comm_ctx