diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 1e3e5b3ae..1836c3c10 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -21,6 +21,7 @@ from miles.utils.ppo_utils import compute_approx_kl, compute_gspo_kl, compute_opsm_mask, compute_policy_loss from miles.utils.processing_utils import load_processor, load_tokenizer from miles.utils.ray_utils import Box +from miles.utils.rolloutpostprocessor import RolloutPostprocessor from miles.utils.timer import Timer, inverse_timer, timer from miles.utils.tracking_utils import init_tracking @@ -502,20 +503,45 @@ def _log_rollout_data(self, rollout_id: int, rollout_data, packed_batches): for metric_key in ["log_probs", "rollout_log_probs", "ref_log_probs", "advantages", "returns"]: if metric_key not in packed_batches[0]: continue + # Keep existing per-sample mean aggregation for backwards compatibility val = torch.tensor([0.0], device=torch.cuda.current_device()) - for _mbs_id, batches in enumerate(packed_batches): + + # Also collect flat tensors to compute global masked stats via Postprocessor + metric_tensors = [] + mask_tensors = [] + + for mbs_id, batches in enumerate(packed_batches): unpacked_batches = unpack_sequences(batches) for unpacked_batch in unpacked_batches: if isinstance(unpacked_batch[metric_key], torch.Tensor): loss_masks_tensor = unpacked_batch["loss_masks"].to(device=torch.cuda.current_device()) metric_tensor = unpacked_batch[metric_key].to(device=torch.cuda.current_device()) val += (metric_tensor * loss_masks_tensor).sum() / loss_masks_tensor.sum().clamp_min(1) + metric_tensors.append(metric_tensor) + mask_tensors.append(loss_masks_tensor) else: val += unpacked_batch[metric_key] + dist.all_reduce(val, op=dist.ReduceOp.SUM, group=self.dp_group) log_dict[f"rollout/{metric_key}"] = ( val / (self.args.n_samples_per_prompt * self.args.rollout_batch_size) ).item() + + # Compute and attach global masked statistics when possible + if metric_tensors and mask_tensors: + try: + flat_metric = torch.cat(metric_tensors).to(device=torch.cuda.current_device()) + flat_mask = torch.cat(mask_tensors).to(device=torch.cuda.current_device()) + stats = RolloutPostprocessor.compute_masked_stats_safe( + flat_metric, flat_mask, process_group=self.dp_group + ) + log_dict[f"rollout/{metric_key}_global_mean"] = stats["mean"].item() + log_dict[f"rollout/{metric_key}_global_std"] = stats["std"].item() + log_dict[f"rollout/{metric_key}_global_max"] = stats["max"].item() + log_dict[f"rollout/{metric_key}_global_min"] = stats["min"].item() + except Exception e: + logger.errors(f"error in computing global stats for {metric_key}: {e}") + if dist.get_rank() == 0: logger.info(f"rollout {rollout_id}: {log_dict}") log_dict["rollout/step"] = compute_rollout_step(self.args, rollout_id) @@ -618,17 +644,14 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): response_lengths = [batch["response_lengths"] for batch in unpacked_batches] advantages = advantages.to(device=log_probs.device) - old_log_probs = old_log_probs.to(device=log_probs.device) - ppo_kl = old_log_probs - log_probs - - if self.args.use_opsm: - opsm_mask, opsm_clipfrac = compute_opsm_mask( - args=self.args, - full_log_probs=[batch["cur_log_probs"] for batch in unpacked_batches], - full_old_log_probs=[batch[old_log_prob_key] for batch in unpacked_batches], - advantages=[batch["advantages"] for batch in unpacked_batches], - loss_masks=loss_masks, - ) + # compute global advantage stats (masked) + try: + flat_adv_mask = torch.cat(loss_masks).to(device=advantages.device) + adv_stats = RolloutPostprocessor.compute_masked_stats_safe(advantages, flat_adv_mask, process_group=self.dp_group) + except Exception as e: + logger.error(f"error in computing advantage stats: {e}") + adv_stats = None + ppo_kl = old_log_probs.to(device=log_probs.device) - log_probs if self.args.advantage_estimator == "gspo": ppo_kl = compute_gspo_kl( @@ -670,6 +693,14 @@ def _has_rollout_log_probs(batch) -> bool: pg_loss = pg_loss * tis_clip + # compute pg_loss masked stats before reduction + try: + flat_pg_mask = torch.cat(loss_masks).to(device=pg_loss.device) + pg_stats = RolloutPostprocessor.compute_masked_stats_safe(pg_loss, flat_pg_mask, process_group=self.dp_group) + except Exception as e: + logger.error(f"error in computing pg_loss stats: {e}") + pg_stats = None + assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented" pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks) pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) @@ -711,6 +742,16 @@ def _has_rollout_log_probs(batch) -> bool: "entropy_loss": entropy_loss.detach(), } + if adv_stats is not None: + reported["advantage_mean"] = adv_stats["mean"].detach() + reported["advantage_std"] = adv_stats["std"].detach() + reported["advantage_max"] = adv_stats["max"].detach() + reported["advantage_min"] = adv_stats["min"].detach() + + if pg_stats is not None: + reported["pg_loss_max"] = pg_stats["max"].detach() + reported["pg_loss_min"] = pg_stats["min"].detach() + reported["pg_loss_std"] = pg_stats["std"].detach() if train_rollout_logprob_abs_diff is not None: reported["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff diff --git a/miles/backends/megatron_utils/data.py b/miles/backends/megatron_utils/data.py index f94d1b7e0..c929411f8 100644 --- a/miles/backends/megatron_utils/data.py +++ b/miles/backends/megatron_utils/data.py @@ -13,6 +13,7 @@ from miles.utils.data import get_minimum_num_micro_batch_size from miles.utils.flops_utils import calculate_fwd_flops from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step +from miles.utils.rolloutpostprocessor import RolloutPostprocessor from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions from miles.utils.types import RolloutBatch @@ -340,14 +341,48 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc # modified in place and will cause problem for the next rollout. val = torch.cat(val).clone().detach() if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages", "values"]: - sum_of_sample_mean = get_sum_of_sample_mean( - total_lengths, - response_lengths, - loss_masks, - qkv_format=args.qkv_format, - max_seq_lens=max_seq_lens, - ) - val = cp_size * sum_of_sample_mean(val) / len(loss_masks) + # Keep existing per-sample mean behavior + concatenated = val.clone().detach() + sum_of_sample_mean = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks) + val = cp_size * sum_of_sample_mean(concatenated) / len(loss_masks) + # Also compute global masked stats and attach additional keys + try: + # Prepare flat mask tensor + if isinstance(loss_masks[0], torch.Tensor): + flat_mask = torch.cat(loss_masks).to(device=concatenated.device) + else: + flat_mask = torch.tensor( + [m for lm in loss_masks for m in lm], device=concatenated.device + ) + + # Compute local per-rank aggregates (sum, sumsq, count, min, max) + mask_bool = flat_mask.bool() + if mask_bool.numel() == 0 or mask_bool.sum().item() == 0: + local_count = 0 + local_sum = 0.0 + local_sumsq = 0.0 + local_min = float("inf") + local_max = float("-inf") + else: + masked_vals = concatenated[mask_bool] + local_count = int(masked_vals.numel()) + local_sum = float(masked_vals.sum().item()) + local_sumsq = float((masked_vals * masked_vals).sum().item()) + local_min = float(masked_vals.min().item()) + local_max = float(masked_vals.max().item()) + + # Emit per-rank aggregates for pooled reduction + log_dict.update( + { + f"{key}_agg_sum": local_sum, + f"{key}_agg_sumsq": local_sumsq, + f"{key}_agg_count": local_count, + f"{key}_agg_min": local_min, + f"{key}_agg_max": local_max, + } + ) + except Exception as e: + logger.error(f"error in preparing aggregates for {key}: {e}") else: val = val.mean() * cp_size else: @@ -358,7 +393,15 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc raise ValueError(f"Unsupported type: {type(val)} for key: {key}") log_dict[key] = val.item() if isinstance(val, torch.Tensor) else val - reduced_log_dict = gather_log_data("rollout", args, rollout_id, log_dict) + # Aggregate per-rank metrics using all-reduce and log on DP source rank + reduced_log_dict = RolloutPostprocessor.aggregate_and_log( + log_dict, + args, + rollout_id, + process_group=mpu.get_data_parallel_group(), + dp_src_rank=mpu.get_data_parallel_src_rank(with_context_parallel=True), + only_log_on_src=True, + ) if args.ci_test and reduced_log_dict is not None: if ( rollout_id == 0 diff --git a/miles/backends/megatron_utils/loss.py b/miles/backends/megatron_utils/loss.py index d7b72a512..2df2d4dd3 100644 --- a/miles/backends/megatron_utils/loss.py +++ b/miles/backends/megatron_utils/loss.py @@ -19,6 +19,7 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) +from miles.utils.rolloutpostprocessor import RolloutPostprocessor from miles.utils.types import RolloutBatch from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean @@ -465,6 +466,16 @@ def policy_loss_function( are enabled. """ advantages = torch.cat(batch["advantages"], dim=0) + # compute advantage-level global stats (masked by loss_masks) + try: + flat_adv_mask = torch.cat(batch["loss_masks"]).to(device=advantages.device) + adv_stats = RolloutPostprocessor.compute_masked_stats_safe( + advantages, flat_adv_mask, process_group=mpu.get_data_parallel_group() + ) + except Exception as e: + logger.error(f"error in computing advantage stats: {e}") + adv_stats = None + old_log_probs = batch["rollout_log_probs"] if args.use_rollout_logprobs else batch["log_probs"] response_lengths = batch["response_lengths"] @@ -553,6 +564,9 @@ def policy_loss_function( tis_func = vanilla_tis_function pg_loss, modified_response_masks, tis_metrics = tis_func(**tis_kwargs) + # if TIS modified masks, use them for metric computation + metric_masks = modified_response_masks + # [decouple IS and rejection] Rebuild sum_of_sample_mean with modified_response_masks for denominator correction # modified_response_masks will be sliced with cp in get_sum_of_sample_mean sum_of_sample_mean = get_sum_of_sample_mean( @@ -563,6 +577,18 @@ def policy_loss_function( args.qkv_format, batch.get("max_seq_lens", None), ) + else: + metric_masks = batch["loss_masks"] + + # compute pg_loss token-level stats (masked) + try: + flat_pg_mask = torch.cat(metric_masks).to(device=pg_loss.device) + pg_stats = RolloutPostprocessor.compute_masked_stats_safe( + pg_loss, flat_pg_mask, process_group=mpu.get_data_parallel_group() + ) + except Exception as e: + logger.error(f"error in computing pg_loss stats: {e}") + pg_stats = None pg_loss = sum_of_sample_mean(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) @@ -608,6 +634,19 @@ def policy_loss_function( "ppo_kl": ppo_kl.clone().detach(), } + # Add advantage stats if available + if adv_stats is not None: + reported_loss["advantage_mean"] = adv_stats["mean"].clone().detach() + reported_loss["advantage_std"] = adv_stats["std"].clone().detach() + reported_loss["advantage_max"] = adv_stats["max"].clone().detach() + reported_loss["advantage_min"] = adv_stats["min"].clone().detach() + + # Add pg_loss distributional stats if available + if pg_stats is not None: + reported_loss["pg_loss_max"] = pg_stats["max"].clone().detach() + reported_loss["pg_loss_min"] = pg_stats["min"].clone().detach() + reported_loss["pg_loss_std"] = pg_stats["std"].clone().detach() + if train_rollout_logprob_abs_diff is not None: reported_loss["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff.clone().detach() diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 9ee0fbb8a..2023116b7 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -23,6 +23,7 @@ from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix from miles.utils.misc import load_function from miles.utils.ray_utils import Box +from miles.utils.rolloutpostprocessor import RolloutPostprocessor from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions from miles.utils.tracking_utils import init_tracking from miles.utils.types import Sample @@ -608,6 +609,31 @@ def compute_metrics_from_samples(args, samples): response_lengths = [sample.effective_response_length for sample in samples] log_dict = {} + log_dict |= dict_add_prefix(compute_statistics(response_lengths), f"response_len/") + # Add distributed-safe global stats for response lengths and rewards when possible + try: + if len(response_lengths) > 0: + vals = torch.tensor(response_lengths, dtype=torch.float32) + mask = torch.ones_like(vals, dtype=torch.bool) + stats = RolloutPostprocessor.compute_masked_stats_safe(vals, mask, process_group=None) + log_dict["response_len/global_mean"] = stats["mean"].item() + log_dict["response_len/global_std"] = stats["std"].item() + log_dict["response_len/global_max"] = stats["max"].item() + log_dict["response_len/global_min"] = stats["min"].item() + + # rewards + rewards = [sample.get_reward_value(args) for sample in samples] + if len(rewards) > 0: + rvals = torch.tensor(rewards, dtype=torch.float32) + rmask = torch.ones_like(rvals, dtype=torch.bool) + rstats = RolloutPostprocessor.compute_masked_stats_safe(rvals, rmask, process_group=None) + log_dict["reward/global_mean"] = rstats["mean"].item() + log_dict["reward/global_std"] = rstats["std"].item() + log_dict["reward/global_max"] = rstats["max"].item() + log_dict["reward/global_min"] = rstats["min"].item() + except Exception: + # Best-effort only; fall back to existing stats if anything goes wrong + pass log_dict |= dict_add_prefix(compute_statistics(response_lengths), "response_len/") log_dict |= _compute_zero_std_metrics(args, samples) log_dict |= _compute_spec_metrics(args, samples) diff --git a/miles/utils/rolloutpostprocessor.py b/miles/utils/rolloutpostprocessor.py new file mode 100644 index 000000000..079dc6801 --- /dev/null +++ b/miles/utils/rolloutpostprocessor.py @@ -0,0 +1,250 @@ +from typing import Optional + +import torch +import torch.distributed as dist + + +class RolloutPostprocessor: + """Postprocessing helpers for rollout / loss metrics. + + This class centralizes distributed, masked statistics used when + reporting metrics such as advantage and per-token pg_loss distributions. + """ + + @staticmethod + def compute_global_masked_stats( + values: torch.Tensor, + mask: torch.Tensor, + process_group: Optional[dist.ProcessGroup] = None, + ) -> dict: + """ + Compute global mean, std, min and max over elements where `mask` is truthy, + aggregating across the provided process group. + + Returns dict with keys `mean`, `std`, `min`, `max` as torch tensors on + `values.device`. + """ + mask_bool = mask.bool() + + local_sum = (values * mask_bool).sum() + local_sum_sq = ((values**2) * mask_bool).sum() + local_count = mask_bool.sum().to(dtype=torch.float32) + + stats_tensor = torch.tensor([local_sum, local_sum_sq, local_count], device=values.device, dtype=torch.float32) + dist.all_reduce(stats_tensor, group=process_group) + + global_sum, global_sum_sq, global_count = stats_tensor + + if global_count.item() == 0: + zero = torch.tensor(0.0, device=values.device) + return { + "mean": zero, + "std": zero, + "min": torch.tensor(float("inf"), device=values.device), + "max": torch.tensor(float("-inf"), device=values.device), + } + + global_mean = global_sum / global_count + global_mean_sq = global_sum_sq / global_count + global_var = global_mean_sq - global_mean**2 + + if global_count.item() >= 2: + bessel = global_count / (global_count - 1) + global_var = global_var * bessel + + global_std = torch.sqrt(torch.clamp(global_var, min=0.0)) + + local_max = torch.where(mask_bool, values, torch.tensor(float("-inf"), device=values.device)) + local_min = torch.where(mask_bool, values, torch.tensor(float("inf"), device=values.device)) + + max_tensor = local_max.max() if local_max.numel() > 0 else torch.tensor(float("-inf"), device=values.device) + min_tensor = local_min.min() if local_min.numel() > 0 else torch.tensor(float("inf"), device=values.device) + + dist.all_reduce(max_tensor, op=dist.ReduceOp.MAX, group=process_group) + dist.all_reduce(min_tensor, op=dist.ReduceOp.MIN, group=process_group) + + return {"mean": global_mean, "std": global_std, "min": min_tensor, "max": max_tensor} + + @staticmethod + def compute_masked_stats_safe( + values: torch.Tensor, + mask: torch.Tensor, + process_group: Optional[dist.ProcessGroup] = None, + ) -> dict: + """ + Safe wrapper around `compute_global_masked_stats` that falls back to + local (non-distributed) statistics when torch.distributed is not + available or not initialized. This avoids runtime errors in contexts + (e.g., Ray rollout workers) where distributed backend isn't set up. + + Returns the same dict format: {"mean", "std", "min", "max"}. + """ + # If distributed isn't available/initialized, compute local masked stats. + if not dist.is_available() or not dist.is_initialized(): + mask_bool = mask.bool() + if mask_bool.numel() == 0 or mask_bool.sum().item() == 0: + zero = torch.tensor(0.0, device=values.device) + return { + "mean": zero, + "std": zero, + "min": torch.tensor(float("inf"), device=values.device), + "max": torch.tensor(float("-inf"), device=values.device), + } + + vals = values[mask_bool] + mean = vals.mean() + std = vals.std(unbiased=False) if vals.numel() > 1 else torch.tensor(0.0, device=values.device) + min_v = vals.min() + max_v = vals.max() + return {"mean": mean, "std": std, "min": min_v, "max": max_v} + + # Otherwise delegate to the distributed implementation + return RolloutPostprocessor.compute_global_masked_stats(values, mask, process_group=process_group) + + @staticmethod + def aggregate_and_log( + log_dict: dict, + args, + rollout_id: int, + process_group: Optional[dist.ProcessGroup] = None, + dp_src_rank: int = 0, + only_log_on_src: bool = True, + ) -> Optional[dict]: + """ + Aggregate per-rank metrics into pooled/global metrics and log them. + + Expected convention for pooled fields: callers should emit per-rank + aggregates with suffixes: `_agg_sum`, `_agg_sumsq`, `_agg_count`, + `_agg_min`, `_agg_max` for fields that require pooled mean/std/min/max. + + Non-aggregate scalar keys (plain numeric) will be averaged across + ranks via all-reduce sum/mean. + + Returns the reduced dict (with keys prefixed by `rollout/`) on all + ranks. Logging via `tracking_utils.log` happens on the DP source rank + when `only_log_on_src` is True. + """ + # Fast path: non-distributed -> compute locally + if not dist.is_available() or not dist.is_initialized() or process_group is None: + reduced: dict = {} + # Handle aggregate bases + agg_bases = {k[: -len("_agg_sum")] for k in log_dict.keys() if k.endswith("_agg_sum")} + for base in agg_bases: + s = float(log_dict.get(f"{base}_agg_sum", 0.0)) + ssq = float(log_dict.get(f"{base}_agg_sumsq", 0.0)) + cnt = int(log_dict.get(f"{base}_agg_count", 0)) + mn = float(log_dict.get(f"{base}_agg_min", float("inf"))) + mx = float(log_dict.get(f"{base}_agg_max", float("-inf"))) + if cnt > 0: + mean = s / cnt + var = ssq / cnt - mean * mean + if cnt >= 2: + var = var * (cnt / (cnt - 1)) + std = float(max(var, 0.0) ** 0.5) + else: + mean = 0.0 + std = 0.0 + mn = 0.0 + mx = 0.0 + reduced[f"rollout/{base}_global_mean"] = mean + reduced[f"rollout/{base}_global_std"] = std + reduced[f"rollout/{base}_global_min"] = mn + reduced[f"rollout/{base}_global_max"] = mx + + # Average non-aggregate numeric keys + non_agg_keys = [k for k in log_dict.keys() if not any(k.startswith(b) for b in agg_bases)] + for key in non_agg_keys: + v = log_dict[key] + try: + reduced[f"rollout/{key}"] = float(v) + except Exception: + reduced[f"rollout/{key}"] = v + + # Add step if available + reduced["rollout/step"] = compute_rollout_step(args, rollout_id) + if not only_log_on_src: + tracking_utils.log(args, reduced, step_key="rollout/step") + return reduced + + # Distributed path: perform all-reduces per aggregate + dp_size = dist.get_world_size(group=process_group) + + gathered_reduced: dict = {} + + # Find aggregate bases + agg_bases = {k[: -len("_agg_sum")] for k in log_dict.keys() if k.endswith("_agg_sum")} + + # For each aggregate base, all-reduce sum, sumsq, count and reduce min/max + for base in agg_bases: + local_sum = torch.tensor(float(log_dict.get(f"{base}_agg_sum", 0.0)), dtype=torch.float64, device="cpu") + local_sumsq = torch.tensor( + float(log_dict.get(f"{base}_agg_sumsq", 0.0)), dtype=torch.float64, device="cpu" + ) + local_count = torch.tensor(int(log_dict.get(f"{base}_agg_count", 0)), dtype=torch.float64, device="cpu") + # Use CPU tensors for small reductions to avoid GPU sync + dist.all_reduce(local_sum, op=dist.ReduceOp.SUM, group=process_group) + dist.all_reduce(local_sumsq, op=dist.ReduceOp.SUM, group=process_group) + dist.all_reduce(local_count, op=dist.ReduceOp.SUM, group=process_group) + + total_count = int(local_count.item()) + total_sum = float(local_sum.item()) + total_sumsq = float(local_sumsq.item()) + + # Reduce min/max + local_min = torch.tensor( + float(log_dict.get(f"{base}_agg_min", float("inf"))), dtype=torch.float64, device="cpu" + ) + local_max = torch.tensor( + float(log_dict.get(f"{base}_agg_max", float("-inf"))), dtype=torch.float64, device="cpu" + ) + dist.all_reduce(local_min, op=dist.ReduceOp.MIN, group=process_group) + dist.all_reduce(local_max, op=dist.ReduceOp.MAX, group=process_group) + + if total_count > 0: + mean = total_sum / total_count + var = total_sumsq / total_count - mean * mean + if total_count >= 2: + var = var * (total_count / (total_count - 1)) + std = float(max(var, 0.0) ** 0.5) + mn = float(local_min.item()) + mx = float(local_max.item()) + else: + mean = 0.0 + std = 0.0 + mn = 0.0 + mx = 0.0 + + gathered_reduced[f"rollout/{base}_global_mean"] = mean + gathered_reduced[f"rollout/{base}_global_std"] = std + gathered_reduced[f"rollout/{base}_global_min"] = mn + gathered_reduced[f"rollout/{base}_global_max"] = mx + + # Handle non-aggregate numeric keys: all-reduce sum then divide by dp_size + non_agg_keys = [k for k in log_dict.keys() if not any(k.startswith(b) for b in agg_bases)] + for key in non_agg_keys: + v = log_dict[key] + try: + t = torch.tensor(float(v), dtype=torch.float64, device="cpu") + dist.all_reduce(t, op=dist.ReduceOp.SUM, group=process_group) + gathered_reduced[f"rollout/{key}"] = float(t.item()) / float(dp_size) + except Exception: + # Non-numeric -> pick first-rank's value via gather_object + vals = [None] * dp_size + dist.gather_object( + log_dict[key], + vals if dist.get_rank() == dp_src_rank else None, + dst=dp_src_rank, + group=process_group, + ) + if dist.get_rank() == dp_src_rank: + gathered_reduced[f"rollout/{key}"] = vals[0] + + # Add rollout step + gathered_reduced["rollout/step"] = compute_rollout_step(args, rollout_id) + + # Logging only on source rank by default + rank = dist.get_rank(group=process_group) if hasattr(dist, "get_rank") else dist.get_rank() + if not only_log_on_src or rank == dp_src_rank: + tracking_utils.log(args, gathered_reduced, step_key="rollout/step") + + return gathered_reduced