Skip to content

Commit 00fd199

Browse files
committed
clean
1 parent e482625 commit 00fd199

File tree

3 files changed

+8
-89
lines changed

3 files changed

+8
-89
lines changed

src/twinkle/model/transformers/strategy/accelerate.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh):
7070

7171
return parallelism_config
7272

73-
def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any], memory_efficient: bool):
73+
def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any],
74+
memory_efficient: bool):
7475
from accelerate import FullyShardedDataParallelPlugin
7576
from torch.distributed.fsdp import BackwardPrefetch
7677
from torch.distributed.fsdp import ShardingStrategy as FSDPShardingStrategy
@@ -111,10 +112,6 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di
111112
cpu_ram_efficient_loading=fsdp_config.pop('cpu_ram_efficient_loading', memory_efficient),
112113
**fsdp_config,
113114
)
114-
# The env vars (ACCELERATE_USE_FSDP, FSDP_CPU_RAM_EFFICIENT_LOADING) are set
115-
# in TransformersModel.__init__ before from_pretrained, and the plugin's
116-
# __post_init__ also sets FSDP_CPU_RAM_EFFICIENT_LOADING when
117-
# cpu_ram_efficient_loading=True.
118115
return fsdp_plugin
119116

120117
def wrap_model(self, model, *args):

src/twinkle/model/transformers/strategy/native_fsdp.py

Lines changed: 5 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,14 @@ def wrap_model(self, model, optimizer=None, memory_efficient=True):
4949
if fsdp_mesh is not None:
5050
ep_enabled = (self.enable_ep and self.ep_fsdp_device_mesh is not None)
5151

52-
# EP path is not yet compatible with meta-device flow because
53-
# _place_ep_experts_on_local_device requires experts on a real device.
52+
# EP path requires experts on a real device, incompatible with meta-device flow.
5453
use_meta = memory_efficient and not ep_enabled
5554

56-
# --- Phase 1: save state before meta move ---
5755
original_sd = None
5856
saved_buffers = None
5957
if use_meta:
6058
original_sd = model.state_dict()
6159
saved_buffers = _get_non_persistent_buffers(model)
62-
# Drop optimizer references so old params can be freed on to('meta').
63-
# Without this, the optimizer holds strong refs to the full-size
64-
# parameter tensors, preventing GC even after the model moves to meta.
65-
# _rebind_optimizer will re-attach the new sharded params later.
6660
if optimizer is not None:
6761
_unbind_optimizer_params(optimizer)
6862
model = model.to(torch.device('meta'))
@@ -78,19 +72,16 @@ def wrap_model(self, model, optimizer=None, memory_efficient=True):
7872
if ep_enabled:
7973
_ensure_ep_fsdp_supported(model)
8074

81-
# Collect experts map and expert params
8275
experts_map = _collect_ep_experts_map(model) if ep_enabled else {}
8376
expert_params = _collect_expert_params(model) if self.enable_ep else None
8477

85-
# Build layer_pairs: [(layer_mod, experts_mod_or_None)]
8678
layers = _get_decoder_layers(model)
8779
layer_pairs = []
8880
if layers is not None:
8981
for layer_mod in layers:
9082
experts_mod = _find_experts_in_layer(layer_mod, experts_map)
9183
layer_pairs.append((layer_mod, experts_mod))
9284

93-
# FSDP2 wrapping per layer
9485
world_size = self.device_mesh.world_size
9586
ep_fsdp_mesh_1d = self.ep_fsdp_device_mesh['ep_fsdp'] if ep_enabled else None
9687

@@ -120,7 +111,6 @@ def wrap_model(self, model, optimizer=None, memory_efficient=True):
120111
)
121112
layer_mod._fsdp_modules.append(layer_mod)
122113

123-
# Root model
124114
fully_shard(
125115
model,
126116
mesh=fsdp_mesh,
@@ -129,7 +119,6 @@ def wrap_model(self, model, optimizer=None, memory_efficient=True):
129119
ignored_params=expert_params,
130120
)
131121

132-
# --- Phase 2: broadcast and restore ---
133122
if use_meta:
134123
device_type = self.device_mesh.device_type or 'cuda'
135124
is_rank0 = (dist.get_rank() == 0)
@@ -143,11 +132,9 @@ def wrap_model(self, model, optimizer=None, memory_efficient=True):
143132
if hasattr(model, 'tie_weights'):
144133
model.tie_weights()
145134

146-
# Manual prefetch
147135
if ep_enabled and layer_pairs:
148136
_setup_manual_prefetch([lp[0] for lp in layer_pairs])
149137

150-
# Rebuild groups after wrapping so grad clip sees the live Parameter objects.
151138
if ep_enabled:
152139
_rebuild_ep_param_groups(model)
153140

@@ -436,27 +423,7 @@ def _broadcast_sharded_state_dict(
436423
full_sd: dict,
437424
device_type: str = 'cuda',
438425
) -> None:
439-
"""Broadcast full state dict from rank 0 and load as sharded parameters.
440-
441-
After ``fully_shard`` on a meta-device model, every rank has DTensor
442-
parameters whose ``device_mesh`` and ``placements`` describe the desired
443-
sharding but whose storage is still on ``meta``. This function:
444-
445-
1. Rank 0 broadcasts each full parameter tensor.
446-
2. Every rank calls ``distribute_tensor`` to materialise only its local
447-
shard, then collects the results into a new state dict.
448-
3. ``model.load_state_dict(..., assign=True)`` replaces the meta tensors
449-
with the real sharded ones.
450-
451-
This is the twinkle equivalent of accelerate's
452-
``fsdp2_load_full_state_dict``.
453-
454-
Args:
455-
model: The model whose parameters are on ``meta`` after ``fully_shard``.
456-
full_sd: The full (unsharded) state dict. Must be populated on rank 0;
457-
may be empty (``{}``) on other ranks.
458-
device_type: The device type string (e.g. ``'cuda'``, ``'npu'``).
459-
"""
426+
"""Broadcast full state dict from rank 0 and materialise local shards via distribute_tensor."""
460427
from torch.distributed.tensor import DTensor, distribute_tensor
461428

462429
meta_sharded_sd = model.state_dict()
@@ -476,10 +443,6 @@ def _broadcast_sharded_state_dict(
476443
full_tensor = torch.empty(shape, device=device_type, dtype=dtype)
477444

478445
dist.broadcast(full_tensor, src=0)
479-
480-
# Ensure the async broadcast completes before we consume the tensor.
481-
# Without this, NPU (and potentially other async backends) may not
482-
# have finished writing full_tensor when distribute_tensor reads it.
483446
torch_util.synchronize()
484447

485448
device_mesh = sharded_param.device_mesh
@@ -492,17 +455,7 @@ def _broadcast_sharded_state_dict(
492455

493456

494457
def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]:
495-
"""Return {fqn: tensor} for all non-persistent buffers in the model.
496-
497-
Non-persistent buffers are not included in ``state_dict()`` and will be
498-
lost when the model is moved to ``meta`` device. We need to save them
499-
before the move and re-register them after broadcast.
500-
501-
Uses ``module._non_persistent_buffers_set`` (the same approach as
502-
accelerate's ``get_non_persistent_buffers``) for precision — directly
503-
reads PyTorch's internal tracking set rather than diffing against
504-
``state_dict()`` keys.
505-
"""
458+
"""Return {fqn: tensor} for non-persistent buffers (lost on to('meta'))."""
506459
non_persistent_fqns: Set[str] = set()
507460
for fqn, module in model.named_modules():
508461
for buf_name in getattr(module, '_non_persistent_buffers_set', set()):
@@ -513,19 +466,7 @@ def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]:
513466

514467

515468
def _unbind_optimizer_params(optimizer: torch.optim.Optimizer) -> None:
516-
"""Replace optimizer param references with ``torch.empty(1)`` placeholders.
517-
518-
This drops the optimizer's strong references to the full model parameters,
519-
allowing them to be freed when the model is moved to ``meta`` device.
520-
Without this, ``model.to('meta')`` cannot free the old parameter tensors
521-
because the optimizer still holds references to them.
522-
523-
Must be called BEFORE ``model.to('meta')``. After ``fully_shard`` and
524-
``_broadcast_sharded_state_dict``, call ``_rebind_optimizer`` to point
525-
the optimizer at the new sharded parameters.
526-
527-
This mirrors accelerate's approach in ``Accelerator._prepare_fsdp2``.
528-
"""
469+
"""Drop optimizer param refs so model.to('meta') can free memory."""
529470
for group in optimizer.param_groups:
530471
for i in range(len(group['params'])):
531472
group['params'][i] = torch.empty(1)
@@ -536,13 +477,7 @@ def _restore_non_persistent_buffers(
536477
saved_buffers: Dict[str, torch.Tensor],
537478
device: torch.device,
538479
) -> None:
539-
"""Re-register non-persistent buffers that were saved before ``to(meta)``.
540-
541-
Args:
542-
model: The model (may have meta-device buffers after sharding).
543-
saved_buffers: ``{fqn: tensor}`` from ``_get_non_persistent_buffers``.
544-
device: Target device for the restored buffers.
545-
"""
480+
"""Re-register non-persistent buffers saved before to('meta')."""
546481
for fqn, buf_tensor in saved_buffers.items():
547482
buf_tensor = buf_tensor.to(device)
548483
if '.' in fqn:

src/twinkle/model/transformers/transformers.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,7 @@ def __init__(
205205
self.model = model_cls.from_config(config, **kwargs)
206206
else:
207207
model_id = HubOperation.download_model(model_id)
208-
# Memory-efficient init: set env vars so transformers' from_pretrained
209-
# uses its built-in FSDP-aware loading path.
210-
# When is_fsdp_enabled() returns True inside transformers:
211-
# - All ranks: model created on meta device
212-
# - Rank 0: loads real weights from disk
213-
# - Non-rank-0: replaces params with torch.empty_like (no disk I/O)
214-
# This works for BOTH strategies:
215-
# - NativeFSDPStrategy: wrap_model does meta → broadcast (Task 4)
216-
# - AccelerateStrategy: accelerator.prepare() → fsdp2_prepare_model()
217-
# does its own meta → broadcast (accelerate built-in)
208+
# Trigger transformers' FSDP-aware loading: meta-device init + rank-0-only weight load.
218209
use_efficient_loading = (memory_efficient_init and self.device_mesh is not None)
219210
_saved_env = {}
220211
if use_efficient_loading:
@@ -225,10 +216,6 @@ def __init__(
225216
try:
226217
self.model = model_cls.from_pretrained(model_id, config=config, **kwargs)
227218
finally:
228-
# Restore original env vars to avoid polluting other code paths.
229-
# For AccelerateStrategy, Accelerator.__init__ already sets
230-
# ACCELERATE_USE_FSDP=true when fsdp_plugin is provided, so
231-
# restoring here is safe — accelerate will re-set it as needed.
232219
if use_efficient_loading:
233220
for key, old_val in _saved_env.items():
234221
if old_val is None:

0 commit comments

Comments
 (0)