diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py index 5df9c2e95c0..f18a21df6c1 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py @@ -175,6 +175,11 @@ def validate_uneven_dtensor(dtensor: DTensor) -> None: ) # Check that all boundaries (start and end) are touched. + # Skip under fake process group — all_reduce is a no-op so only rank 0's + # boundaries are visible, which makes the end-boundary check always fail. + if torch.distributed.is_initialized() and torch.distributed.get_backend() == 'fake': + return + boundary_checks = torch.tensor( [ [offset == 0, offset + size == dtensor.shape[dim]]