Skip to content

Error with FSDP and FP8 Quantization on dual RTX3090 #29

@hp1337

Description

@hp1337

Describe the bug
When using FSDP with dual 3090s I get an error with FP8 quantization on Wan Video. I have Ulysses and Ring working without issue. I think there is an issue with how FSDP weights are loaded compared to Ulysses or Ring.

System

  • Ubuntu 24.04 LTS
  • PyTorch 2.9, Cuda 13
  • 2xRTX3090 with 128GB System RAM
    Workflow:
Image

Terminal Output

DPKSamplerAdvanced

�[36mray::RayWorker.common_ksampler()�[39m (pid=758017, ip=192.168.1.100, actor_id=54e32a57fd0ec08dc0c25b9401000000, repr=<raylight.distributed_worker.ray_worker.RayWorker object at 0x7e0086f46f00>)
  File "/tmp/ray/session_2025-11-03_17-13-12_108235_746278/runtime_resources/py_modules_files/_ray_pkg_776c58acf4413b7c/raylight/distributed_worker/ray_worker.py", line 347, in common_ksampler
    self.model.patch_fsdp()
  File "/tmp/ray/session_2025-11-03_17-13-12_108235_746278/runtime_resources/py_modules_files/_ray_pkg_776c58acf4413b7c/raylight/comfy_dist/fsdp_registry.py", line 93, in patch_fsdp
    self.model = FSDPShardRegistry.wrap(
                 ^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2025-11-03_17-13-12_108235_746278/runtime_resources/py_modules_files/_ray_pkg_776c58acf4413b7c/raylight/comfy_dist/fsdp_registry.py", line 44, in wrap
    return shard_func(model, fsdp_state_dict, cpu_offload)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2025-11-03_17-13-12_108235_746278/runtime_resources/py_modules_files/_ray_pkg_776c58acf4413b7c/raylight/comfy_dist/fsdp_registry.py", line 52, in _wrap_wan
    return wan_shard(model, sd, cpu_offload)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2025-11-03_17-13-12_108235_746278/runtime_resources/py_modules_files/_ray_pkg_776c58acf4413b7c/raylight/diffusion_models/wan/fsdp.py", line 30, in shard_model_fsdp2
    set_model_state_dict(
  File "/path/to/python/site-packages/torch/distributed/checkpoint/state_dict.py", line 1271, in set_model_state_dict
    return _load_model_state_dict(model, model_state_dict, info)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/python/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/python/site-packages/torch/distributed/checkpoint/state_dict.py", line 585, in _load_model_state_dict
    _broadcast_state_dict(
  File "/path/to/python/site-packages/torch/distributed/_state_dict_utils.py", line 673, in _broadcast_state_dict
    _broadcast_tensors(ret, local_state_dict, keys, device, pg)
  File "/path/to/python/site-packages/torch/distributed/_state_dict_utils.py", line 573, in _broadcast_tensors
    dist.broadcast(tensors[0], src=0, group=pg)
  File "/path/to/python/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/python/site-packages/torch/distributed/distributed_c10d.py", line 2835, in broadcast
    work = group.broadcast([tensor], opts)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:3690, invalid argument (run with NCCL_DEBUG=WARN for details), NCCL version 2.27.7
ncclInvalidArgument: Invalid value for an argument.
Last error:
FP8 reduction support begins with sm90 capable devices.

Additional context
I was able to vibe code a fix in fsdp_registry.py with Deepseek below. However it slowed performance significantly.

@classmethod
def wrap(cls, model, fsdp_state_dict=None, cpu_offload=False):
    """Find the right shard function based on model type."""
    # Convert FP8 state dict for all models if needed
    converted_state_dict = fsdp_state_dict
    if converted_state_dict is not None:
        converted_state_dict = {}
        for k, v in fsdp_state_dict.items():
            if hasattr(v, 'dtype') and v.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
                print(f"[FSDPRegistry] Converting FP8 tensor {k} to bfloat16")
                converted_state_dict[k] = v.to(torch.bfloat16)
            else:
                converted_state_dict[k] = v
    
    for registered_cls, shard_func in cls._REGISTRY.items():
        if isinstance(model, registered_cls):
            print(f"[FSDPRegistry] Wrapping {registered_cls.__name__}")
            return shard_func(model, converted_state_dict, cpu_offload)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions