From be82fb097bf097fd0e3b0f907f093afea52afe23 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 28 Dec 2025 20:13:22 +0000 Subject: [PATCH 1/5] add unified logging Signed-off-by: root --- miles/backends/fsdp_utils/actor.py | 49 +++++++++++++ miles/backends/megatron_utils/data.py | 29 +++++++- miles/backends/megatron_utils/loss.py | 36 +++++++++ miles/ray/rollout.py | 25 +++++++ miles/utils/postprocessor.py | 102 ++++++++++++++++++++++++++ 5 files changed, 240 insertions(+), 1 deletion(-) create mode 100644 miles/utils/postprocessor.py diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index c55ed5248..7ab3b9b74 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -22,6 +22,7 @@ from miles.utils.distributed_utils import get_gloo_group from miles.utils.memory_utils import clear_memory, print_memory from miles.utils.metric_utils import compute_rollout_step +from miles.utils.postprocessor import Postprocessor from miles.utils.ppo_utils import compute_approx_kl, compute_policy_loss from miles.utils.ray_utils import Box from miles.utils.timer import Timer, inverse_timer, timer @@ -455,7 +456,13 @@ def _train_core(self, rollout_id: int, rollout_data_ref: Box) -> None: for metric_key in ["log_probs", "ref_log_probs", "advantages", "returns"]: if metric_key not in packed_batches[0]: continue + # Keep existing per-sample mean aggregation for backwards compatibility val = torch.tensor([0.0], device=torch.cuda.current_device()) + + # Also collect flat tensors to compute global masked stats via Postprocessor + metric_tensors = [] + mask_tensors = [] + for mbs_id, batches in enumerate(packed_batches): unpacked_batches = unpack_sequences(batches) for unpacked_batch in unpacked_batches: @@ -463,12 +470,30 @@ def _train_core(self, rollout_id: int, rollout_data_ref: Box) -> None: loss_masks_tensor = unpacked_batch["loss_masks"].to(device=torch.cuda.current_device()) metric_tensor = unpacked_batch[metric_key].to(device=torch.cuda.current_device()) val += (metric_tensor * loss_masks_tensor).sum() / loss_masks_tensor.sum().clamp_min(1) + metric_tensors.append(metric_tensor) + mask_tensors.append(loss_masks_tensor) else: val += unpacked_batch[metric_key] + dist.all_reduce(val, op=dist.ReduceOp.SUM, group=self.dp_group) log_dict[f"rollout/{metric_key}"] = ( val / (self.args.n_samples_per_prompt * self.args.rollout_batch_size) ).item() + + # Compute and attach global masked statistics when possible + if metric_tensors and mask_tensors: + try: + flat_metric = torch.cat(metric_tensors).to(device=torch.cuda.current_device()) + flat_mask = torch.cat(mask_tensors).to(device=torch.cuda.current_device()) + stats = Postprocessor.compute_masked_stats_safe( + flat_metric, flat_mask, process_group=self.dp_group + ) + log_dict[f"rollout/{metric_key}_global_mean"] = stats["mean"].item() + log_dict[f"rollout/{metric_key}_global_std"] = stats["std"].item() + log_dict[f"rollout/{metric_key}_global_max"] = stats["max"].item() + log_dict[f"rollout/{metric_key}_global_min"] = stats["min"].item() + except Exception: + pass if dist.get_rank() == 0: logger.info(f"rollout {rollout_id}: {log_dict}") log_dict["rollout/step"] = compute_rollout_step(self.args, rollout_id) @@ -535,6 +560,12 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): response_lengths = [batch["response_lengths"] for batch in unpacked_batches] advantages = advantages.to(device=log_probs.device) + # compute global advantage stats (masked) + try: + flat_adv_mask = torch.cat(loss_masks).to(device=advantages.device) + adv_stats = Postprocessor.compute_masked_stats_safe(advantages, flat_adv_mask, process_group=self.dp_group) + except Exception: + adv_stats = None ppo_kl = old_log_probs.to(device=log_probs.device) - log_probs if self.args.advantage_estimator == "gspo": @@ -579,6 +610,13 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): pg_loss = pg_loss * tis_clip + # compute pg_loss masked stats before reduction + try: + flat_pg_mask = torch.cat(loss_masks).to(device=pg_loss.device) + pg_stats = Postprocessor.compute_masked_stats_safe(pg_loss, flat_pg_mask, process_group=self.dp_group) + except Exception: + pg_stats = None + assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented" pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks) pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) @@ -614,6 +652,17 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): "train_rollout_logprob_abs_diff": train_rollout_logprob_abs_diff, } + if adv_stats is not None: + reported["advantage_mean"] = adv_stats["mean"].detach() + reported["advantage_std"] = adv_stats["std"].detach() + reported["advantage_max"] = adv_stats["max"].detach() + reported["advantage_min"] = adv_stats["min"].detach() + + if pg_stats is not None: + reported["pg_loss_max"] = pg_stats["max"].detach() + reported["pg_loss_min"] = pg_stats["min"].detach() + reported["pg_loss_std"] = pg_stats["std"].detach() + if self.args.use_kl_loss: reported["kl_loss"] = kl_loss.detach() diff --git a/miles/backends/megatron_utils/data.py b/miles/backends/megatron_utils/data.py index fcd8c184c..e5a0ff915 100644 --- a/miles/backends/megatron_utils/data.py +++ b/miles/backends/megatron_utils/data.py @@ -13,6 +13,7 @@ from miles.utils.data import get_minimum_num_micro_batch_size from miles.utils.flops_utils import calculate_fwd_flops from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step +from miles.utils.postprocessor import Postprocessor from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions from miles.utils.types import RolloutBatch @@ -324,8 +325,34 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc # modified in place and will cause problem for the next rollout. val = torch.cat(val).clone().detach() if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages", "values"]: + # Keep existing per-sample mean behavior + concatenated = val.clone().detach() sum_of_sample_mean = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks) - val = cp_size * sum_of_sample_mean(val) / len(loss_masks) + val = cp_size * sum_of_sample_mean(concatenated) / len(loss_masks) + # Also compute global masked stats and attach additional keys + try: + # Prepare flat mask tensor + if isinstance(loss_masks[0], torch.Tensor): + flat_mask = torch.cat(loss_masks).to(device=concatenated.device) + else: + flat_mask = torch.tensor( + [m for lm in loss_masks for m in lm], device=concatenated.device + ) + + stats = Postprocessor.compute_masked_stats_safe( + concatenated, flat_mask, process_group=mpu.get_data_parallel_group() + ) + # Attach extra keys with a rollout/ prefix; gather_log_data will prefix again + val_stats = { + f"{key}_global_mean": stats["mean"].item(), + f"{key}_global_std": stats["std"].item(), + f"{key}_global_max": stats["max"].item(), + f"{key}_global_min": stats["min"].item(), + } + # Merge these into log_dict (they will be reduced/gathered later) + log_dict.update(val_stats) + except Exception: + pass else: val = val.mean() * cp_size else: diff --git a/miles/backends/megatron_utils/loss.py b/miles/backends/megatron_utils/loss.py index 83c150402..e71c097f1 100644 --- a/miles/backends/megatron_utils/loss.py +++ b/miles/backends/megatron_utils/loss.py @@ -7,6 +7,7 @@ from miles.utils.distributed_utils import distributed_masked_whiten from miles.utils.misc import load_function +from miles.utils.postprocessor import Postprocessor from miles.utils.ppo_utils import ( calculate_log_probs_and_entropy, compute_approx_kl, @@ -392,6 +393,14 @@ def policy_loss_function( are enabled. """ advantages = torch.cat(batch["advantages"], dim=0) + # compute advantage-level global stats (masked by loss_masks) + try: + flat_adv_mask = torch.cat(batch["loss_masks"]).to(device=advantages.device) + adv_stats = Postprocessor.compute_masked_stats_safe( + advantages, flat_adv_mask, process_group=mpu.get_data_parallel_group() + ) + except Exception: + adv_stats = None old_log_probs = batch["rollout_log_probs"] if args.use_rollout_logprobs else batch["log_probs"] response_lengths = batch["response_lengths"] @@ -479,11 +488,25 @@ def vanilla_tis_function( tis_func = vanilla_tis_function pg_loss, modified_response_masks, tis_metrics = tis_func(**tis_kwargs) + # if TIS modified masks, use them for metric computation + metric_masks = modified_response_masks + # [decouple IS and rejection] Rebuild sum_of_sample_mean with modified_response_masks for denominator correction # modified_response_masks will be sliced with cp in get_sum_of_sample_mean sum_of_sample_mean = get_sum_of_sample_mean( total_lengths, response_lengths, modified_response_masks, args.calculate_per_token_loss ) + else: + metric_masks = batch["loss_masks"] + + # compute pg_loss token-level stats (masked) + try: + flat_pg_mask = torch.cat(metric_masks).to(device=pg_loss.device) + pg_stats = Postprocessor.compute_masked_stats_safe( + pg_loss, flat_pg_mask, process_group=mpu.get_data_parallel_group() + ) + except Exception: + pg_stats = None pg_loss = sum_of_sample_mean(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) @@ -525,6 +548,19 @@ def vanilla_tis_function( "ppo_kl": ppo_kl.clone().detach(), } + # Add advantage stats if available + if adv_stats is not None: + reported_loss["advantage_mean"] = adv_stats["mean"].clone().detach() + reported_loss["advantage_std"] = adv_stats["std"].clone().detach() + reported_loss["advantage_max"] = adv_stats["max"].clone().detach() + reported_loss["advantage_min"] = adv_stats["min"].clone().detach() + + # Add pg_loss distributional stats if available + if pg_stats is not None: + reported_loss["pg_loss_max"] = pg_stats["max"].clone().detach() + reported_loss["pg_loss_min"] = pg_stats["min"].clone().detach() + reported_loss["pg_loss_std"] = pg_stats["std"].clone().detach() + if train_rollout_logprob_abs_diff is not None: reported_loss["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff.clone().detach() diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 2519f613c..cb4386fda 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -21,6 +21,7 @@ from miles.utils.metric_checker import MetricChecker from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix from miles.utils.misc import load_function +from miles.utils.postprocessor import Postprocessor from miles.utils.ray_utils import Box from miles.utils.tracking_utils import init_tracking from miles.utils.types import Sample @@ -505,6 +506,30 @@ def _compute_metrics_from_samples(args, samples): log_dict = {} log_dict |= dict_add_prefix(compute_statistics(response_lengths), f"response_len/") + # Add distributed-safe global stats for response lengths and rewards when possible + try: + if len(response_lengths) > 0: + vals = torch.tensor(response_lengths, dtype=torch.float32) + mask = torch.ones_like(vals, dtype=torch.bool) + stats = Postprocessor.compute_masked_stats_safe(vals, mask, process_group=None) + log_dict["response_len/global_mean"] = stats["mean"].item() + log_dict["response_len/global_std"] = stats["std"].item() + log_dict["response_len/global_max"] = stats["max"].item() + log_dict["response_len/global_min"] = stats["min"].item() + + # rewards + rewards = [sample.get_reward_value(args) for sample in samples] + if len(rewards) > 0: + rvals = torch.tensor(rewards, dtype=torch.float32) + rmask = torch.ones_like(rvals, dtype=torch.bool) + rstats = Postprocessor.compute_masked_stats_safe(rvals, rmask, process_group=None) + log_dict["reward/global_mean"] = rstats["mean"].item() + log_dict["reward/global_std"] = rstats["std"].item() + log_dict["reward/global_max"] = rstats["max"].item() + log_dict["reward/global_min"] = rstats["min"].item() + except Exception: + # Best-effort only; fall back to existing stats if anything goes wrong + pass log_dict |= _compute_zero_std_metrics(args, samples) log_dict |= _compute_spec_metrics(args, samples) log_dict |= _compute_reward_cat_metrics(args, samples) diff --git a/miles/utils/postprocessor.py b/miles/utils/postprocessor.py new file mode 100644 index 000000000..517cec67f --- /dev/null +++ b/miles/utils/postprocessor.py @@ -0,0 +1,102 @@ +from typing import Optional + +import torch +import torch.distributed as dist + + +class Postprocessor: + """Postprocessing helpers for rollout / loss metrics. + + This class centralizes distributed, masked statistics used when + reporting metrics such as advantage and per-token pg_loss distributions. + """ + + @staticmethod + def compute_global_masked_stats( + values: torch.Tensor, + mask: torch.Tensor, + process_group: Optional[dist.ProcessGroup] = None, + ) -> dict: + """ + Compute global mean, std, min and max over elements where `mask` is truthy, + aggregating across the provided process group. + + Returns dict with keys `mean`, `std`, `min`, `max` as torch tensors on + `values.device`. + """ + mask_bool = mask.bool() + + local_sum = (values * mask_bool).sum() + local_sum_sq = ((values**2) * mask_bool).sum() + local_count = mask_bool.sum().to(dtype=torch.float32) + + stats_tensor = torch.tensor([local_sum, local_sum_sq, local_count], device=values.device, dtype=torch.float32) + dist.all_reduce(stats_tensor, group=process_group) + + global_sum, global_sum_sq, global_count = stats_tensor + + if global_count.item() == 0: + zero = torch.tensor(0.0, device=values.device) + return { + "mean": zero, + "std": zero, + "min": torch.tensor(float("inf"), device=values.device), + "max": torch.tensor(float("-inf"), device=values.device), + } + + global_mean = global_sum / global_count + global_mean_sq = global_sum_sq / global_count + global_var = global_mean_sq - global_mean**2 + + if global_count.item() >= 2: + bessel = global_count / (global_count - 1) + global_var = global_var * bessel + + global_std = torch.sqrt(torch.clamp(global_var, min=0.0)) + + local_max = torch.where(mask_bool, values, torch.tensor(float("-inf"), device=values.device)) + local_min = torch.where(mask_bool, values, torch.tensor(float("inf"), device=values.device)) + + max_tensor = local_max.max() if local_max.numel() > 0 else torch.tensor(float("-inf"), device=values.device) + min_tensor = local_min.min() if local_min.numel() > 0 else torch.tensor(float("inf"), device=values.device) + + dist.all_reduce(max_tensor, op=dist.ReduceOp.MAX, group=process_group) + dist.all_reduce(min_tensor, op=dist.ReduceOp.MIN, group=process_group) + + return {"mean": global_mean, "std": global_std, "min": min_tensor, "max": max_tensor} + + @staticmethod + def compute_masked_stats_safe( + values: torch.Tensor, + mask: torch.Tensor, + process_group: Optional[dist.ProcessGroup] = None, + ) -> dict: + """ + Safe wrapper around `compute_global_masked_stats` that falls back to + local (non-distributed) statistics when torch.distributed is not + available or not initialized. This avoids runtime errors in contexts + (e.g., Ray rollout workers) where distributed backend isn't set up. + + Returns the same dict format: {"mean", "std", "min", "max"}. + """ + # If distributed isn't available/initialized, compute local masked stats. + if not dist.is_available() or not dist.is_initialized(): + mask_bool = mask.bool() + if mask_bool.numel() == 0 or mask_bool.sum().item() == 0: + zero = torch.tensor(0.0, device=values.device) + return { + "mean": zero, + "std": zero, + "min": torch.tensor(float("inf"), device=values.device), + "max": torch.tensor(float("-inf"), device=values.device), + } + + vals = values[mask_bool] + mean = vals.mean() + std = vals.std(unbiased=False) if vals.numel() > 1 else torch.tensor(0.0, device=values.device) + min_v = vals.min() + max_v = vals.max() + return {"mean": mean, "std": std, "min": min_v, "max": max_v} + + # Otherwise delegate to the distributed implementation + return Postprocessor.compute_global_masked_stats(values, mask, process_group=process_group) From 607728cb840112fe57369240579b40f6f0a636b8 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 28 Dec 2025 20:45:56 +0000 Subject: [PATCH 2/5] add unified logging Signed-off-by: root --- miles/backends/fsdp_utils/actor.py | 11 +++-- miles/backends/megatron_utils/data.py | 64 +++++++++++++++++++++++---- miles/backends/megatron_utils/loss.py | 7 ++- 3 files changed, 68 insertions(+), 14 deletions(-) diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 7ab3b9b74..7b277e95e 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -492,8 +492,9 @@ def _train_core(self, rollout_id: int, rollout_data_ref: Box) -> None: log_dict[f"rollout/{metric_key}_global_std"] = stats["std"].item() log_dict[f"rollout/{metric_key}_global_max"] = stats["max"].item() log_dict[f"rollout/{metric_key}_global_min"] = stats["min"].item() - except Exception: - pass + except Exception e: + logger.errors(f"error in computing global stats for {metric_key}: {e}") + if dist.get_rank() == 0: logger.info(f"rollout {rollout_id}: {log_dict}") log_dict["rollout/step"] = compute_rollout_step(self.args, rollout_id) @@ -564,7 +565,8 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): try: flat_adv_mask = torch.cat(loss_masks).to(device=advantages.device) adv_stats = Postprocessor.compute_masked_stats_safe(advantages, flat_adv_mask, process_group=self.dp_group) - except Exception: + except Exception as e: + logger.error(f"error in computing advantage stats: {e}") adv_stats = None ppo_kl = old_log_probs.to(device=log_probs.device) - log_probs @@ -614,7 +616,8 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): try: flat_pg_mask = torch.cat(loss_masks).to(device=pg_loss.device) pg_stats = Postprocessor.compute_masked_stats_safe(pg_loss, flat_pg_mask, process_group=self.dp_group) - except Exception: + except Exception as e: + logger.error(f"error in computing pg_loss stats: {e}") pg_stats = None assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented" diff --git a/miles/backends/megatron_utils/data.py b/miles/backends/megatron_utils/data.py index e5a0ff915..8f54490e0 100644 --- a/miles/backends/megatron_utils/data.py +++ b/miles/backends/megatron_utils/data.py @@ -1,4 +1,5 @@ import logging +import math from argparse import Namespace from typing import Optional, Sequence, Union @@ -107,7 +108,7 @@ def gather_log_data( dp_size = mpu.get_data_parallel_world_size(with_context_parallel=True) gathered_log_dict = [None] * dp_size - # Not sure if this will be a performance bottleneck. + # Gather per-rank dicts to the DP source rank dist.gather_object( log_dict, gathered_log_dict, @@ -115,9 +116,56 @@ def gather_log_data( group=mpu.get_data_parallel_group_gloo(with_context_parallel=True), ) - reduced_log_dict = { - f"{metric_name}/{key}": sum([d[key] for d in gathered_log_dict]) / dp_size for key in log_dict - } + reduced_log_dict = {} + + # For keys that already represent global values (contain "_global_"), + # avoid averaging them across ranks. Instead, take the first value and + # warn if ranks disagree (this is the minimal safe behavior per option 3). + for key in log_dict: + try: + vals = [d[key] for d in gathered_log_dict] + except Exception: + # Missing key in some ranks; skip + continue + + if "_global_" in key: + # Numeric comparison: ensure values are (nearly) identical across ranks + first = vals[0] + consistent = True + for v in vals[1:]: + try: + if ( + isinstance(first, float) + or isinstance(v, float) + or isinstance(first, int) + or isinstance(v, int) + ): + if not math.isclose(float(first), float(v), rel_tol=1e-6, abs_tol=1e-9): + consistent = False + break + else: + if first != v: + consistent = False + break + except Exception: + consistent = False + break + + if not consistent: + logger.warning( + f"Inconsistent per-rank values for global key '{key}' at rollout {rollout_id}; using first rank's value." + ) + + reduced_log_dict[f"{metric_name}/{key}"] = first + else: + # Default behavior: arithmetic mean across ranks + try: + numeric_vals = [float(v) for v in vals] + reduced_log_dict[f"{metric_name}/{key}"] = sum(numeric_vals) / dp_size + except Exception: + # Fallback: keep first + reduced_log_dict[f"{metric_name}/{key}"] = vals[0] + logger.info(f"{metric_name} {rollout_id}: {reduced_log_dict}") # Calculate step once to avoid duplication @@ -342,17 +390,17 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc stats = Postprocessor.compute_masked_stats_safe( concatenated, flat_mask, process_group=mpu.get_data_parallel_group() ) - # Attach extra keys with a rollout/ prefix; gather_log_data will prefix again + # Attach extra keys with a rollout/ prefix, gather_log_data will prefix again val_stats = { f"{key}_global_mean": stats["mean"].item(), f"{key}_global_std": stats["std"].item(), f"{key}_global_max": stats["max"].item(), f"{key}_global_min": stats["min"].item(), } - # Merge these into log_dict (they will be reduced/gathered later) + # Merge these into log_dict, they will be reduced/gathered later log_dict.update(val_stats) - except Exception: - pass + except Exception as e: + logger.error(f"error in computing global stats for {key}: {e}") else: val = val.mean() * cp_size else: diff --git a/miles/backends/megatron_utils/loss.py b/miles/backends/megatron_utils/loss.py index e71c097f1..1136eb216 100644 --- a/miles/backends/megatron_utils/loss.py +++ b/miles/backends/megatron_utils/loss.py @@ -399,8 +399,10 @@ def policy_loss_function( adv_stats = Postprocessor.compute_masked_stats_safe( advantages, flat_adv_mask, process_group=mpu.get_data_parallel_group() ) - except Exception: + except Exception as e: + logger.error(f"error in computing advantage stats: {e}") adv_stats = None + old_log_probs = batch["rollout_log_probs"] if args.use_rollout_logprobs else batch["log_probs"] response_lengths = batch["response_lengths"] @@ -505,7 +507,8 @@ def vanilla_tis_function( pg_stats = Postprocessor.compute_masked_stats_safe( pg_loss, flat_pg_mask, process_group=mpu.get_data_parallel_group() ) - except Exception: + except Exception as e: + logger.error(f"error in computing pg_loss stats: {e}") pg_stats = None pg_loss = sum_of_sample_mean(pg_loss) From 7349975170717477939d2d2b4d7bd6c1ffcf7173 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 28 Dec 2025 21:00:27 +0000 Subject: [PATCH 3/5] rename files Signed-off-by: root --- miles/backends/fsdp_utils/actor.py | 8 ++++---- miles/backends/megatron_utils/data.py | 4 ++-- miles/backends/megatron_utils/loss.py | 6 +++--- miles/ray/rollout.py | 6 +++--- miles/utils/{postprocessor.py => rolloutpostprocessor.py} | 4 ++-- 5 files changed, 14 insertions(+), 14 deletions(-) rename miles/utils/{postprocessor.py => rolloutpostprocessor.py} (96%) diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 7b277e95e..e603370ef 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -22,9 +22,9 @@ from miles.utils.distributed_utils import get_gloo_group from miles.utils.memory_utils import clear_memory, print_memory from miles.utils.metric_utils import compute_rollout_step -from miles.utils.postprocessor import Postprocessor from miles.utils.ppo_utils import compute_approx_kl, compute_policy_loss from miles.utils.ray_utils import Box +from miles.utils.rolloutpostprocessor import RolloutPostprocessor from miles.utils.timer import Timer, inverse_timer, timer from miles.utils.tracking_utils import init_tracking @@ -485,7 +485,7 @@ def _train_core(self, rollout_id: int, rollout_data_ref: Box) -> None: try: flat_metric = torch.cat(metric_tensors).to(device=torch.cuda.current_device()) flat_mask = torch.cat(mask_tensors).to(device=torch.cuda.current_device()) - stats = Postprocessor.compute_masked_stats_safe( + stats = RolloutPostprocessor.compute_masked_stats_safe( flat_metric, flat_mask, process_group=self.dp_group ) log_dict[f"rollout/{metric_key}_global_mean"] = stats["mean"].item() @@ -564,7 +564,7 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): # compute global advantage stats (masked) try: flat_adv_mask = torch.cat(loss_masks).to(device=advantages.device) - adv_stats = Postprocessor.compute_masked_stats_safe(advantages, flat_adv_mask, process_group=self.dp_group) + adv_stats = RolloutPostprocessor.compute_masked_stats_safe(advantages, flat_adv_mask, process_group=self.dp_group) except Exception as e: logger.error(f"error in computing advantage stats: {e}") adv_stats = None @@ -615,7 +615,7 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): # compute pg_loss masked stats before reduction try: flat_pg_mask = torch.cat(loss_masks).to(device=pg_loss.device) - pg_stats = Postprocessor.compute_masked_stats_safe(pg_loss, flat_pg_mask, process_group=self.dp_group) + pg_stats = RolloutPostprocessor.compute_masked_stats_safe(pg_loss, flat_pg_mask, process_group=self.dp_group) except Exception as e: logger.error(f"error in computing pg_loss stats: {e}") pg_stats = None diff --git a/miles/backends/megatron_utils/data.py b/miles/backends/megatron_utils/data.py index 8f54490e0..a9c635a4d 100644 --- a/miles/backends/megatron_utils/data.py +++ b/miles/backends/megatron_utils/data.py @@ -14,7 +14,7 @@ from miles.utils.data import get_minimum_num_micro_batch_size from miles.utils.flops_utils import calculate_fwd_flops from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step -from miles.utils.postprocessor import Postprocessor +from miles.utils.rolloutpostprocessor import RolloutPostprocessor from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions from miles.utils.types import RolloutBatch @@ -387,7 +387,7 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc [m for lm in loss_masks for m in lm], device=concatenated.device ) - stats = Postprocessor.compute_masked_stats_safe( + stats = RolloutPostprocessor.compute_masked_stats_safe( concatenated, flat_mask, process_group=mpu.get_data_parallel_group() ) # Attach extra keys with a rollout/ prefix, gather_log_data will prefix again diff --git a/miles/backends/megatron_utils/loss.py b/miles/backends/megatron_utils/loss.py index 1136eb216..961133f03 100644 --- a/miles/backends/megatron_utils/loss.py +++ b/miles/backends/megatron_utils/loss.py @@ -7,7 +7,6 @@ from miles.utils.distributed_utils import distributed_masked_whiten from miles.utils.misc import load_function -from miles.utils.postprocessor import Postprocessor from miles.utils.ppo_utils import ( calculate_log_probs_and_entropy, compute_approx_kl, @@ -17,6 +16,7 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) +from miles.utils.rolloutpostprocessor import RolloutPostprocessor from miles.utils.types import RolloutBatch from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean @@ -396,7 +396,7 @@ def policy_loss_function( # compute advantage-level global stats (masked by loss_masks) try: flat_adv_mask = torch.cat(batch["loss_masks"]).to(device=advantages.device) - adv_stats = Postprocessor.compute_masked_stats_safe( + adv_stats = RolloutPostprocessor.compute_masked_stats_safe( advantages, flat_adv_mask, process_group=mpu.get_data_parallel_group() ) except Exception as e: @@ -504,7 +504,7 @@ def vanilla_tis_function( # compute pg_loss token-level stats (masked) try: flat_pg_mask = torch.cat(metric_masks).to(device=pg_loss.device) - pg_stats = Postprocessor.compute_masked_stats_safe( + pg_stats = RolloutPostprocessor.compute_masked_stats_safe( pg_loss, flat_pg_mask, process_group=mpu.get_data_parallel_group() ) except Exception as e: diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index cb4386fda..7638537a7 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -21,8 +21,8 @@ from miles.utils.metric_checker import MetricChecker from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix from miles.utils.misc import load_function -from miles.utils.postprocessor import Postprocessor from miles.utils.ray_utils import Box +from miles.utils.rolloutpostprocessor import RolloutPostprocessor from miles.utils.tracking_utils import init_tracking from miles.utils.types import Sample @@ -511,7 +511,7 @@ def _compute_metrics_from_samples(args, samples): if len(response_lengths) > 0: vals = torch.tensor(response_lengths, dtype=torch.float32) mask = torch.ones_like(vals, dtype=torch.bool) - stats = Postprocessor.compute_masked_stats_safe(vals, mask, process_group=None) + stats = RolloutPostprocessor.compute_masked_stats_safe(vals, mask, process_group=None) log_dict["response_len/global_mean"] = stats["mean"].item() log_dict["response_len/global_std"] = stats["std"].item() log_dict["response_len/global_max"] = stats["max"].item() @@ -522,7 +522,7 @@ def _compute_metrics_from_samples(args, samples): if len(rewards) > 0: rvals = torch.tensor(rewards, dtype=torch.float32) rmask = torch.ones_like(rvals, dtype=torch.bool) - rstats = Postprocessor.compute_masked_stats_safe(rvals, rmask, process_group=None) + rstats = RolloutPostprocessor.compute_masked_stats_safe(rvals, rmask, process_group=None) log_dict["reward/global_mean"] = rstats["mean"].item() log_dict["reward/global_std"] = rstats["std"].item() log_dict["reward/global_max"] = rstats["max"].item() diff --git a/miles/utils/postprocessor.py b/miles/utils/rolloutpostprocessor.py similarity index 96% rename from miles/utils/postprocessor.py rename to miles/utils/rolloutpostprocessor.py index 517cec67f..8ada6b050 100644 --- a/miles/utils/postprocessor.py +++ b/miles/utils/rolloutpostprocessor.py @@ -4,7 +4,7 @@ import torch.distributed as dist -class Postprocessor: +class RolloutPostprocessor: """Postprocessing helpers for rollout / loss metrics. This class centralizes distributed, masked statistics used when @@ -99,4 +99,4 @@ def compute_masked_stats_safe( return {"mean": mean, "std": std, "min": min_v, "max": max_v} # Otherwise delegate to the distributed implementation - return Postprocessor.compute_global_masked_stats(values, mask, process_group=process_group) + return RolloutPostprocessor.compute_global_masked_stats(values, mask, process_group=process_group) From 30e8eb6bb7de5e527e77f407aaca96cb2134b5dd Mon Sep 17 00:00:00 2001 From: root Date: Sun, 28 Dec 2025 21:26:40 +0000 Subject: [PATCH 4/5] further consolidate Signed-off-by: root --- miles/backends/megatron_utils/data.py | 48 ++++++--- miles/utils/rolloutpostprocessor.py | 148 ++++++++++++++++++++++++++ 2 files changed, 183 insertions(+), 13 deletions(-) diff --git a/miles/backends/megatron_utils/data.py b/miles/backends/megatron_utils/data.py index a9c635a4d..9a60ee594 100644 --- a/miles/backends/megatron_utils/data.py +++ b/miles/backends/megatron_utils/data.py @@ -387,20 +387,34 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc [m for lm in loss_masks for m in lm], device=concatenated.device ) - stats = RolloutPostprocessor.compute_masked_stats_safe( - concatenated, flat_mask, process_group=mpu.get_data_parallel_group() + # Compute local per-rank aggregates (sum, sumsq, count, min, max) + mask_bool = flat_mask.bool() + if mask_bool.numel() == 0 or mask_bool.sum().item() == 0: + local_count = 0 + local_sum = 0.0 + local_sumsq = 0.0 + local_min = float("inf") + local_max = float("-inf") + else: + masked_vals = concatenated[mask_bool] + local_count = int(masked_vals.numel()) + local_sum = float(masked_vals.sum().item()) + local_sumsq = float((masked_vals * masked_vals).sum().item()) + local_min = float(masked_vals.min().item()) + local_max = float(masked_vals.max().item()) + + # Emit per-rank aggregates for pooled reduction + log_dict.update( + { + f"{key}_agg_sum": local_sum, + f"{key}_agg_sumsq": local_sumsq, + f"{key}_agg_count": local_count, + f"{key}_agg_min": local_min, + f"{key}_agg_max": local_max, + } ) - # Attach extra keys with a rollout/ prefix, gather_log_data will prefix again - val_stats = { - f"{key}_global_mean": stats["mean"].item(), - f"{key}_global_std": stats["std"].item(), - f"{key}_global_max": stats["max"].item(), - f"{key}_global_min": stats["min"].item(), - } - # Merge these into log_dict, they will be reduced/gathered later - log_dict.update(val_stats) except Exception as e: - logger.error(f"error in computing global stats for {key}: {e}") + logger.error(f"error in preparing aggregates for {key}: {e}") else: val = val.mean() * cp_size else: @@ -411,7 +425,15 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc raise ValueError(f"Unsupported type: {type(val)}") log_dict[key] = val.item() if isinstance(val, torch.Tensor) else val - reduced_log_dict = gather_log_data("rollout", args, rollout_id, log_dict) + # Aggregate per-rank metrics using all-reduce and log on DP source rank + reduced_log_dict = RolloutPostprocessor.aggregate_and_log( + log_dict, + args, + rollout_id, + process_group=mpu.get_data_parallel_group(), + dp_src_rank=mpu.get_data_parallel_src_rank(with_context_parallel=True), + only_log_on_src=True, + ) if args.ci_test and reduced_log_dict is not None: if ( rollout_id == 0 diff --git a/miles/utils/rolloutpostprocessor.py b/miles/utils/rolloutpostprocessor.py index 8ada6b050..079dc6801 100644 --- a/miles/utils/rolloutpostprocessor.py +++ b/miles/utils/rolloutpostprocessor.py @@ -100,3 +100,151 @@ def compute_masked_stats_safe( # Otherwise delegate to the distributed implementation return RolloutPostprocessor.compute_global_masked_stats(values, mask, process_group=process_group) + + @staticmethod + def aggregate_and_log( + log_dict: dict, + args, + rollout_id: int, + process_group: Optional[dist.ProcessGroup] = None, + dp_src_rank: int = 0, + only_log_on_src: bool = True, + ) -> Optional[dict]: + """ + Aggregate per-rank metrics into pooled/global metrics and log them. + + Expected convention for pooled fields: callers should emit per-rank + aggregates with suffixes: `_agg_sum`, `_agg_sumsq`, `_agg_count`, + `_agg_min`, `_agg_max` for fields that require pooled mean/std/min/max. + + Non-aggregate scalar keys (plain numeric) will be averaged across + ranks via all-reduce sum/mean. + + Returns the reduced dict (with keys prefixed by `rollout/`) on all + ranks. Logging via `tracking_utils.log` happens on the DP source rank + when `only_log_on_src` is True. + """ + # Fast path: non-distributed -> compute locally + if not dist.is_available() or not dist.is_initialized() or process_group is None: + reduced: dict = {} + # Handle aggregate bases + agg_bases = {k[: -len("_agg_sum")] for k in log_dict.keys() if k.endswith("_agg_sum")} + for base in agg_bases: + s = float(log_dict.get(f"{base}_agg_sum", 0.0)) + ssq = float(log_dict.get(f"{base}_agg_sumsq", 0.0)) + cnt = int(log_dict.get(f"{base}_agg_count", 0)) + mn = float(log_dict.get(f"{base}_agg_min", float("inf"))) + mx = float(log_dict.get(f"{base}_agg_max", float("-inf"))) + if cnt > 0: + mean = s / cnt + var = ssq / cnt - mean * mean + if cnt >= 2: + var = var * (cnt / (cnt - 1)) + std = float(max(var, 0.0) ** 0.5) + else: + mean = 0.0 + std = 0.0 + mn = 0.0 + mx = 0.0 + reduced[f"rollout/{base}_global_mean"] = mean + reduced[f"rollout/{base}_global_std"] = std + reduced[f"rollout/{base}_global_min"] = mn + reduced[f"rollout/{base}_global_max"] = mx + + # Average non-aggregate numeric keys + non_agg_keys = [k for k in log_dict.keys() if not any(k.startswith(b) for b in agg_bases)] + for key in non_agg_keys: + v = log_dict[key] + try: + reduced[f"rollout/{key}"] = float(v) + except Exception: + reduced[f"rollout/{key}"] = v + + # Add step if available + reduced["rollout/step"] = compute_rollout_step(args, rollout_id) + if not only_log_on_src: + tracking_utils.log(args, reduced, step_key="rollout/step") + return reduced + + # Distributed path: perform all-reduces per aggregate + dp_size = dist.get_world_size(group=process_group) + + gathered_reduced: dict = {} + + # Find aggregate bases + agg_bases = {k[: -len("_agg_sum")] for k in log_dict.keys() if k.endswith("_agg_sum")} + + # For each aggregate base, all-reduce sum, sumsq, count and reduce min/max + for base in agg_bases: + local_sum = torch.tensor(float(log_dict.get(f"{base}_agg_sum", 0.0)), dtype=torch.float64, device="cpu") + local_sumsq = torch.tensor( + float(log_dict.get(f"{base}_agg_sumsq", 0.0)), dtype=torch.float64, device="cpu" + ) + local_count = torch.tensor(int(log_dict.get(f"{base}_agg_count", 0)), dtype=torch.float64, device="cpu") + # Use CPU tensors for small reductions to avoid GPU sync + dist.all_reduce(local_sum, op=dist.ReduceOp.SUM, group=process_group) + dist.all_reduce(local_sumsq, op=dist.ReduceOp.SUM, group=process_group) + dist.all_reduce(local_count, op=dist.ReduceOp.SUM, group=process_group) + + total_count = int(local_count.item()) + total_sum = float(local_sum.item()) + total_sumsq = float(local_sumsq.item()) + + # Reduce min/max + local_min = torch.tensor( + float(log_dict.get(f"{base}_agg_min", float("inf"))), dtype=torch.float64, device="cpu" + ) + local_max = torch.tensor( + float(log_dict.get(f"{base}_agg_max", float("-inf"))), dtype=torch.float64, device="cpu" + ) + dist.all_reduce(local_min, op=dist.ReduceOp.MIN, group=process_group) + dist.all_reduce(local_max, op=dist.ReduceOp.MAX, group=process_group) + + if total_count > 0: + mean = total_sum / total_count + var = total_sumsq / total_count - mean * mean + if total_count >= 2: + var = var * (total_count / (total_count - 1)) + std = float(max(var, 0.0) ** 0.5) + mn = float(local_min.item()) + mx = float(local_max.item()) + else: + mean = 0.0 + std = 0.0 + mn = 0.0 + mx = 0.0 + + gathered_reduced[f"rollout/{base}_global_mean"] = mean + gathered_reduced[f"rollout/{base}_global_std"] = std + gathered_reduced[f"rollout/{base}_global_min"] = mn + gathered_reduced[f"rollout/{base}_global_max"] = mx + + # Handle non-aggregate numeric keys: all-reduce sum then divide by dp_size + non_agg_keys = [k for k in log_dict.keys() if not any(k.startswith(b) for b in agg_bases)] + for key in non_agg_keys: + v = log_dict[key] + try: + t = torch.tensor(float(v), dtype=torch.float64, device="cpu") + dist.all_reduce(t, op=dist.ReduceOp.SUM, group=process_group) + gathered_reduced[f"rollout/{key}"] = float(t.item()) / float(dp_size) + except Exception: + # Non-numeric -> pick first-rank's value via gather_object + vals = [None] * dp_size + dist.gather_object( + log_dict[key], + vals if dist.get_rank() == dp_src_rank else None, + dst=dp_src_rank, + group=process_group, + ) + if dist.get_rank() == dp_src_rank: + gathered_reduced[f"rollout/{key}"] = vals[0] + + # Add rollout step + gathered_reduced["rollout/step"] = compute_rollout_step(args, rollout_id) + + # Logging only on source rank by default + rank = dist.get_rank(group=process_group) if hasattr(dist, "get_rank") else dist.get_rank() + if not only_log_on_src or rank == dp_src_rank: + tracking_utils.log(args, gathered_reduced, step_key="rollout/step") + + return gathered_reduced From e19f0f0c52c455c5b34618aa65a6354e44007895 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 28 Dec 2025 21:36:32 +0000 Subject: [PATCH 5/5] further consolidate Signed-off-by: root --- miles/backends/megatron_utils/data.py | 58 +++------------------------ 1 file changed, 5 insertions(+), 53 deletions(-) diff --git a/miles/backends/megatron_utils/data.py b/miles/backends/megatron_utils/data.py index 9a60ee594..00996af77 100644 --- a/miles/backends/megatron_utils/data.py +++ b/miles/backends/megatron_utils/data.py @@ -1,5 +1,4 @@ import logging -import math from argparse import Namespace from typing import Optional, Sequence, Union @@ -95,7 +94,7 @@ def gather_log_data( args: Namespace, rollout_id: int, log_dict: dict[str, float], -) -> Optional[dict[str, float]]: +) -> dict[str, float] | None: """ Gather per-rank metrics, reduce by mean on the DP source rank, and log. @@ -108,7 +107,7 @@ def gather_log_data( dp_size = mpu.get_data_parallel_world_size(with_context_parallel=True) gathered_log_dict = [None] * dp_size - # Gather per-rank dicts to the DP source rank + # Not sure if this will be a performance bottleneck. dist.gather_object( log_dict, gathered_log_dict, @@ -116,56 +115,9 @@ def gather_log_data( group=mpu.get_data_parallel_group_gloo(with_context_parallel=True), ) - reduced_log_dict = {} - - # For keys that already represent global values (contain "_global_"), - # avoid averaging them across ranks. Instead, take the first value and - # warn if ranks disagree (this is the minimal safe behavior per option 3). - for key in log_dict: - try: - vals = [d[key] for d in gathered_log_dict] - except Exception: - # Missing key in some ranks; skip - continue - - if "_global_" in key: - # Numeric comparison: ensure values are (nearly) identical across ranks - first = vals[0] - consistent = True - for v in vals[1:]: - try: - if ( - isinstance(first, float) - or isinstance(v, float) - or isinstance(first, int) - or isinstance(v, int) - ): - if not math.isclose(float(first), float(v), rel_tol=1e-6, abs_tol=1e-9): - consistent = False - break - else: - if first != v: - consistent = False - break - except Exception: - consistent = False - break - - if not consistent: - logger.warning( - f"Inconsistent per-rank values for global key '{key}' at rollout {rollout_id}; using first rank's value." - ) - - reduced_log_dict[f"{metric_name}/{key}"] = first - else: - # Default behavior: arithmetic mean across ranks - try: - numeric_vals = [float(v) for v in vals] - reduced_log_dict[f"{metric_name}/{key}"] = sum(numeric_vals) / dp_size - except Exception: - # Fallback: keep first - reduced_log_dict[f"{metric_name}/{key}"] = vals[0] - + reduced_log_dict = { + f"{metric_name}/{key}": sum([d[key] for d in gathered_log_dict]) / dp_size for key in log_dict + } logger.info(f"{metric_name} {rollout_id}: {reduced_log_dict}") # Calculate step once to avoid duplication