Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions cookbook/megatron/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ def train():
for step, batch in enumerate(dataloader):
# Do forward and backward
model.forward_backward(inputs=batch)
_inputs = [input_feature_to_datum(b) for b in batch]
_temp = TwinkleCompatModelBase._get_forward_output(_inputs, model.optimizer_group['default'].outputs['logits'])
# Step
model.clip_grad_and_step()
if step % 5 == 0:
Expand Down
9 changes: 9 additions & 0 deletions src/twinkle/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,15 @@ def _collect_func(method: Union[Literal['none', 'flatten', 'mean', 'sum', 'first
return np.array(flatten)
return type(result[0])(flatten)
elif method in ('avg', 'mean'):
if isinstance(result[0], dict):
output = {}
for key in result[0]:
vals = [r[key] for r in result if key in r]
try:
output[key] = np.mean(vals)
except (TypeError, ValueError):
output[key] = vals
return output
return np.mean(result)
elif method == 'sum':
return np.sum(result)
Expand Down
24 changes: 12 additions & 12 deletions src/twinkle/loss/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,18 +289,18 @@ def __call__(
if labels.dim() == 1:
labels = labels.unsqueeze(0)

logits = outputs.get('logits')
if logits.shape[1] != labels.shape[1]:
# some mllm return logits with image tokens, exclude here
logits = logits[:, -labels.shape[1]:]

# labels = torch.roll(labels, shifts=-1, dims=1)
loss_mask = (labels != self.ignore_index).bool()
masked_labels = labels.clone()
masked_labels[~loss_mask] = 0
logps = selective_log_softmax(logits, masked_labels)

del logits
logps = outputs.get('logps')
if logps is None:
logits = outputs.get('logits')
if logits.shape[1] != labels.shape[1]:
# some mllm return logits with image tokens, exclude here
logits = logits[:, -labels.shape[1]:]

# labels = torch.roll(labels, shifts=-1, dims=1)
loss_mask = (labels != self.ignore_index).bool()
masked_labels = labels.clone()
masked_labels[~loss_mask] = 0
logps = selective_log_softmax(logits, masked_labels)

device = logps.device

Expand Down
26 changes: 2 additions & 24 deletions src/twinkle/loss/vocab_parallel_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,17 @@


class VocabParallelCrossEntropyLoss(Loss):
"""Vocab-parallel cross entropy loss for Megatron training with TP > 1.

This loss uses Megatron's tensor_parallel.vocab_parallel_cross_entropy to
correctly compute cross entropy when vocabulary is sharded across TP ranks.

NOTE: Labels are expected to be pre-shifted by the template (using np.roll).
This loss does NOT perform additional shifting.

Args:
ignore_index: The label value to ignore when computing loss. Default: -100.
"""

def __init__(self, ignore_index: int = -100):
super().__init__()
self.ignore_index = ignore_index

def __call__(self, inputs, outputs, **kwargs):
from megatron.core import tensor_parallel

logits = outputs['logits']
labels = inputs['labels']
logps = outputs.get('logps')

# Transpose: [batch, seq, vocab] -> [seq, batch, vocab]
logits_sbv = logits.transpose(0, 1).contiguous()
labels_sb = labels.transpose(0, 1).contiguous()

# Compute vocab-parallel cross entropy
per_token_loss = tensor_parallel.vocab_parallel_cross_entropy(logits_sbv, labels_sb)
per_token_loss = per_token_loss.transpose(0, 1).contiguous()

# Apply loss mask
loss_mask = (labels != self.ignore_index).float()
return LossOutput(
loss=(per_token_loss * loss_mask).sum(),
loss=(-logps * loss_mask).sum(),
num_tokens=loss_mask.sum().clamp(min=1),
)
55 changes: 52 additions & 3 deletions src/twinkle/model/megatron/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,51 @@ def _not_encoded(inputs):
assert isinstance(inputs, dict)
return 'input_ids' not in inputs and 'input_embedding' not in inputs

def _postprocess_tensor_cp(self, tensor):
"""All-gather and reconstruct full sequence from CP-split tensor.

Uses load-balanced split pattern: each CP rank holds chunks [rank] and
[2*cp_size - rank - 1] from the original 2*cp_size chunks.

Only the current rank's slice retains the original tensor (and its
gradient graph); other ranks' slices are plain copies. This means
backward through the reconstructed tensor only produces gradients for
the local chunk, naturally distributing the gradient across CP ranks
without extra scaling.

Args:
tensor: [batch_size, seq_len/cp_size] CP-split tensor

Returns:
[batch_size, full_seq_len] reconstructed full tensor
"""
from megatron.core import parallel_state as mpu
cp_size = mpu.get_context_parallel_world_size()
if cp_size <= 1:
return tensor

cp_rank = mpu.get_context_parallel_rank()
cp_group = mpu.get_context_parallel_group()

gathered = [torch.empty_like(tensor) for _ in range(cp_size)]
torch.distributed.all_gather(gathered, tensor.contiguous(), group=cp_group)
gathered[cp_rank] = tensor

batch_size = tensor.shape[0]
seq_len_per_cp = tensor.shape[1]
full_seq_len = seq_len_per_cp * cp_size
chunk_len = full_seq_len // (2 * cp_size)
half_len = seq_len_per_cp // 2

output = tensor.new_zeros(batch_size, full_seq_len)
for j in range(cp_size):
o = gathered[j]
output[:, j * chunk_len:(j + 1) * chunk_len] = o[:, :half_len]
reverse_idx = 2 * cp_size - j - 1
output[:, reverse_idx * chunk_len:(reverse_idx + 1) * chunk_len] = o[:, half_len:]

return output

@remote_function()
def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs):
raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`')
Expand Down Expand Up @@ -420,13 +465,13 @@ def post_loss_function(output_tensor, inputs, logps):
mb_idx = _mb_counter[0]
_mb_counter[0] += 1
current_kwargs = loss_extra_kwargs_per_mb[mb_idx % len(loss_extra_kwargs_per_mb)]
outputs = ModelOutput(logits=output_tensor)
outputs = ModelOutput(logits=output_tensor, logps=logps)
result = loss_instance(inputs, outputs, **current_kwargs)
losses = result['loss']
counts = result['num_tokens']
if not counts:
counts = torch.tensor(1, device=losses.device)
return self.strategy.gather_loss_for_cp(losses, counts, output_tensor, logps)
return self.strategy.reduce_loss(losses, counts, output_tensor, logps)

# Define forward step function for Megatron
# forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func))
Expand All @@ -435,11 +480,15 @@ def forward_step_func(data_iterator, model):
labels = batch.pop('labels', None)
output_tensor = model(**batch)
batch['labels'] = labels
logps = None
if labels is not None:
loss_mask = (labels != -100).bool()
masked_labels = labels.clone()
masked_labels[~loss_mask] = 0
logps = selective_log_softmax(output_tensor, masked_labels)
if cp_size > 1:
logps = self._postprocess_tensor_cp(logps)
batch['labels'] = self._postprocess_tensor_cp(labels)
return output_tensor, partial(post_loss_function, inputs=batch, logps=logps)

# Get Megatron's forward-backward function
Expand Down Expand Up @@ -514,7 +563,7 @@ def forward_step_func(data_iterator, model):
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG, group=dp_cp_group)

optimizer_config.inputs = inputs
if len({_logps.shape[1] for _logps in logps}) == 1:
if logps and len({_logps.shape[1] for _logps in logps}) == 1:
logps = torch.cat(logps, dim=0)
if isinstance(loss, torch.Tensor):
loss = loss.detach().cpu().float().numpy()
Expand Down
24 changes: 2 additions & 22 deletions src/twinkle/model/megatron/strategy/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,28 +133,8 @@ def _wrap_with_megatron_ddp(

return wrapped_models

def gather_loss_for_cp(self, local_loss_sum, local_count, logits, logps):
import torch
from megatron.core import parallel_state as mpu
cp_size = mpu.get_context_parallel_world_size()

# For CP > 1, aggregate loss across CP ranks
if cp_size > 1:
# All-reduce the count across CP ranks
total_count = local_count.clone()
torch.distributed.nn.all_reduce(
total_count, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group())

# All-reduce the loss sum
total_loss_sum = local_loss_sum.clone()
torch.distributed.nn.all_reduce(
total_loss_sum, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group())

# Return global mean, divided by cp_size to counteract Megatron's multiplication
loss = (total_loss_sum / total_count.clamp(min=1)) / cp_size
else:
loss = local_loss_sum / local_count.clamp(min=1)

def reduce_loss(self, local_loss, local_count, logits, logps):
loss = local_loss / local_count.clamp(min=1)
return loss, {'loss': loss.detach(), 'logits': logits.detach(), 'logps': logps.detach()}

def get_model_config(
Expand Down
Loading