Skip to content

Commit 85d31f1

Browse files
authored
[feat] recover cp sequence before loss (#88)
1 parent cdb8342 commit 85d31f1

File tree

6 files changed

+77
-63
lines changed

6 files changed

+77
-63
lines changed

cookbook/megatron/tp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ def train():
6565
for step, batch in enumerate(dataloader):
6666
# Do forward and backward
6767
model.forward_backward(inputs=batch)
68-
_inputs = [input_feature_to_datum(b) for b in batch]
69-
_temp = TwinkleCompatModelBase._get_forward_output(_inputs, model.optimizer_group['default'].outputs['logits'])
7068
# Step
7169
model.clip_grad_and_step()
7270
if step % 5 == 0:

src/twinkle/infra/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,15 @@ def _collect_func(method: Union[Literal['none', 'flatten', 'mean', 'sum', 'first
280280
return np.array(flatten)
281281
return type(result[0])(flatten)
282282
elif method in ('avg', 'mean'):
283+
if isinstance(result[0], dict):
284+
output = {}
285+
for key in result[0]:
286+
vals = [r[key] for r in result if key in r]
287+
try:
288+
output[key] = np.mean(vals)
289+
except (TypeError, ValueError):
290+
output[key] = vals
291+
return output
283292
return np.mean(result)
284293
elif method == 'sum':
285294
return np.sum(result)

src/twinkle/loss/grpo.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -289,18 +289,18 @@ def __call__(
289289
if labels.dim() == 1:
290290
labels = labels.unsqueeze(0)
291291

292-
logits = outputs.get('logits')
293-
if logits.shape[1] != labels.shape[1]:
294-
# some mllm return logits with image tokens, exclude here
295-
logits = logits[:, -labels.shape[1]:]
296-
297-
# labels = torch.roll(labels, shifts=-1, dims=1)
298-
loss_mask = (labels != self.ignore_index).bool()
299-
masked_labels = labels.clone()
300-
masked_labels[~loss_mask] = 0
301-
logps = selective_log_softmax(logits, masked_labels)
302-
303-
del logits
292+
logps = outputs.get('logps')
293+
if logps is None:
294+
logits = outputs.get('logits')
295+
if logits.shape[1] != labels.shape[1]:
296+
# some mllm return logits with image tokens, exclude here
297+
logits = logits[:, -labels.shape[1]:]
298+
299+
# labels = torch.roll(labels, shifts=-1, dims=1)
300+
loss_mask = (labels != self.ignore_index).bool()
301+
masked_labels = labels.clone()
302+
masked_labels[~loss_mask] = 0
303+
logps = selective_log_softmax(logits, masked_labels)
304304

305305
device = logps.device
306306

src/twinkle/loss/vocab_parallel_cross_entropy.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,17 @@
44

55

66
class VocabParallelCrossEntropyLoss(Loss):
7-
"""Vocab-parallel cross entropy loss for Megatron training with TP > 1.
8-
9-
This loss uses Megatron's tensor_parallel.vocab_parallel_cross_entropy to
10-
correctly compute cross entropy when vocabulary is sharded across TP ranks.
11-
12-
NOTE: Labels are expected to be pre-shifted by the template (using np.roll).
13-
This loss does NOT perform additional shifting.
14-
15-
Args:
16-
ignore_index: The label value to ignore when computing loss. Default: -100.
17-
"""
187

198
def __init__(self, ignore_index: int = -100):
209
super().__init__()
2110
self.ignore_index = ignore_index
2211

2312
def __call__(self, inputs, outputs, **kwargs):
24-
from megatron.core import tensor_parallel
25-
26-
logits = outputs['logits']
2713
labels = inputs['labels']
14+
logps = outputs.get('logps')
2815

29-
# Transpose: [batch, seq, vocab] -> [seq, batch, vocab]
30-
logits_sbv = logits.transpose(0, 1).contiguous()
31-
labels_sb = labels.transpose(0, 1).contiguous()
32-
33-
# Compute vocab-parallel cross entropy
34-
per_token_loss = tensor_parallel.vocab_parallel_cross_entropy(logits_sbv, labels_sb)
35-
per_token_loss = per_token_loss.transpose(0, 1).contiguous()
36-
37-
# Apply loss mask
3816
loss_mask = (labels != self.ignore_index).float()
3917
return LossOutput(
40-
loss=(per_token_loss * loss_mask).sum(),
18+
loss=(-logps * loss_mask).sum(),
4119
num_tokens=loss_mask.sum().clamp(min=1),
4220
)

src/twinkle/model/megatron/megatron.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,51 @@ def _not_encoded(inputs):
290290
assert isinstance(inputs, dict)
291291
return 'input_ids' not in inputs and 'input_embedding' not in inputs
292292

293+
def _postprocess_tensor_cp(self, tensor):
294+
"""All-gather and reconstruct full sequence from CP-split tensor.
295+
296+
Uses load-balanced split pattern: each CP rank holds chunks [rank] and
297+
[2*cp_size - rank - 1] from the original 2*cp_size chunks.
298+
299+
Only the current rank's slice retains the original tensor (and its
300+
gradient graph); other ranks' slices are plain copies. This means
301+
backward through the reconstructed tensor only produces gradients for
302+
the local chunk, naturally distributing the gradient across CP ranks
303+
without extra scaling.
304+
305+
Args:
306+
tensor: [batch_size, seq_len/cp_size] CP-split tensor
307+
308+
Returns:
309+
[batch_size, full_seq_len] reconstructed full tensor
310+
"""
311+
from megatron.core import parallel_state as mpu
312+
cp_size = mpu.get_context_parallel_world_size()
313+
if cp_size <= 1:
314+
return tensor
315+
316+
cp_rank = mpu.get_context_parallel_rank()
317+
cp_group = mpu.get_context_parallel_group()
318+
319+
gathered = [torch.empty_like(tensor) for _ in range(cp_size)]
320+
torch.distributed.all_gather(gathered, tensor.contiguous(), group=cp_group)
321+
gathered[cp_rank] = tensor
322+
323+
batch_size = tensor.shape[0]
324+
seq_len_per_cp = tensor.shape[1]
325+
full_seq_len = seq_len_per_cp * cp_size
326+
chunk_len = full_seq_len // (2 * cp_size)
327+
half_len = seq_len_per_cp // 2
328+
329+
output = tensor.new_zeros(batch_size, full_seq_len)
330+
for j in range(cp_size):
331+
o = gathered[j]
332+
output[:, j * chunk_len:(j + 1) * chunk_len] = o[:, :half_len]
333+
reverse_idx = 2 * cp_size - j - 1
334+
output[:, reverse_idx * chunk_len:(reverse_idx + 1) * chunk_len] = o[:, half_len:]
335+
336+
return output
337+
293338
@remote_function()
294339
def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs):
295340
raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`')
@@ -420,13 +465,13 @@ def post_loss_function(output_tensor, inputs, logps):
420465
mb_idx = _mb_counter[0]
421466
_mb_counter[0] += 1
422467
current_kwargs = loss_extra_kwargs_per_mb[mb_idx % len(loss_extra_kwargs_per_mb)]
423-
outputs = ModelOutput(logits=output_tensor)
468+
outputs = ModelOutput(logits=output_tensor, logps=logps)
424469
result = loss_instance(inputs, outputs, **current_kwargs)
425470
losses = result['loss']
426471
counts = result['num_tokens']
427472
if not counts:
428473
counts = torch.tensor(1, device=losses.device)
429-
return self.strategy.gather_loss_for_cp(losses, counts, output_tensor, logps)
474+
return self.strategy.reduce_loss(losses, counts, output_tensor, logps)
430475

431476
# Define forward step function for Megatron
432477
# forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func))
@@ -435,11 +480,15 @@ def forward_step_func(data_iterator, model):
435480
labels = batch.pop('labels', None)
436481
output_tensor = model(**batch)
437482
batch['labels'] = labels
483+
logps = None
438484
if labels is not None:
439485
loss_mask = (labels != -100).bool()
440486
masked_labels = labels.clone()
441487
masked_labels[~loss_mask] = 0
442488
logps = selective_log_softmax(output_tensor, masked_labels)
489+
if cp_size > 1:
490+
logps = self._postprocess_tensor_cp(logps)
491+
batch['labels'] = self._postprocess_tensor_cp(labels)
443492
return output_tensor, partial(post_loss_function, inputs=batch, logps=logps)
444493

445494
# Get Megatron's forward-backward function
@@ -514,7 +563,7 @@ def forward_step_func(data_iterator, model):
514563
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG, group=dp_cp_group)
515564

516565
optimizer_config.inputs = inputs
517-
if len({_logps.shape[1] for _logps in logps}) == 1:
566+
if logps and len({_logps.shape[1] for _logps in logps}) == 1:
518567
logps = torch.cat(logps, dim=0)
519568
if isinstance(loss, torch.Tensor):
520569
loss = loss.detach().cpu().float().numpy()

src/twinkle/model/megatron/strategy/megatron.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -133,28 +133,8 @@ def _wrap_with_megatron_ddp(
133133

134134
return wrapped_models
135135

136-
def gather_loss_for_cp(self, local_loss_sum, local_count, logits, logps):
137-
import torch
138-
from megatron.core import parallel_state as mpu
139-
cp_size = mpu.get_context_parallel_world_size()
140-
141-
# For CP > 1, aggregate loss across CP ranks
142-
if cp_size > 1:
143-
# All-reduce the count across CP ranks
144-
total_count = local_count.clone()
145-
torch.distributed.nn.all_reduce(
146-
total_count, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group())
147-
148-
# All-reduce the loss sum
149-
total_loss_sum = local_loss_sum.clone()
150-
torch.distributed.nn.all_reduce(
151-
total_loss_sum, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group())
152-
153-
# Return global mean, divided by cp_size to counteract Megatron's multiplication
154-
loss = (total_loss_sum / total_count.clamp(min=1)) / cp_size
155-
else:
156-
loss = local_loss_sum / local_count.clamp(min=1)
157-
136+
def reduce_loss(self, local_loss, local_count, logits, logps):
137+
loss = local_loss / local_count.clamp(min=1)
158138
return loss, {'loss': loss.detach(), 'logits': logits.detach(), 'logps': logps.detach()}
159139

160140
def get_model_config(

0 commit comments

Comments
 (0)