Add fused_adam, quantized_model_init, and fsdp2 example#2698
Add fused_adam, quantized_model_init, and fsdp2 example#2698vthumbe1503 merged 4 commits intoNVIDIA:mainfrom
Conversation
22604c4 to
4d89e04
Compare
Greptile SummaryThis PR adds FSDP2 ( Key findings:
Confidence Score: 4/5
Sequence DiagramsequenceDiagram
participant User
participant FSDP2 as FSDP2 (fully_shard)
participant DTensor
participant FusedAdam
participant QT as QuantizedTensor (local)
User->>FSDP2: wrap model (meta device)
FSDP2->>DTensor: wrap each param as DTensor
User->>QT: reset_parameters() → materialize + quantize local shard
Note over DTensor,QT: DTensor._local_tensor = QuantizedTensor
User->>FusedAdam: initialize_state(param=DTensor)
FusedAdam->>DTensor: param._local_tensor → QuantizedTensor
FusedAdam->>QT: dequantize() → plain float32
FusedAdam-->>FusedAdam: allocate exp_avg, exp_avg_sq (float32)
FusedAdam->>QT: dequantize(dtype=float32).clone() → master_param
loop Training step
User->>FSDP2: forward pass (all-gather)
FSDP2->>DTensor: unshard params
User->>User: loss.backward() → FSDP2 reduce-scatter gradients
User->>FusedAdam: step()
FusedAdam->>DTensor: unwrap → Float8Tensor._data
FusedAdam->>FusedAdam: CUDA kernel: update FP8 weights + master_param + moments
FSDP2->>DTensor: re-shard updated params
end
Last reviewed commit: 9565c4f |
0103b53 to
3c3dbd2
Compare
|
@pstjohn Hi, thanks for the great work! Does this PR plan to also handle the BF16 path? I noticed the BF16 branch still operates on the original p/p_grad without unwrapping when they're DTensors. In my experiments with FSDP2 + BF16, I'm seeing non-trivial overhead during the optimizer step from repeated DTensor dispatch. Curious if that's intentional or a planned follow-up. |
a4d691f to
872caef
Compare
9ccc0c3 to
eb8606a
Compare
eb8606a to
c2415e4
Compare
| pytest.xfail( | ||
| "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " | ||
| "MXFP8 quantized tensors, causing illegal memory access" | ||
| ) |
There was a problem hiding this comment.
claude's analysis:
Root cause: MXFP8Tensor inherits from QuantizedTensor but NOT from Float8Tensor. In fused_adam.py step():
- Line 617: isinstance(p, Float8Tensor) → False for MXFP8Tensor
- Line 629: p.dtype in [torch.float16, torch.bfloat16] → True (nominal dtype is bfloat16)
- So p.data (which is still an MXFP8Tensor wrapper with MXFP8-layout storage) gets added to p_f16_model
- The multi_tensor_adam CUDA kernel treats this as plain bf16 memory → illegal memory access
|
/te-ci L1 pytorch |
| if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling): | ||
| model_state = { | ||
| k: v for k, v in model.state_dict().items() if not k.endswith("_extra_state") | ||
| } |
There was a problem hiding this comment.
Can you also please comment why we should avoid saving _extra_state. As in what error we get with dcp if we dont do so?
There was a problem hiding this comment.
This one I remember, the others I'll need to comment out and run the test suite again 😅
But this is a known hassle where torch DCP needs the sizes of these tensors to remain consistent during saving & loading, and since this is pickled data, it changes when there's data in that field.
The alternative is a detailed load_planner for DelayedScaling that reads and allocates the extra state data tensor
099a3ab to
5f0ebab
Compare
Additional Comments (1)
|
| pytest.xfail( | ||
| f"async DCP save/load with {fp_recipe} produces different outputs: " | ||
| "the async staging may capture stale tensor state for FP8 scaling " | ||
| "factors, causing numerical divergence after reload" | ||
| ) |
There was a problem hiding this comment.
this isn't right -- it's actually NaNs when we load back. I'm digging on this
There was a problem hiding this comment.
-
Sync save (dcp.save): Uses _SerialCpuLoader.values() (line 131) which calls tensor.detach().cpu(). For Float8Tensor, .cpu() calls self.dequantize().cpu(), properly converting FP8 data to a plain bfloat16 tensor.
-
Async save (dcp.async_save): Uses StateDictStager._offload_tensor() which tries to deep-copy the tensor's underlying storage. Float8Tensor is a wrapper subclass (_make_wrapper_subclass) with data_ptr() == 0 (empty storage). The staging code at line 215 skips the storage copy for wrapper subclasses, creating a plain tensor with uninitialized garbage data. The actual FP8 data (in _data, _scale_inv attributes) is deep-copied but ignored by DCP when writing.
There was a problem hiding this comment.
i'll put a proposed fix for this in #2721
Additional Comments (3)
Both if fp_recipe in ("Float8BlockScaling", "MXFP8BlockScaling", "NVFP4BlockScaling"):
pytest.xfail(
f"{fp_recipe}: quantized_model_init and FSDP2 is not currently supported, since the "
"block tensor is dequantized before we flatten it for FSDP2."
)
The previous The warning message acknowledges this: "If you are using this for FSDP2 … then ignore this warning." But the message doesn't tell non-FSDP2 users that data loss has occurred. Considering that the same change is replicated in A more defensive approach would be to detect the FSDP2 call context before downgrading the error, or at least emit a
The module-level string literal describes a pending task that has already been completed: |
Additional Comments (4)
Changing from The warning message itself is also inaccurate: "this view is not going to be used anywhere" — but the function does
Same concern as in The message "this view is not going to be used anywhere" is also misleading — the result is returned and will be used by the caller. The
The trailing
|
…hard Signed-off-by: Peter St. John <pstjohn@nvidia.com>
456df05 to
d09ffb2
Compare
Additional Comments (3)
Converting In a non-FSDP2 setting, if a The same pattern is repeated in To preserve safety for non-FSDP2 callers, consider one of these approaches:
With A test with such tight convergence criteria risks flakiness. Consider one of these approaches:
The checkpoint is saved to Consider adding a cleanup step at the end of the script, after all assertions: import shutil
if int(os.environ.get("RANK", "0")) == 0:
shutil.rmtree(checkpoint_dir, ignore_errors=True)This mirrors the cleanup already present in |
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
Additional Comments (3)
The elif isinstance(value, QuantizedTensor):
fp32_state[key] = value.dequantize().float()
else:
fp32_state[key] = value.float()The test should mirror this logic to avoid silently writing incorrect FP32 weights when master weights are unavailable (e.g., bias terms that happen to be QuantizedTensors).
The change from
Consider narrowing the fallback to FSDP2 contexts specifically — e.g., by checking whether the caller is inside an FSDP-managed dispatch (or by keeping the hard error and only routing the FSDP2 codepath differently). This same pattern is repeated in
Both Consider using a unique path that incorporates the checkpoint_dir = f"/tmp/te_test_fsdp2_dcp_parity{'_async' if async_save else ''}" |
Additional Comments (4)
Consider either increasing
The comment says "Allocate main_grad buffers on the DTensor params", but the model has not been sharded yet. The parameters at this point are still plain
Although
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
|
Merging the PR since all tests pass except the one that is expected to fail on the main branch |
Summary
FusedAdamto work with PyTorch-native FSDP2 (fully_shard) when parameters areDTensor-wrappedFloat8Tensor/QuantizedTensorfuse_wgrad_accumulationguard to avoid crashing with vanilla FSDP2 (previously assumed Megatron-Core FSDP exclusively)quantized_model_initon single-GPU (main.py) and multi-GPU FSDP2 (fully_shard.py)Note:
fuse_wgrad_accumulationremains incompatible with vanilla FSDP2fuse_wgrad_accumulationstill cannot be used with vanilla FSDP2. The feature writes weight gradients directly intomain_gradand returnsNoneto autograd, bypassing FSDP2's reduce-scatter. Each rank ends up with an unreduced gradient. Megatron-Core FSDP solves this by wiringget_main_grad()into its own reduce-scatter infrastructure. Vanilla FSDP2 does not yet expose an equivalent hook.Fixes #2682