Skip to content

Latest commit

 

History

History
206 lines (161 loc) · 8.8 KB

File metadata and controls

206 lines (161 loc) · 8.8 KB

torch.distributed.fsdp.fully_shard

PyTorch FSDP2 (fully_shard)

PyTorch FSDP2 (RFC) provides a fully sharded data parallelism (FSDP) implementation targeting performant eager-mode while using per-parameter sharding for improved usability

The user contract for fully_shard(model) is as follows

  • For model initialization, fully_shard converts model.parameters() from plain torch.Tensor to DTensor in-place. The parameters are moved to the appropriate device according to the device mesh.

  • Before forward and backward passes, pre-forward/backward hooks are responsible for all-gathering the parameters and converting model.parameters() from DTensor to plain torch.Tensor.

  • After forward and backward passes, post-forward/backward hooks free the unsharded parameters (no communication needed) and convert model.parameters() from plain torch.Tensor back to DTensor.

  • For the optimizer, it must be initialized with the DTensor model.parameters(), and the optimizer step should be performed on DTensor parameters.

  • Call model(input) instead of model.forward(input) to trigger pre-forward hooks to all-gather parameters. To make model.forward(input) work, users must either call model.unshard() explicitly or use register_fsdp_forward_method(model, "forward") to register the forward method for hooking.

  • fully_shard groups parameters together for a single all-gather. User should apply fully_shard in a bottom-up manner. For example, in a Transformer model, fully_shard should be applied to each layer before applying it to the root model. When applied to the root model, fully_shard excludes model.parameters() from each layer and groups the remaining parameters (e.g., embeddings, output projection) into a single all-gather group.

  • type(model) is "unioned" with FSDPModule in-place. For example, if model is originally of type nn.Linear, then fully_shard changes type(model) from nn.Linear to FSDPLinear in-place. FSDPLinear is an instance of both nn.Linear and FSDPModule. It retains all methods of nn.Linear while also exposing FSDP2-specific APIs under FSDPModule, such as reshard() and unshard().

  • Fully Qualified Names (FQNs) for parameters remain unchanged. If we call model.state_dict(), the FQNs are the same before and after applying fully_shard. This is because fully_shard does not wrap the module but only registers hooks to the original module.

Communication Grouping and Scheduling

Each call to fully_shard creates one communication group containing all parameters in the module that are not already assigned to a group from an earlier call on a submodule. Each group's parameters are all-gathered together in one collective before forward, and their gradients are reduce-scattered together in one collective after backward. Unlike DDP, FSDP2 has no bucket_cap_mb parameter — the communication boundaries are determined entirely by which modules you apply fully_shard to.

Consider a model with four submodules where a, b, c, and d denote the number of parameters in each:

model[ m1[a] -> m2[b] -> m3[c] -> m4[d] ]

If you only call fully_shard(model) (root only), all parameters are in a single group. This means the entire forward and backward look like:

all-gather(a+b+c+d) -> forward(m1,m2,m3,m4) -> backward(m4,m3,m2,m1) -> reduce-scatter(a+b+c+d)

All communication happens as two large blocking operations with no overlap with compute. This is almost never what you want.

If you apply fully_shard per submodule — for example, calling fully_shard(m2), fully_shard(m3), and then fully_shard(model) — the remaining parameters (a and d) form the root group, while m2 and m3 each get their own group.

In forward, all-gathers run on a separate CUDA stream, so the next module's all-gather can overlap with the current module's forward compute. Each module's pre-forward hook issues its own all-gather and waits for it to complete before running the module. Because the CPU typically runs ahead of the GPU, the next module's all-gather is issued on the AG stream while the current module's forward is still executing on the compute stream:

              time ──────────────────────────────────────────────►

compute:      [wait] [ fwd(m1)   | fwd(m2)    | fwd(m3,m4)     ]
AG stream:    [AG(a,d)]  [AG(b)  |    AG(c)   ]

While fwd(m1) runs on the compute stream, the CPU fires m2's pre-forward hook, which issues AG(b) on the AG stream. To make this overlap more robust (e.g. when CPU-side overhead reduces the lead), use set_modules_to_forward_prefetch to issue the next all-gather earlier — inside the current module's pre-forward hook rather than waiting for the next module's hook to fire.

In backward, FSDP2 additionally prefetches the next module's all-gather explicitly and runs reduce-scatters on a separate CUDA stream, all without any additional configuration:

              time ──────────────────────────────────────────────►

compute:      [ bwd(m4,m3)     | bwd(m2)        | bwd(m1)       ]
AG stream:    [AG(c)] [ AG(b)  |   AG(a,d)      ]
RS stream:                     |[RS(c)]  [ RS(b)|     RS(a,d)   ]

While bwd(m4,m3) runs on the compute stream, the all-gather for b (needed by m2) is prefetched on the AG stream. While bwd(m2) runs, both AG(a,d) and RS(c) overlap with compute. This pipelining is why the recommended pattern is to apply fully_shard bottom-up to each layer before applying it to the root.

To control the size of each communication group, choose which modules to wrap: wrapping more fine-grained modules produces smaller, more overlappable groups (similar to smaller DDP buckets), while wrapping fewer modules produces larger groups. There is no automatic bucketing — the grouping is explicit and determined by the module structure.

Compared to PyTorch FSDP1 (FullyShardedDataParallel):

  • FSDP2 uses DTensor-based dim-0 per-parameter sharding for a simpler sharding representation compared to FSDP1's flat-parameter sharding, while preserving similar throughput performance. More specifically, FSDP2 chunks each parameter on dim-0 across the data parallel workers (using torch.chunk(dim=0)), whereas FSDP1 flattens, concatenates, and chunks a group of tensors together, making reasoning about what data is present on each worker and resharding to different parallelisms complex. Per-parameter sharding provides a more intuitive user experience, relaxes constraints around frozen parameters, and allows for communication-free (sharded) state dicts, which otherwise require all-gathers in FSDP1.
  • FSDP2 implements a different memory management approach to handle the multi-stream usages that avoids torch.Tensor.record_stream. This ensures deterministic and expected memory usage and does not require blocking the CPU like in FSDP1's limit_all_gathers=True.
  • FSDP2 exposes APIs for manual control over prefetching and collective scheduling, allowing power users more customization. See the methods on FSDPModule below for details.
  • FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly support full state dicts. Instead, users can reshard the sharded state dicts containing DTensor s to full state dicts themselves using DTensor APIs like DTensor.full_tensor() or by using higher-level APIs like PyTorch Distributed Checkpoint 's distributed state dict APIs. Also, some other args have been removed; see here for details.
.. currentmodule:: torch.distributed.fsdp

The frontend API is fully_shard that can be called on a module:

.. autofunction:: fully_shard
.. autoclass:: FSDPModule
    :members:
    :member-order: bysource
.. autoclass:: UnshardHandle
    :members:
.. autofunction:: register_fsdp_forward_method
.. autoclass:: MixedPrecisionPolicy
    :members:
.. autoclass:: OffloadPolicy
    :members:
.. autoclass:: CPUOffloadPolicy
    :members:
.. autofunction:: share_comm_ctx