Skip to content

Fix for async dcp checkpointing with Float8Tensors#2721

Draft
pstjohn wants to merge 3 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fix-async-dcp
Draft

Fix for async dcp checkpointing with Float8Tensors#2721
pstjohn wants to merge 3 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fix-async-dcp

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Mar 2, 2026

(includes changes from #2698)

dcp.async_save fails silently with QuantizedTensor (Float8Tensor) — staged tensors contain uninitialized (NaN) data instead of actual FP8 values.

PyTorch's async save stages tensors to CPU by copying raw storage via new_empty() + deep_copy. Float8Tensor is a wrapper subclass with data_ptr()==0 (empty storage), so:

  1. new_empty() falls through to default dispatch, returning a plain tensor instead of a Float8Tensor
  2. The deep-copied _data/_scale_inv attributes land on the plain tensor but are ignored by DCP's write path

Changes

  • quantized_tensor.py: Handle aten.new_empty.default in torch_dispatch so staging preserves the Float8Tensor subclass type
  • float8_tensor_storage.py: Add a CPU fallback in dequantize() using PyTorch native FP8 dtypes, since tex.dequantize is CUDA-only and the staged tensor lives on CPU
  • run_fsdp2_fused_adam.py: Remove the _dequantize_state_dict workaround — dcp.async_save now works transparently

pstjohn added 3 commits March 2, 2026 07:23
…hard

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant