We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c944a4b commit ce71139Copy full SHA for ce71139
src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -193,6 +193,7 @@ def _maybe_shard_ep_expert_blocks(model: nn.Module,
193
reshard_after_forward: Optional[bool],
194
mp_policy: 'MixedPrecisionPolicy') -> int:
195
from torch.distributed.fsdp import fully_shard
196
+ from torch.distributed.tensor import Shard
197
sharded_blocks = 0
198
for module in model.modules():
199
if not getattr(module, "_ep_patched", False):
@@ -207,6 +208,7 @@ def _maybe_shard_ep_expert_blocks(model: nn.Module,
207
208
mesh=mesh,
209
reshard_after_forward=reshard_after_forward,
210
mp_policy=mp_policy,
211
+ shard_placement_fn=lambda param: Shard(1),
212
)
213
sharded_blocks += 1
214
return sharded_blocks
0 commit comments