Skip to content

Commit e482625

Browse files
committed
wip
1 parent 38e75cd commit e482625

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

src/twinkle/model/transformers/strategy/accelerate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ def __init__(
2121
mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
2222
ddp_config: Dict[str, Any] = None,
2323
fsdp_config: Dict[str, Any] = None,
24+
memory_efficient: bool = True,
2425
):
2526
from accelerate import Accelerator
2627

2728
self.device_mesh = device_mesh
2829
self.mixed_precision = mixed_precision
2930
parallelism_config = self._parallelism_config_from_device_mesh(device_mesh)
30-
fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config)
31+
fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient)
3132

3233
kwargs_handlers = []
3334
if ddp_config is not None:
@@ -69,7 +70,7 @@ def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh):
6970

7071
return parallelism_config
7172

72-
def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any]):
73+
def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any], memory_efficient: bool):
7374
from accelerate import FullyShardedDataParallelPlugin
7475
from torch.distributed.fsdp import BackwardPrefetch
7576
from torch.distributed.fsdp import ShardingStrategy as FSDPShardingStrategy
@@ -107,7 +108,7 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di
107108
activation_checkpointing=fsdp_config.pop('activation_checkpointing', False),
108109
auto_wrap_policy=fsdp_config.pop('auto_wrap_policy', 'transformer_based_wrap'), # noqa
109110
reshard_after_forward=fsdp_config.pop('reshard_after_forward', True),
110-
cpu_ram_efficient_loading=fsdp_config.pop('cpu_ram_efficient_loading', True),
111+
cpu_ram_efficient_loading=fsdp_config.pop('cpu_ram_efficient_loading', memory_efficient),
111112
**fsdp_config,
112113
)
113114
# The env vars (ACCELERATE_USE_FSDP, FSDP_CPU_RAM_EFFICIENT_LOADING) are set

src/twinkle/model/transformers/strategy/native_fsdp.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def wrap_model(self, model, optimizer=None, memory_efficient=True):
5959
if use_meta:
6060
original_sd = model.state_dict()
6161
saved_buffers = _get_non_persistent_buffers(model)
62+
# Drop optimizer references so old params can be freed on to('meta').
63+
# Without this, the optimizer holds strong refs to the full-size
64+
# parameter tensors, preventing GC even after the model moves to meta.
65+
# _rebind_optimizer will re-attach the new sharded params later.
66+
if optimizer is not None:
67+
_unbind_optimizer_params(optimizer)
6268
model = model.to(torch.device('meta'))
6369
if hasattr(model, 'tie_weights'):
6470
model.tie_weights()
@@ -506,6 +512,25 @@ def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]:
506512
return {k: v.clone() for k, v in model.named_buffers() if k in non_persistent_fqns}
507513

508514

515+
def _unbind_optimizer_params(optimizer: torch.optim.Optimizer) -> None:
516+
"""Replace optimizer param references with ``torch.empty(1)`` placeholders.
517+
518+
This drops the optimizer's strong references to the full model parameters,
519+
allowing them to be freed when the model is moved to ``meta`` device.
520+
Without this, ``model.to('meta')`` cannot free the old parameter tensors
521+
because the optimizer still holds references to them.
522+
523+
Must be called BEFORE ``model.to('meta')``. After ``fully_shard`` and
524+
``_broadcast_sharded_state_dict``, call ``_rebind_optimizer`` to point
525+
the optimizer at the new sharded parameters.
526+
527+
This mirrors accelerate's approach in ``Accelerator._prepare_fsdp2``.
528+
"""
529+
for group in optimizer.param_groups:
530+
for i in range(len(group['params'])):
531+
group['params'][i] = torch.empty(1)
532+
533+
509534
def _restore_non_persistent_buffers(
510535
model: nn.Module,
511536
saved_buffers: Dict[str, torch.Tensor],

src/twinkle/model/transformers/transformers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ def __init__(
196196
self.mixed_precision = mixed_precision
197197
self._fsdp_config = dict(fsdp_config or {})
198198
self._ddp_config = ddp_config or {}
199-
self._decide_strategy(strategy)
200199
self._memory_efficient_init = memory_efficient_init
200+
self._decide_strategy(strategy)
201201
self.grad_scaler_config = grad_scaler_config
202202
if isinstance(model_cls, str):
203203
model_cls = getattr(transformers, model_cls)
@@ -267,7 +267,8 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']):
267267
mixed_precision=self.mixed_precision,
268268
ddp_config=self._ddp_config,
269269
fsdp_config=self._fsdp_config,
270-
device_mesh=self.device_mesh)
270+
device_mesh=self.device_mesh,
271+
memory_efficient=self._memory_efficient_init)
271272

272273
# Sequence parallel ("ulysses") is derived from dp/fsdp ranks; it does not change world size.
273274
# We construct `sp_strategy` after the underlying HF model is initialized (see __init__).

0 commit comments

Comments
 (0)