diff --git a/cookbook/megatron/tp.py b/cookbook/megatron/tp.py index f985d740..8bf4525c 100644 --- a/cookbook/megatron/tp.py +++ b/cookbook/megatron/tp.py @@ -65,8 +65,6 @@ def train(): for step, batch in enumerate(dataloader): # Do forward and backward model.forward_backward(inputs=batch) - _inputs = [input_feature_to_datum(b) for b in batch] - _temp = TwinkleCompatModelBase._get_forward_output(_inputs, model.optimizer_group['default'].outputs['logits']) # Step model.clip_grad_and_step() if step % 5 == 0: diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 1e8f773f..e3e27d2c 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -280,6 +280,15 @@ def _collect_func(method: Union[Literal['none', 'flatten', 'mean', 'sum', 'first return np.array(flatten) return type(result[0])(flatten) elif method in ('avg', 'mean'): + if isinstance(result[0], dict): + output = {} + for key in result[0]: + vals = [r[key] for r in result if key in r] + try: + output[key] = np.mean(vals) + except (TypeError, ValueError): + output[key] = vals + return output return np.mean(result) elif method == 'sum': return np.sum(result) diff --git a/src/twinkle/loss/grpo.py b/src/twinkle/loss/grpo.py index d4997710..c0ff950a 100644 --- a/src/twinkle/loss/grpo.py +++ b/src/twinkle/loss/grpo.py @@ -289,18 +289,18 @@ def __call__( if labels.dim() == 1: labels = labels.unsqueeze(0) - logits = outputs.get('logits') - if logits.shape[1] != labels.shape[1]: - # some mllm return logits with image tokens, exclude here - logits = logits[:, -labels.shape[1]:] - - # labels = torch.roll(labels, shifts=-1, dims=1) - loss_mask = (labels != self.ignore_index).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - logps = selective_log_softmax(logits, masked_labels) - - del logits + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + if logits.shape[1] != labels.shape[1]: + # some mllm return logits with image tokens, exclude here + logits = logits[:, -labels.shape[1]:] + + # labels = torch.roll(labels, shifts=-1, dims=1) + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + logps = selective_log_softmax(logits, masked_labels) device = logps.device diff --git a/src/twinkle/loss/vocab_parallel_cross_entropy.py b/src/twinkle/loss/vocab_parallel_cross_entropy.py index 0a30429f..166e843f 100644 --- a/src/twinkle/loss/vocab_parallel_cross_entropy.py +++ b/src/twinkle/loss/vocab_parallel_cross_entropy.py @@ -4,39 +4,17 @@ class VocabParallelCrossEntropyLoss(Loss): - """Vocab-parallel cross entropy loss for Megatron training with TP > 1. - - This loss uses Megatron's tensor_parallel.vocab_parallel_cross_entropy to - correctly compute cross entropy when vocabulary is sharded across TP ranks. - - NOTE: Labels are expected to be pre-shifted by the template (using np.roll). - This loss does NOT perform additional shifting. - - Args: - ignore_index: The label value to ignore when computing loss. Default: -100. - """ def __init__(self, ignore_index: int = -100): super().__init__() self.ignore_index = ignore_index def __call__(self, inputs, outputs, **kwargs): - from megatron.core import tensor_parallel - - logits = outputs['logits'] labels = inputs['labels'] + logps = outputs.get('logps') - # Transpose: [batch, seq, vocab] -> [seq, batch, vocab] - logits_sbv = logits.transpose(0, 1).contiguous() - labels_sb = labels.transpose(0, 1).contiguous() - - # Compute vocab-parallel cross entropy - per_token_loss = tensor_parallel.vocab_parallel_cross_entropy(logits_sbv, labels_sb) - per_token_loss = per_token_loss.transpose(0, 1).contiguous() - - # Apply loss mask loss_mask = (labels != self.ignore_index).float() return LossOutput( - loss=(per_token_loss * loss_mask).sum(), + loss=(-logps * loss_mask).sum(), num_tokens=loss_mask.sum().clamp(min=1), ) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 512318bd..22977721 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -290,6 +290,51 @@ def _not_encoded(inputs): assert isinstance(inputs, dict) return 'input_ids' not in inputs and 'input_embedding' not in inputs + def _postprocess_tensor_cp(self, tensor): + """All-gather and reconstruct full sequence from CP-split tensor. + + Uses load-balanced split pattern: each CP rank holds chunks [rank] and + [2*cp_size - rank - 1] from the original 2*cp_size chunks. + + Only the current rank's slice retains the original tensor (and its + gradient graph); other ranks' slices are plain copies. This means + backward through the reconstructed tensor only produces gradients for + the local chunk, naturally distributing the gradient across CP ranks + without extra scaling. + + Args: + tensor: [batch_size, seq_len/cp_size] CP-split tensor + + Returns: + [batch_size, full_seq_len] reconstructed full tensor + """ + from megatron.core import parallel_state as mpu + cp_size = mpu.get_context_parallel_world_size() + if cp_size <= 1: + return tensor + + cp_rank = mpu.get_context_parallel_rank() + cp_group = mpu.get_context_parallel_group() + + gathered = [torch.empty_like(tensor) for _ in range(cp_size)] + torch.distributed.all_gather(gathered, tensor.contiguous(), group=cp_group) + gathered[cp_rank] = tensor + + batch_size = tensor.shape[0] + seq_len_per_cp = tensor.shape[1] + full_seq_len = seq_len_per_cp * cp_size + chunk_len = full_seq_len // (2 * cp_size) + half_len = seq_len_per_cp // 2 + + output = tensor.new_zeros(batch_size, full_seq_len) + for j in range(cp_size): + o = gathered[j] + output[:, j * chunk_len:(j + 1) * chunk_len] = o[:, :half_len] + reverse_idx = 2 * cp_size - j - 1 + output[:, reverse_idx * chunk_len:(reverse_idx + 1) * chunk_len] = o[:, half_len:] + + return output + @remote_function() def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`') @@ -420,13 +465,13 @@ def post_loss_function(output_tensor, inputs, logps): mb_idx = _mb_counter[0] _mb_counter[0] += 1 current_kwargs = loss_extra_kwargs_per_mb[mb_idx % len(loss_extra_kwargs_per_mb)] - outputs = ModelOutput(logits=output_tensor) + outputs = ModelOutput(logits=output_tensor, logps=logps) result = loss_instance(inputs, outputs, **current_kwargs) losses = result['loss'] counts = result['num_tokens'] if not counts: counts = torch.tensor(1, device=losses.device) - return self.strategy.gather_loss_for_cp(losses, counts, output_tensor, logps) + return self.strategy.reduce_loss(losses, counts, output_tensor, logps) # Define forward step function for Megatron # forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func)) @@ -435,11 +480,15 @@ def forward_step_func(data_iterator, model): labels = batch.pop('labels', None) output_tensor = model(**batch) batch['labels'] = labels + logps = None if labels is not None: loss_mask = (labels != -100).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 logps = selective_log_softmax(output_tensor, masked_labels) + if cp_size > 1: + logps = self._postprocess_tensor_cp(logps) + batch['labels'] = self._postprocess_tensor_cp(labels) return output_tensor, partial(post_loss_function, inputs=batch, logps=logps) # Get Megatron's forward-backward function @@ -514,7 +563,7 @@ def forward_step_func(data_iterator, model): torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG, group=dp_cp_group) optimizer_config.inputs = inputs - if len({_logps.shape[1] for _logps in logps}) == 1: + if logps and len({_logps.shape[1] for _logps in logps}) == 1: logps = torch.cat(logps, dim=0) if isinstance(loss, torch.Tensor): loss = loss.detach().cpu().float().numpy() diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 7dc8f4d6..758afb4e 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -133,28 +133,8 @@ def _wrap_with_megatron_ddp( return wrapped_models - def gather_loss_for_cp(self, local_loss_sum, local_count, logits, logps): - import torch - from megatron.core import parallel_state as mpu - cp_size = mpu.get_context_parallel_world_size() - - # For CP > 1, aggregate loss across CP ranks - if cp_size > 1: - # All-reduce the count across CP ranks - total_count = local_count.clone() - torch.distributed.nn.all_reduce( - total_count, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group()) - - # All-reduce the loss sum - total_loss_sum = local_loss_sum.clone() - torch.distributed.nn.all_reduce( - total_loss_sum, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group()) - - # Return global mean, divided by cp_size to counteract Megatron's multiplication - loss = (total_loss_sum / total_count.clamp(min=1)) / cp_size - else: - loss = local_loss_sum / local_count.clamp(min=1) - + def reduce_loss(self, local_loss, local_count, logits, logps): + loss = local_loss / local_count.clamp(min=1) return loss, {'loss': loss.detach(), 'logits': logits.detach(), 'logps': logps.detach()} def get_model_config(