Skip to content

Add fused_adam, quantized_model_init, and fsdp2 example#2698

Merged
vthumbe1503 merged 4 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fsdp2-fused-adam
Mar 4, 2026
Merged

Add fused_adam, quantized_model_init, and fsdp2 example#2698
vthumbe1503 merged 4 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fsdp2-fused-adam

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Feb 22, 2026

Summary

  • Fix FusedAdam to work with PyTorch-native FSDP2 (fully_shard) when parameters are DTensor-wrapped Float8Tensor/QuantizedTensor
  • Fix fuse_wgrad_accumulation guard to avoid crashing with vanilla FSDP2 (previously assumed Megatron-Core FSDP exclusively)
  • Add examples for quantized_model_init on single-GPU (main.py) and multi-GPU FSDP2 (fully_shard.py)

Note: fuse_wgrad_accumulation remains incompatible with vanilla FSDP2

fuse_wgrad_accumulation still cannot be used with vanilla FSDP2. The feature writes weight gradients directly into main_grad and returns None to autograd, bypassing FSDP2's reduce-scatter. Each rank ends up with an unreduced gradient. Megatron-Core FSDP solves this by wiring get_main_grad() into its own reduce-scatter infrastructure. Vanilla FSDP2 does not yet expose an equivalent hook.

Fixes #2682

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch 2 times, most recently from 22604c4 to 4d89e04 Compare February 23, 2026 15:28
@pstjohn pstjohn marked this pull request as ready for review February 23, 2026 17:27
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 23, 2026

Greptile Summary

This PR adds FSDP2 (fully_shard) compatibility to FusedAdam with DTensor-wrapped Float8Tensor/QuantizedTensor parameters, introduces __getstate__ pickle fixes for quantizer process groups, and provides new single-GPU (main.py) and multi-GPU (fully_shard.py) quantized_model_init examples along with a comprehensive FSDP2+FusedAdam test suite.

Key findings:

  • fused_adam.py: Correctly unwraps DTensor._local_tensor before dequantizing QuantizedTensor in both initialize_state and master-weight initialization. The fix is targeted and logically sound.
  • float8_blockwise_tensor.py / nvfp4_tensor.py: The change from RuntimeError to warnings.warn + silent dequantize() fallback in _ViewFunc and _ReshapeFunc is required for FSDP2 internal parameter-flattening. The warnings explicitly mention the FSDP2 context.
  • run_fsdp2_fused_adam.py: The test_fused_adam_bf16 loss-decrease assertion (assert losses[-1] < losses[0]) may be flaky with only 3 training steps on random data. The test_fuse_wgrad_accumulation function allocates main_grad buffers with incorrect shape before FSDP2 sharding; while the test correctly fails during backward before main_grad is consumed, this logical mismatch could confuse future maintainers. The misleading comment at line 343 incorrectly claims params are already DTensors.
  • __getstate__ additions: Correctly exclude unpicklable amax_reduction_group from pickle state across three quantizer classes.
  • fully_shard.py example: The checkpoint directory /tmp/te_fsdp2_example_checkpoint is never cleaned up between runs.

Confidence Score: 4/5

  • Core optimizer fix is correct and targeted. Remaining findings are concrete but non-critical: flaky test assertion, misleading comments, and cleanup issues.
  • The central FSDP2+FusedAdam fix (DTensor unwrapping before dequantization) is well-designed and verified. The behavioral change in float8_blockwise_tensor.py and nvfp4_tensor.py (RuntimeError→warning) is necessary for FSDP2's parameter-flattening and explicitly documented. The remaining issues are all concrete and fixable: the loss-decrease assertion can flake with insufficient training steps, comments can be corrected to accurately reflect the code, the main_grad shape mismatch is a logical inconsistency in an explicitly expected-to-fail test, and the checkpoint directory should be cleaned up. None of these are blocking problems, and the test suite provides good coverage of the new functionality.
  • tests/pytorch/distributed/run_fsdp2_fused_adam.py (flaky assertion + comment clarity + logical shape issue) and examples/pytorch/quantized_model_init/fully_shard.py (cleanup).

Sequence Diagram

sequenceDiagram
    participant User
    participant FSDP2 as FSDP2 (fully_shard)
    participant DTensor
    participant FusedAdam
    participant QT as QuantizedTensor (local)

    User->>FSDP2: wrap model (meta device)
    FSDP2->>DTensor: wrap each param as DTensor

    User->>QT: reset_parameters() → materialize + quantize local shard
    Note over DTensor,QT: DTensor._local_tensor = QuantizedTensor

    User->>FusedAdam: initialize_state(param=DTensor)
    FusedAdam->>DTensor: param._local_tensor → QuantizedTensor
    FusedAdam->>QT: dequantize() → plain float32
    FusedAdam-->>FusedAdam: allocate exp_avg, exp_avg_sq (float32)
    FusedAdam->>QT: dequantize(dtype=float32).clone() → master_param

    loop Training step
        User->>FSDP2: forward pass (all-gather)
        FSDP2->>DTensor: unshard params
        User->>User: loss.backward() → FSDP2 reduce-scatter gradients
        User->>FusedAdam: step()
        FusedAdam->>DTensor: unwrap → Float8Tensor._data
        FusedAdam->>FusedAdam: CUDA kernel: update FP8 weights + master_param + moments
        FSDP2->>DTensor: re-shard updated params
    end
Loading

Last reviewed commit: 9565c4f

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 13 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Member

@cspades cspades left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, clean edits.

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from 0103b53 to 3c3dbd2 Compare February 24, 2026 20:06
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@XueSongTap
Copy link

@pstjohn Hi, thanks for the great work! Does this PR plan to also handle the BF16 path? I noticed the BF16 branch still operates on the original p/p_grad without unwrapping when they're DTensors. In my experiments with FSDP2 + BF16, I'm seeing non-trivial overhead during the optimizer step from repeated DTensor dispatch. Curious if that's intentional or a planned follow-up.

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from a4d691f to 872caef Compare February 26, 2026 15:11
@pstjohn pstjohn marked this pull request as draft February 26, 2026 20:08
@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch 4 times, most recently from 9ccc0c3 to eb8606a Compare February 26, 2026 21:55
@pstjohn pstjohn marked this pull request as ready for review February 26, 2026 21:56
@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from eb8606a to c2415e4 Compare February 26, 2026 22:50
Comment on lines +167 to +170
pytest.xfail(
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
"MXFP8 quantized tensors, causing illegal memory access"
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

claude's analysis:

Root cause: MXFP8Tensor inherits from QuantizedTensor but NOT from Float8Tensor. In fused_adam.py step():

  1. Line 617: isinstance(p, Float8Tensor) → False for MXFP8Tensor
  2. Line 629: p.dtype in [torch.float16, torch.bfloat16] → True (nominal dtype is bfloat16)
  3. So p.data (which is still an MXFP8Tensor wrapper with MXFP8-layout storage) gets added to p_f16_model
  4. The multi_tensor_adam CUDA kernel treats this as plain bf16 memory → illegal memory access

@vthumbe1503
Copy link
Collaborator

/te-ci L1 pytorch

vthumbe1503
vthumbe1503 previously approved these changes Feb 28, 2026
Copy link
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean PR and great work. Left a few minor comments. LGTM post CI success.

if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling):
model_state = {
k: v for k, v in model.state_dict().items() if not k.endswith("_extra_state")
}
Copy link
Collaborator

@vthumbe1503 vthumbe1503 Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also please comment why we should avoid saving _extra_state. As in what error we get with dcp if we dont do so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one I remember, the others I'll need to comment out and run the test suite again 😅

But this is a known hassle where torch DCP needs the sizes of these tensors to remain consistent during saving & loading, and since this is pickled data, it changes when there's data in that field.

The alternative is a detailed load_planner for DelayedScaling that reads and allocates the extra state data tensor

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 2, 2026

Additional Comments (1)

tests/pytorch/distributed/test_torch_fsdp2.py, line 84
default value uses snake_case but run_fsdp2_fused_adam.py argparse expects PascalCase. change to "DelayedScaling" for consistency

def _run_fused_adam_test(test_name, recipe="DelayedScaling"):

Comment on lines +150 to +154
pytest.xfail(
f"async DCP save/load with {fp_recipe} produces different outputs: "
"the async staging may capture stale tensor state for FP8 scaling "
"factors, causing numerical divergence after reload"
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't right -- it's actually NaNs when we load back. I'm digging on this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Sync save (dcp.save): Uses _SerialCpuLoader.values() (line 131) which calls tensor.detach().cpu(). For Float8Tensor, .cpu() calls self.dequantize().cpu(), properly converting FP8 data to a plain bfloat16 tensor.

  • Async save (dcp.async_save): Uses StateDictStager._offload_tensor() which tries to deep-copy the tensor's underlying storage. Float8Tensor is a wrapper subclass (_make_wrapper_subclass) with data_ptr() == 0 (empty storage). The staging code at line 215 skips the storage copy for wrapper subclasses, creating a plain tensor with uninitialized garbage data. The actual FP8 data (in _data, _scale_inv attributes) is deep-copied but ignored by DCP when writing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll put a proposed fix for this in #2721

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (3)

tests/pytorch/distributed/test_torch_fsdp2.py, line 176
Missing xfail for Float8BlockScaling and NVFP4BlockScaling

Both test_fsdp2_dcp_output_parity (line 136) and test_fsdp2_safetensors_fp32_export (line 169) call _run_fused_adam_test which internally uses _build_model(fp8_init=True) combined with a master_weights=True optimizer — the exact same configuration that is xfail-ed in test_fsdp2_fused_adam_fp8_master_weights (line 104–108) for Float8BlockScaling, MXFP8BlockScaling, and NVFP4BlockScaling:

if fp_recipe in ("Float8BlockScaling", "MXFP8BlockScaling", "NVFP4BlockScaling"):
    pytest.xfail(
        f"{fp_recipe}: quantized_model_init and FSDP2 is not currently supported, since the "
        "block tensor is dequantized before we flatten it for FSDP2."
    )

MXFP8BlockScaling is already xfail-ed in both tests, but Float8BlockScaling and NVFP4BlockScaling are missing. Without these guards, those parametrized test cases are expected to fail unexpectedly in CI instead of being marked as known failures.


transformer_engine/pytorch/tensor/float8_blockwise_tensor.py, line 628
Silent dequantization in view/reshape forward may mask real errors outside FSDP2

The previous RuntimeError was a hard guard against callers applying dimension-incompatible views to a quantized tensor — a contract violation that would silently corrupt the tensor type. The replacement warning + dequantize fallback is necessary for FSDP2 (where the result is indeed discarded), but in any non-FSDP2 context the caller now silently receives a plain float tensor instead of a Float8BlockwiseQTensor, without knowing the quantization was lost.

The warning message acknowledges this: "If you are using this for FSDP2 … then ignore this warning." But the message doesn't tell non-FSDP2 users that data loss has occurred. Considering that the same change is replicated in nvfp4_tensor.py, this affects two tensor types.

A more defensive approach would be to detect the FSDP2 call context before downgrading the error, or at least emit a RuntimeWarning (instead of the default UserWarning) and document the data-loss risk in the message so non-FSDP2 callers are not silently affected.


tests/pytorch/distributed/test_torch_fsdp2.py, line 208
Stale TODO comment — async DCP tests already added

The module-level string literal describes a pending task that has already been completed: test_fsdp2_dcp_output_parity_async was added in this same PR. The orphaned comment should be removed.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (4)

transformer_engine/pytorch/tensor/float8_blockwise_tensor.py, line 636
Silent dequantization affects all callers, not just FSDP2

Changing from RuntimeError to UserWarning + dequantize fallback applies to every call-site of _ViewFunc.forward and _ReshapeFunc.forward, not just FSDP2 internals. Non-FSDP2 code that accidentally violates the last-dimension constraint will now silently receive a dequantized float32 tensor instead of a hard error, losing FP8 precision with no notice beyond a warning.

The warning message itself is also inaccurate: "this view is not going to be used anywhere" — but the function does return tensor.dequantize().view(*shape), so the result is returned and used by the caller. Consider adding a note like "the returned dequantized tensor will be used as-is by the caller" instead, or differentiating between FSDP2-internal paths and end-user paths if a more robust solution is desired.


transformer_engine/pytorch/tensor/nvfp4_tensor.py, line 820
Silent dequantization affects all callers, not just FSDP2

Same concern as in float8_blockwise_tensor.py: changing from RuntimeError to UserWarning + dequantize().view() / dequantize().reshape() makes this fallback active for all callers, not just FSDP2 internals. Code that legitimately should fail on an unsupported inner-dimension reshape will now silently get dequantized results with only a warning.

The message "this view is not going to be used anywhere" is also misleading — the result is returned and will be used by the caller. The _ViewFunc forward and _ReshapeFunc forward both fall into this same pattern (lines 807–820 and lines 933–946).


tests/pytorch/distributed/test_torch_fsdp2.py, line 1439
TODO block is a bare string literal, not a comment

The trailing """TODO: ...""" at the end of the file is a module-level string expression, not a Python comment. It is a no-op (discarded at runtime) but won't be picked up by tooling that searches for # TODO markers, and it is non-idiomatic. Replace with a standard comment:

# TODO:
#  - async DCP tests

tests/pytorch/distributed/run_fsdp2_fused_adam.py, line 807
dist.destroy_process_group() is unreachable in the expected-failure path

loss.backward() is documented in the function body as "Expected to raise AttributeError". If it does raise, dist.destroy_process_group() on the next line is never called. Since this function runs inside a torchrun subprocess that will exit abnormally anyway, this does not affect correctness; however, it could confuse readers into thinking the process group is always cleaned up. Consider wrapping the backward in a try/finally for clarity, or adding a comment noting that cleanup is intentionally skipped on expected failure.

@pstjohn pstjohn marked this pull request as draft March 3, 2026 14:07
…hard

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from 456df05 to d09ffb2 Compare March 3, 2026 14:09
@pstjohn pstjohn marked this pull request as ready for review March 3, 2026 15:01
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (3)

transformer_engine/pytorch/tensor/float8_blockwise_tensor.py, line 646
Silent dequantize fallback silences legitimate non-FSDP2 errors

Converting RuntimeError to a warnings.warn + dequantize().view() fallback is a meaningful behavioral regression for non-FSDP2 callers. The warning message explicitly says "ignore this warning since this view is not going to be used anywhere" — but that advice is only accurate inside FSDP2's internal reshape path. There is no guard here to confirm we're in an FSDP2 context.

In a non-FSDP2 setting, if a Float8BlockwiseQTensor is accidentally passed an incompatible shape, the forward now silently falls back to a plain (dequantized) tensor. That plain tensor continues through the rest of the model without error, losing all FP8 compute benefits. Previously the RuntimeError would have caught this immediately.

The same pattern is repeated in _ReshapeFunc.forward (lines 738–758) and in nvfp4_tensor.py _ViewFunc/_ReshapeFunc (all four sites have the same issue).

To preserve safety for non-FSDP2 callers, consider one of these approaches:

  • Propagate a context flag (e.g., thread-local or module-level _in_fsdp2_unshard context manager) and only suppress the error inside that context
  • Mark the dequantize fallback result with a flag that causes a hard error if it is subsequently used in a quantized compute kernel, making the "this will not be used" assumption explicit and enforceable

tests/pytorch/distributed/run_fsdp2_fused_adam.py, line 221
Loss-decrease assertion is brittle with only 3 training steps

With NUM_STEPS = 3 and target = torch.randn_like(x) (a fresh random target each run), a monotonically decreasing loss over just three steps is not guaranteed. The RNG is seeded once in _setup(), but since the same seed is shared across all test functions in the same process, different test orderings will leave the RNG in different states — making the actual target values non-deterministic across different invocation orders.

A test with such tight convergence criteria risks flakiness.

Consider one of these approaches:

  • Relax the assertion to assert losses[-1] <= losses[0] + epsilon with a small tolerance
  • Use a fixed, non-random target (e.g., torch.zeros_like(x)) to guarantee convergence in a few steps
  • Increase NUM_STEPS to give the optimizer sufficient iterations to reduce loss regardless of starting point

examples/pytorch/quantized_model_init/fully_shard.py, line 195
Checkpoint directory not cleaned up on success

The checkpoint is saved to /tmp/te_fsdp2_example_checkpoint but never removed. After each successful run the directory accumulates on disk. On repeated runs (e.g., in CI), stale shards from a previous run may conflict with or be silently overwritten by DCP.

Consider adding a cleanup step at the end of the script, after all assertions:

import shutil
if int(os.environ.get("RANK", "0")) == 0:
    shutil.rmtree(checkpoint_dir, ignore_errors=True)

This mirrors the cleanup already present in run_fsdp2_fused_adam.py lines 615–618.

@vthumbe1503
Copy link
Collaborator

/te-ci L1 pytorch

@vthumbe1503
Copy link
Collaborator

/te-ci L1 pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (3)

tests/pytorch/distributed/run_fsdp2_fused_adam.py, line 869
QuantizedTensor not dequantized before .float() cast

value.float() is called unconditionally on full_model_state values. However, when a parameter is a QuantizedTensor (FP8 weight), .float() may not correctly dequantize its internal representation — calling .float() on a QuantizedTensor does not go through the proper dequantization path.

The fully_shard.py example (lines 249–251) handles this correctly with an explicit isinstance check:

elif isinstance(value, QuantizedTensor):
    fp32_state[key] = value.dequantize().float()
else:
    fp32_state[key] = value.float()

The test should mirror this logic to avoid silently writing incorrect FP32 weights when master weights are unavailable (e.g., bias terms that happen to be QuantizedTensors).


transformer_engine/pytorch/tensor/float8_blockwise_tensor.py, line 646
Silent type-change fallback may mask non-FSDP2 bugs

The change from RuntimeError to warnings.warn + dequantize fallback is a broad behavioral change. Previously, any caller (FSDP2 or otherwise) that attempted an incompatible view/reshape got a hard error that surfaced immediately. Now, all callers silently receive a plain float32 tensor in place of the expected Float8BlockwiseQTensor. This could:

  1. Mask bugs in non-FSDP2 paths: code that relied on the RuntimeError to catch invalid views will now silently continue with a dequantized tensor and incorrect precision.
  2. Confuse the warning reader: the message "If you are using this for FSDP2 without compiled_autograd_enabled, then ignore this warning since this view is not going to be used anywhere" is technically correct only for FSDP2. For non-FSDP2 callers the dequantized output is used downstream, so silently accepting it is wrong.

Consider narrowing the fallback to FSDP2 contexts specifically — e.g., by checking whether the caller is inside an FSDP-managed dispatch (or by keeping the hard error and only routing the FSDP2 codepath differently). This same pattern is repeated in _ReshapeFunc.forward (lines 738–761) and in nvfp4_tensor.py _ViewFunc/_ReshapeFunc.


tests/pytorch/distributed/run_fsdp2_fused_adam.py, line 931
Hardcoded checkpoint path shared between sync and async tests

Both test_dcp_output_parity(async_save=False) and test_dcp_output_parity(async_save=True) hardcode the same checkpoint directory /tmp/te_test_fsdp2_dcp_parity. If these tests run concurrently (e.g., via pytest-xdist or when parametrized over multiple recipes), one test's save can overwrite or corrupt the other's. Only rank 0 cleans up at the end, leaving a small race window.

Consider using a unique path that incorporates the async_save flag (or a tempfile.mkdtemp()-style approach) to prevent collisions:

checkpoint_dir = f"/tmp/te_test_fsdp2_dcp_parity{'_async' if async_save else ''}"

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (4)

tests/pytorch/distributed/run_fsdp2_fused_adam.py, line 221
Flaky loss-decrease assertion with only 3 training steps

assert losses[-1] < losses[0] is not guaranteed to hold with only NUM_STEPS = 3 steps of Adam, random synthetic data, and random targets. While the global seed (torch.manual_seed(42)) makes results reproducible on a given platform, floating-point non-determinism across GPU architectures or CUDA versions can cause this assertion to fail, making the test brittle.

Consider either increasing NUM_STEPS (e.g., to 10+) or replacing this assertion with a looser sanity check that does not depend on strict monotonic decrease.


tests/pytorch/distributed/run_fsdp2_fused_adam.py, line 345
Misleading comment — params are not DTensors at this point

The comment says "Allocate main_grad buffers on the DTensor params", but the model has not been sharded yet. The parameters at this point are still plain nn.Parameter / QuantizedTensor objects — they only become DTensor-wrapped after _shard_model() is called on line 347.

    # Allocate main_grad buffers on the (pre-shard) params.
    # _shard_model() will restore these via save_custom_attrs/restore_custom_attrs.

tests/pytorch/distributed/run_fsdp2_fused_adam.py, line 347
main_grad shape mismatch after FSDP2 sharding

main_grad buffers are allocated with the full (unsharded) parameter shape before _shard_model() on line 344-345. After FSDP2 shards the model, the DTensor's local shard has a smaller shape (shape / world_size along the sharding dimension), but restore_custom_attrs blindly re-attaches the full-shape main_grad buffer to the DTensor parameter. This means the main_grad tensor will have the wrong shape for the local shard.

Although test_fuse_wgrad_accumulation is expected to fail during loss.backward() before main_grad would actually be consumed by the optimizer, this setup is logically incorrect and could mislead anyone trying to extend or fix the test later. The main_grad allocation should happen after sharding (with sharded shapes), or through the same save_custom_attrs/restore_custom_attrs mechanism.


examples/pytorch/quantized_model_init/fully_shard.py, line 263
No cleanup of checkpoint directory between runs

checkpoint_dir = "/tmp/te_fsdp2_example_checkpoint" is never removed at the end of the script. Running the example a second time will succeed (DCP overwrites the checkpoint), but the stale directory may accumulate if the script is run many times, or it might confuse readers who inspect /tmp manually. A simple shutil.rmtree(checkpoint_dir, ignore_errors=True) call at the end (on rank 0) would keep the environment clean.

    dist.destroy_process_group()

    import shutil
    if int(os.environ.get("RANK", "0")) == 0:
        shutil.rmtree("/tmp/te_fsdp2_example_checkpoint", ignore_errors=True)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@vthumbe1503
Copy link
Collaborator

Merging the PR since all tests pass except the one that is expected to fail on the main branch

@vthumbe1503 vthumbe1503 merged commit 139c863 into NVIDIA:main Mar 4, 2026
8 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Example of quantized_model_init for low-precision compute weights, fp32 main weights using fused_adam with fsdp2

4 participants