@@ -49,20 +49,14 @@ def wrap_model(self, model, optimizer=None, memory_efficient=True):
4949 if fsdp_mesh is not None :
5050 ep_enabled = (self .enable_ep and self .ep_fsdp_device_mesh is not None )
5151
52- # EP path is not yet compatible with meta-device flow because
53- # _place_ep_experts_on_local_device requires experts on a real device.
52+ # EP path requires experts on a real device, incompatible with meta-device flow.
5453 use_meta = memory_efficient and not ep_enabled
5554
56- # --- Phase 1: save state before meta move ---
5755 original_sd = None
5856 saved_buffers = None
5957 if use_meta :
6058 original_sd = model .state_dict ()
6159 saved_buffers = _get_non_persistent_buffers (model )
62- # Drop optimizer references so old params can be freed on to('meta').
63- # Without this, the optimizer holds strong refs to the full-size
64- # parameter tensors, preventing GC even after the model moves to meta.
65- # _rebind_optimizer will re-attach the new sharded params later.
6660 if optimizer is not None :
6761 _unbind_optimizer_params (optimizer )
6862 model = model .to (torch .device ('meta' ))
@@ -78,19 +72,16 @@ def wrap_model(self, model, optimizer=None, memory_efficient=True):
7872 if ep_enabled :
7973 _ensure_ep_fsdp_supported (model )
8074
81- # Collect experts map and expert params
8275 experts_map = _collect_ep_experts_map (model ) if ep_enabled else {}
8376 expert_params = _collect_expert_params (model ) if self .enable_ep else None
8477
85- # Build layer_pairs: [(layer_mod, experts_mod_or_None)]
8678 layers = _get_decoder_layers (model )
8779 layer_pairs = []
8880 if layers is not None :
8981 for layer_mod in layers :
9082 experts_mod = _find_experts_in_layer (layer_mod , experts_map )
9183 layer_pairs .append ((layer_mod , experts_mod ))
9284
93- # FSDP2 wrapping per layer
9485 world_size = self .device_mesh .world_size
9586 ep_fsdp_mesh_1d = self .ep_fsdp_device_mesh ['ep_fsdp' ] if ep_enabled else None
9687
@@ -120,7 +111,6 @@ def wrap_model(self, model, optimizer=None, memory_efficient=True):
120111 )
121112 layer_mod ._fsdp_modules .append (layer_mod )
122113
123- # Root model
124114 fully_shard (
125115 model ,
126116 mesh = fsdp_mesh ,
@@ -129,7 +119,6 @@ def wrap_model(self, model, optimizer=None, memory_efficient=True):
129119 ignored_params = expert_params ,
130120 )
131121
132- # --- Phase 2: broadcast and restore ---
133122 if use_meta :
134123 device_type = self .device_mesh .device_type or 'cuda'
135124 is_rank0 = (dist .get_rank () == 0 )
@@ -143,11 +132,9 @@ def wrap_model(self, model, optimizer=None, memory_efficient=True):
143132 if hasattr (model , 'tie_weights' ):
144133 model .tie_weights ()
145134
146- # Manual prefetch
147135 if ep_enabled and layer_pairs :
148136 _setup_manual_prefetch ([lp [0 ] for lp in layer_pairs ])
149137
150- # Rebuild groups after wrapping so grad clip sees the live Parameter objects.
151138 if ep_enabled :
152139 _rebuild_ep_param_groups (model )
153140
@@ -436,27 +423,7 @@ def _broadcast_sharded_state_dict(
436423 full_sd : dict ,
437424 device_type : str = 'cuda' ,
438425) -> None :
439- """Broadcast full state dict from rank 0 and load as sharded parameters.
440-
441- After ``fully_shard`` on a meta-device model, every rank has DTensor
442- parameters whose ``device_mesh`` and ``placements`` describe the desired
443- sharding but whose storage is still on ``meta``. This function:
444-
445- 1. Rank 0 broadcasts each full parameter tensor.
446- 2. Every rank calls ``distribute_tensor`` to materialise only its local
447- shard, then collects the results into a new state dict.
448- 3. ``model.load_state_dict(..., assign=True)`` replaces the meta tensors
449- with the real sharded ones.
450-
451- This is the twinkle equivalent of accelerate's
452- ``fsdp2_load_full_state_dict``.
453-
454- Args:
455- model: The model whose parameters are on ``meta`` after ``fully_shard``.
456- full_sd: The full (unsharded) state dict. Must be populated on rank 0;
457- may be empty (``{}``) on other ranks.
458- device_type: The device type string (e.g. ``'cuda'``, ``'npu'``).
459- """
426+ """Broadcast full state dict from rank 0 and materialise local shards via distribute_tensor."""
460427 from torch .distributed .tensor import DTensor , distribute_tensor
461428
462429 meta_sharded_sd = model .state_dict ()
@@ -476,10 +443,6 @@ def _broadcast_sharded_state_dict(
476443 full_tensor = torch .empty (shape , device = device_type , dtype = dtype )
477444
478445 dist .broadcast (full_tensor , src = 0 )
479-
480- # Ensure the async broadcast completes before we consume the tensor.
481- # Without this, NPU (and potentially other async backends) may not
482- # have finished writing full_tensor when distribute_tensor reads it.
483446 torch_util .synchronize ()
484447
485448 device_mesh = sharded_param .device_mesh
@@ -492,17 +455,7 @@ def _broadcast_sharded_state_dict(
492455
493456
494457def _get_non_persistent_buffers (model : nn .Module ) -> Dict [str , torch .Tensor ]:
495- """Return {fqn: tensor} for all non-persistent buffers in the model.
496-
497- Non-persistent buffers are not included in ``state_dict()`` and will be
498- lost when the model is moved to ``meta`` device. We need to save them
499- before the move and re-register them after broadcast.
500-
501- Uses ``module._non_persistent_buffers_set`` (the same approach as
502- accelerate's ``get_non_persistent_buffers``) for precision — directly
503- reads PyTorch's internal tracking set rather than diffing against
504- ``state_dict()`` keys.
505- """
458+ """Return {fqn: tensor} for non-persistent buffers (lost on to('meta'))."""
506459 non_persistent_fqns : Set [str ] = set ()
507460 for fqn , module in model .named_modules ():
508461 for buf_name in getattr (module , '_non_persistent_buffers_set' , set ()):
@@ -513,19 +466,7 @@ def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]:
513466
514467
515468def _unbind_optimizer_params (optimizer : torch .optim .Optimizer ) -> None :
516- """Replace optimizer param references with ``torch.empty(1)`` placeholders.
517-
518- This drops the optimizer's strong references to the full model parameters,
519- allowing them to be freed when the model is moved to ``meta`` device.
520- Without this, ``model.to('meta')`` cannot free the old parameter tensors
521- because the optimizer still holds references to them.
522-
523- Must be called BEFORE ``model.to('meta')``. After ``fully_shard`` and
524- ``_broadcast_sharded_state_dict``, call ``_rebind_optimizer`` to point
525- the optimizer at the new sharded parameters.
526-
527- This mirrors accelerate's approach in ``Accelerator._prepare_fsdp2``.
528- """
469+ """Drop optimizer param refs so model.to('meta') can free memory."""
529470 for group in optimizer .param_groups :
530471 for i in range (len (group ['params' ])):
531472 group ['params' ][i ] = torch .empty (1 )
@@ -536,13 +477,7 @@ def _restore_non_persistent_buffers(
536477 saved_buffers : Dict [str , torch .Tensor ],
537478 device : torch .device ,
538479) -> None :
539- """Re-register non-persistent buffers that were saved before ``to(meta)``.
540-
541- Args:
542- model: The model (may have meta-device buffers after sharding).
543- saved_buffers: ``{fqn: tensor}`` from ``_get_non_persistent_buffers``.
544- device: Target device for the restored buffers.
545- """
480+ """Re-register non-persistent buffers saved before to('meta')."""
546481 for fqn , buf_tensor in saved_buffers .items ():
547482 buf_tensor = buf_tensor .to (device )
548483 if '.' in fqn :
0 commit comments