-
-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
Describe the bug
When using FSDP with dual 3090s I get an error with FP8 quantization on Wan Video. I have Ulysses and Ring working without issue. I think there is an issue with how FSDP weights are loaded compared to Ulysses or Ring.
System
- Ubuntu 24.04 LTS
- PyTorch 2.9, Cuda 13
- 2xRTX3090 with 128GB System RAM
Workflow:
Terminal Output
DPKSamplerAdvanced
�[36mray::RayWorker.common_ksampler()�[39m (pid=758017, ip=192.168.1.100, actor_id=54e32a57fd0ec08dc0c25b9401000000, repr=<raylight.distributed_worker.ray_worker.RayWorker object at 0x7e0086f46f00>)
File "/tmp/ray/session_2025-11-03_17-13-12_108235_746278/runtime_resources/py_modules_files/_ray_pkg_776c58acf4413b7c/raylight/distributed_worker/ray_worker.py", line 347, in common_ksampler
self.model.patch_fsdp()
File "/tmp/ray/session_2025-11-03_17-13-12_108235_746278/runtime_resources/py_modules_files/_ray_pkg_776c58acf4413b7c/raylight/comfy_dist/fsdp_registry.py", line 93, in patch_fsdp
self.model = FSDPShardRegistry.wrap(
^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2025-11-03_17-13-12_108235_746278/runtime_resources/py_modules_files/_ray_pkg_776c58acf4413b7c/raylight/comfy_dist/fsdp_registry.py", line 44, in wrap
return shard_func(model, fsdp_state_dict, cpu_offload)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2025-11-03_17-13-12_108235_746278/runtime_resources/py_modules_files/_ray_pkg_776c58acf4413b7c/raylight/comfy_dist/fsdp_registry.py", line 52, in _wrap_wan
return wan_shard(model, sd, cpu_offload)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2025-11-03_17-13-12_108235_746278/runtime_resources/py_modules_files/_ray_pkg_776c58acf4413b7c/raylight/diffusion_models/wan/fsdp.py", line 30, in shard_model_fsdp2
set_model_state_dict(
File "/path/to/python/site-packages/torch/distributed/checkpoint/state_dict.py", line 1271, in set_model_state_dict
return _load_model_state_dict(model, model_state_dict, info)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/python/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/path/to/python/site-packages/torch/distributed/checkpoint/state_dict.py", line 585, in _load_model_state_dict
_broadcast_state_dict(
File "/path/to/python/site-packages/torch/distributed/_state_dict_utils.py", line 673, in _broadcast_state_dict
_broadcast_tensors(ret, local_state_dict, keys, device, pg)
File "/path/to/python/site-packages/torch/distributed/_state_dict_utils.py", line 573, in _broadcast_tensors
dist.broadcast(tensors[0], src=0, group=pg)
File "/path/to/python/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/path/to/python/site-packages/torch/distributed/distributed_c10d.py", line 2835, in broadcast
work = group.broadcast([tensor], opts)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:3690, invalid argument (run with NCCL_DEBUG=WARN for details), NCCL version 2.27.7
ncclInvalidArgument: Invalid value for an argument.
Last error:
FP8 reduction support begins with sm90 capable devices.
Additional context
I was able to vibe code a fix in fsdp_registry.py with Deepseek below. However it slowed performance significantly.
@classmethod
def wrap(cls, model, fsdp_state_dict=None, cpu_offload=False):
"""Find the right shard function based on model type."""
# Convert FP8 state dict for all models if needed
converted_state_dict = fsdp_state_dict
if converted_state_dict is not None:
converted_state_dict = {}
for k, v in fsdp_state_dict.items():
if hasattr(v, 'dtype') and v.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
print(f"[FSDPRegistry] Converting FP8 tensor {k} to bfloat16")
converted_state_dict[k] = v.to(torch.bfloat16)
else:
converted_state_dict[k] = v
for registered_cls, shard_func in cls._REGISTRY.items():
if isinstance(model, registered_cls):
print(f"[FSDPRegistry] Wrapping {registered_cls.__name__}")
return shard_func(model, converted_state_dict, cpu_offload)
Metadata
Metadata
Assignees
Labels
No labels