Skip to content

Commit 1d5eb8f

Browse files
committed
support cp and gdn sp
1 parent ccf0d79 commit 1d5eb8f

File tree

13 files changed

+2732
-1190
lines changed

13 files changed

+2732
-1190
lines changed

cookbook/transformers/sp_fsdp_dense.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
device_type=Platform.get_platform().device_prefix(),
2020
)]
2121

22-
# FSDP + SP validation over 4 GPUs: dp=2, fsdp=2 (SP only affects input slicing)
22+
# FSDP + sequence-parallel validation over 4 GPUs: dp=2, fsdp=2.
23+
# In Transformers route, ulysses_size is the total sequence-parallel degree.
2324
device_mesh = DeviceMesh(
24-
device_type='cuda',
25+
device_type=Platform.get_platform().device_prefix(),
2526
mesh=np.arange(4).reshape(2, 2),
2627
mesh_dim_names=('dp', 'fsdp'),
2728
ulysses_size=2,

cookbook/transformers/sp_fsdp_dense.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/bin/bash
2-
# To enabele sequence parallelism, please set ulysses_size > 1
2+
# To enable Transformers sequence parallelism, please set ulysses_size > 1.
3+
# ulysses_size is interpreted as the total sequence-parallel degree.
34
# device_mesh = DeviceMesh(
45
# device_type="cuda",
56
# mesh=np.arange(4).reshape(2, 2),

src/twinkle/metric/loss.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,19 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M
2525
return
2626
loss = outputs['loss']
2727
loss_reduction = kwargs.get('loss_reduction', 'mean')
28+
ulysses_size = getattr(self.device_mesh, 'ulysses_size', None) or 1
2829
if loss_reduction == 'sum':
2930
if not isinstance(inputs, list):
3031
inputs = [inputs]
3132
for input in inputs:
3233
# `Transformers` models may use reduction=sum, to average grads before step
3334
labels = input['labels']
3435
self.num_tokens += (labels >= 0).sum().item()
36+
# Sequence-parallel gathered loss is replicated on each ulysses rank, while
37+
# local labels still count only the shard-local tokens. Normalize the loss
38+
# contribution here so metric-side averaging matches the non-SP path.
39+
if ulysses_size > 1:
40+
loss = loss / float(ulysses_size)
3541
grad_norm = kwargs.get('grad_norm')
3642
if grad_norm is not None:
3743
self.grad_norm = grad_norm

0 commit comments

Comments
 (0)