Skip to content

Commit ce71139

Browse files
committed
wip
1 parent c944a4b commit ce71139

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/twinkle/model/transformers/strategy/native_fsdp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def _maybe_shard_ep_expert_blocks(model: nn.Module,
193193
reshard_after_forward: Optional[bool],
194194
mp_policy: 'MixedPrecisionPolicy') -> int:
195195
from torch.distributed.fsdp import fully_shard
196+
from torch.distributed.tensor import Shard
196197
sharded_blocks = 0
197198
for module in model.modules():
198199
if not getattr(module, "_ep_patched", False):
@@ -207,6 +208,7 @@ def _maybe_shard_ep_expert_blocks(model: nn.Module,
207208
mesh=mesh,
208209
reshard_after_forward=reshard_after_forward,
209210
mp_policy=mp_policy,
211+
shard_placement_fn=lambda param: Shard(1),
210212
)
211213
sharded_blocks += 1
212214
return sharded_blocks

0 commit comments

Comments
 (0)