Skip to content

Commit 9d97d84

Browse files
committed
wip
1 parent 9158465 commit 9d97d84

File tree

5 files changed

+40
-19
lines changed

5 files changed

+40
-19
lines changed

src/twinkle/model/transformers/strategy/accelerate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2-
import os
32
from typing import Any, Dict, Literal, Optional
43

54
from twinkle import DeviceMesh
5+
from .load_context import fsdp_pretrained_load_context
66

77

88
class AccelerateStrategy:
@@ -27,6 +27,7 @@ def __init__(
2727

2828
self.device_mesh = device_mesh
2929
self.mixed_precision = mixed_precision
30+
self._memory_efficient_init = memory_efficient_init
3031
parallelism_config = self._parallelism_config_from_device_mesh(device_mesh)
3132
fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient_init)
3233

@@ -43,6 +44,9 @@ def __init__(
4344
kwargs_handlers=kwargs_handlers,
4445
)
4546

47+
def pretrained_load_context(self):
48+
return fsdp_pretrained_load_context(self._memory_efficient_init and self.device_mesh is not None)
49+
4650
@staticmethod
4751
def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh):
4852
# TODO should test with transformers v5.0
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import contextlib
3+
import os
4+
5+
_FSDP_EFFICIENT_LOADING_ENV = {
6+
'ACCELERATE_USE_FSDP': 'true',
7+
'FSDP_CPU_RAM_EFFICIENT_LOADING': 'true',
8+
}
9+
10+
11+
@contextlib.contextmanager
12+
def fsdp_pretrained_load_context(enabled: bool):
13+
"""Enable the env flags required for transformers FSDP-aware loading when needed."""
14+
if not enabled:
15+
yield
16+
return
17+
18+
saved_env = {key: os.environ.get(key) for key in _FSDP_EFFICIENT_LOADING_ENV}
19+
os.environ.update(_FSDP_EFFICIENT_LOADING_ENV)
20+
try:
21+
yield
22+
finally:
23+
for key, old_val in saved_env.items():
24+
if old_val is None:
25+
os.environ.pop(key, None)
26+
else:
27+
os.environ[key] = old_val

src/twinkle/model/transformers/strategy/native_fsdp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set
88

99
from twinkle.utils import DeviceMesh, Platform, torch_util
10+
from .load_context import fsdp_pretrained_load_context
1011

1112
if TYPE_CHECKING:
1213
from torch.distributed.fsdp import MixedPrecisionPolicy
@@ -28,6 +29,9 @@ def __init__(self,
2829
self.enable_ep = enable_ep
2930
self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None
3031

32+
def pretrained_load_context(self):
33+
return fsdp_pretrained_load_context(self._memory_efficient_init and self.device_mesh is not None)
34+
3135
def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]:
3236
if self.device_mesh is None:
3337
return None

src/twinkle/model/transformers/transformers.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -212,22 +212,8 @@ def __init__(
212212
else:
213213
model_id = HubOperation.download_model(model_id)
214214
# Trigger transformers' FSDP-aware loading: meta-device init + rank-0-only weight load.
215-
use_efficient_loading = (memory_efficient_init and self.device_mesh is not None)
216-
_saved_env = {}
217-
if use_efficient_loading:
218-
_saved_env['ACCELERATE_USE_FSDP'] = os.environ.get('ACCELERATE_USE_FSDP')
219-
_saved_env['FSDP_CPU_RAM_EFFICIENT_LOADING'] = os.environ.get('FSDP_CPU_RAM_EFFICIENT_LOADING')
220-
os.environ['ACCELERATE_USE_FSDP'] = 'true'
221-
os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING'] = 'true'
222-
try:
215+
with self.strategy.pretrained_load_context():
223216
self.model = model_cls.from_pretrained(model_id, config=config, **kwargs)
224-
finally:
225-
if use_efficient_loading:
226-
for key, old_val in _saved_env.items():
227-
if old_val is None:
228-
os.environ.pop(key, None)
229-
else:
230-
os.environ[key] = old_val
231217
self.model.gradient_checkpointing_enable()
232218
self.sp_strategy = None
233219
self._model_wrapped = False

tests/strategy/test_fsdp2_memory_efficient_init.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def _worker_wrap_model_memory_efficient(rank, world_size, port, ref_sd):
211211
mesh_dim_names=('fsdp', ),
212212
device_type=_DEVICE_TYPE,
213213
)
214-
strategy = NativeFSDPStrategy(device_mesh=mesh, mixed_precision='no', memory_efficient=True)
214+
strategy = NativeFSDPStrategy(device_mesh=mesh, mixed_precision='no', memory_efficient_init=True)
215215

216216
model = TinyModel(dim=32).to(_DEVICE_TYPE)
217217
if rank == 0:
@@ -269,7 +269,7 @@ def _worker_wrap_model_legacy(rank, world_size, port, ref_sd):
269269
mesh_dim_names=('fsdp', ),
270270
device_type=_DEVICE_TYPE,
271271
)
272-
strategy = NativeFSDPStrategy(device_mesh=mesh, mixed_precision='no', memory_efficient=False)
272+
strategy = NativeFSDPStrategy(device_mesh=mesh, mixed_precision='no', memory_efficient_init=False)
273273

274274
model = TinyModel(dim=32).to(_DEVICE_TYPE)
275275
model.load_state_dict(ref_sd)
@@ -324,7 +324,7 @@ def _worker_wrap_model_per_layer(rank, world_size, port, ref_sd):
324324
mesh_dim_names=('fsdp', ),
325325
device_type=_DEVICE_TYPE,
326326
)
327-
strategy = NativeFSDPStrategy(device_mesh=mesh, mixed_precision='no', memory_efficient=True)
327+
strategy = NativeFSDPStrategy(device_mesh=mesh, mixed_precision='no', memory_efficient_init=True)
328328

329329
model = TinyTransformerModel(dim=32, num_layers=2).to(_DEVICE_TYPE)
330330
if rank == 0:

0 commit comments

Comments
 (0)