Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 53 additions & 12 deletions miles/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from miles.utils.ppo_utils import compute_approx_kl, compute_gspo_kl, compute_opsm_mask, compute_policy_loss
from miles.utils.processing_utils import load_processor, load_tokenizer
from miles.utils.ray_utils import Box
from miles.utils.rolloutpostprocessor import RolloutPostprocessor
from miles.utils.timer import Timer, inverse_timer, timer
from miles.utils.tracking_utils import init_tracking

Expand Down Expand Up @@ -502,20 +503,45 @@ def _log_rollout_data(self, rollout_id: int, rollout_data, packed_batches):
for metric_key in ["log_probs", "rollout_log_probs", "ref_log_probs", "advantages", "returns"]:
if metric_key not in packed_batches[0]:
continue
# Keep existing per-sample mean aggregation for backwards compatibility
val = torch.tensor([0.0], device=torch.cuda.current_device())
for _mbs_id, batches in enumerate(packed_batches):

# Also collect flat tensors to compute global masked stats via Postprocessor
metric_tensors = []
mask_tensors = []

for mbs_id, batches in enumerate(packed_batches):
unpacked_batches = unpack_sequences(batches)
for unpacked_batch in unpacked_batches:
if isinstance(unpacked_batch[metric_key], torch.Tensor):
loss_masks_tensor = unpacked_batch["loss_masks"].to(device=torch.cuda.current_device())
metric_tensor = unpacked_batch[metric_key].to(device=torch.cuda.current_device())
val += (metric_tensor * loss_masks_tensor).sum() / loss_masks_tensor.sum().clamp_min(1)
metric_tensors.append(metric_tensor)
mask_tensors.append(loss_masks_tensor)
else:
val += unpacked_batch[metric_key]

dist.all_reduce(val, op=dist.ReduceOp.SUM, group=self.dp_group)
log_dict[f"rollout/{metric_key}"] = (
val / (self.args.n_samples_per_prompt * self.args.rollout_batch_size)
).item()

# Compute and attach global masked statistics when possible
if metric_tensors and mask_tensors:
try:
flat_metric = torch.cat(metric_tensors).to(device=torch.cuda.current_device())
flat_mask = torch.cat(mask_tensors).to(device=torch.cuda.current_device())
stats = RolloutPostprocessor.compute_masked_stats_safe(
flat_metric, flat_mask, process_group=self.dp_group
)
log_dict[f"rollout/{metric_key}_global_mean"] = stats["mean"].item()
log_dict[f"rollout/{metric_key}_global_std"] = stats["std"].item()
log_dict[f"rollout/{metric_key}_global_max"] = stats["max"].item()
log_dict[f"rollout/{metric_key}_global_min"] = stats["min"].item()
except Exception e:
logger.errors(f"error in computing global stats for {metric_key}: {e}")

if dist.get_rank() == 0:
logger.info(f"rollout {rollout_id}: {log_dict}")
log_dict["rollout/step"] = compute_rollout_step(self.args, rollout_id)
Expand Down Expand Up @@ -618,17 +644,14 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum):
response_lengths = [batch["response_lengths"] for batch in unpacked_batches]

advantages = advantages.to(device=log_probs.device)
old_log_probs = old_log_probs.to(device=log_probs.device)
ppo_kl = old_log_probs - log_probs

if self.args.use_opsm:
opsm_mask, opsm_clipfrac = compute_opsm_mask(
args=self.args,
full_log_probs=[batch["cur_log_probs"] for batch in unpacked_batches],
full_old_log_probs=[batch[old_log_prob_key] for batch in unpacked_batches],
advantages=[batch["advantages"] for batch in unpacked_batches],
loss_masks=loss_masks,
)
# compute global advantage stats (masked)
try:
flat_adv_mask = torch.cat(loss_masks).to(device=advantages.device)
adv_stats = RolloutPostprocessor.compute_masked_stats_safe(advantages, flat_adv_mask, process_group=self.dp_group)
except Exception as e:
logger.error(f"error in computing advantage stats: {e}")
adv_stats = None
ppo_kl = old_log_probs.to(device=log_probs.device) - log_probs

if self.args.advantage_estimator == "gspo":
ppo_kl = compute_gspo_kl(
Expand Down Expand Up @@ -670,6 +693,14 @@ def _has_rollout_log_probs(batch) -> bool:

pg_loss = pg_loss * tis_clip

# compute pg_loss masked stats before reduction
try:
flat_pg_mask = torch.cat(loss_masks).to(device=pg_loss.device)
pg_stats = RolloutPostprocessor.compute_masked_stats_safe(pg_loss, flat_pg_mask, process_group=self.dp_group)
except Exception as e:
logger.error(f"error in computing pg_loss stats: {e}")
pg_stats = None

assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented"
pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
Expand Down Expand Up @@ -711,6 +742,16 @@ def _has_rollout_log_probs(batch) -> bool:
"entropy_loss": entropy_loss.detach(),
}

if adv_stats is not None:
reported["advantage_mean"] = adv_stats["mean"].detach()
reported["advantage_std"] = adv_stats["std"].detach()
reported["advantage_max"] = adv_stats["max"].detach()
reported["advantage_min"] = adv_stats["min"].detach()

if pg_stats is not None:
reported["pg_loss_max"] = pg_stats["max"].detach()
reported["pg_loss_min"] = pg_stats["min"].detach()
reported["pg_loss_std"] = pg_stats["std"].detach()
if train_rollout_logprob_abs_diff is not None:
reported["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff

Expand Down
61 changes: 52 additions & 9 deletions miles/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from miles.utils.data import get_minimum_num_micro_batch_size
from miles.utils.flops_utils import calculate_fwd_flops
from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step
from miles.utils.rolloutpostprocessor import RolloutPostprocessor
from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions
from miles.utils.types import RolloutBatch

Expand Down Expand Up @@ -340,14 +341,48 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc
# modified in place and will cause problem for the next rollout.
val = torch.cat(val).clone().detach()
if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages", "values"]:
sum_of_sample_mean = get_sum_of_sample_mean(
total_lengths,
response_lengths,
loss_masks,
qkv_format=args.qkv_format,
max_seq_lens=max_seq_lens,
)
val = cp_size * sum_of_sample_mean(val) / len(loss_masks)
# Keep existing per-sample mean behavior
concatenated = val.clone().detach()
sum_of_sample_mean = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks)
val = cp_size * sum_of_sample_mean(concatenated) / len(loss_masks)
# Also compute global masked stats and attach additional keys
try:
# Prepare flat mask tensor
if isinstance(loss_masks[0], torch.Tensor):
flat_mask = torch.cat(loss_masks).to(device=concatenated.device)
else:
flat_mask = torch.tensor(
[m for lm in loss_masks for m in lm], device=concatenated.device
)

# Compute local per-rank aggregates (sum, sumsq, count, min, max)
mask_bool = flat_mask.bool()
if mask_bool.numel() == 0 or mask_bool.sum().item() == 0:
local_count = 0
local_sum = 0.0
local_sumsq = 0.0
local_min = float("inf")
local_max = float("-inf")
else:
masked_vals = concatenated[mask_bool]
local_count = int(masked_vals.numel())
local_sum = float(masked_vals.sum().item())
local_sumsq = float((masked_vals * masked_vals).sum().item())
local_min = float(masked_vals.min().item())
local_max = float(masked_vals.max().item())

# Emit per-rank aggregates for pooled reduction
log_dict.update(
{
f"{key}_agg_sum": local_sum,
f"{key}_agg_sumsq": local_sumsq,
f"{key}_agg_count": local_count,
f"{key}_agg_min": local_min,
f"{key}_agg_max": local_max,
}
)
except Exception as e:
logger.error(f"error in preparing aggregates for {key}: {e}")
else:
val = val.mean() * cp_size
else:
Expand All @@ -358,7 +393,15 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc
raise ValueError(f"Unsupported type: {type(val)} for key: {key}")
log_dict[key] = val.item() if isinstance(val, torch.Tensor) else val

reduced_log_dict = gather_log_data("rollout", args, rollout_id, log_dict)
# Aggregate per-rank metrics using all-reduce and log on DP source rank
reduced_log_dict = RolloutPostprocessor.aggregate_and_log(
log_dict,
args,
rollout_id,
process_group=mpu.get_data_parallel_group(),
dp_src_rank=mpu.get_data_parallel_src_rank(with_context_parallel=True),
only_log_on_src=True,
)
if args.ci_test and reduced_log_dict is not None:
if (
rollout_id == 0
Expand Down
39 changes: 39 additions & 0 deletions miles/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_reinforce_plus_plus_baseline_advantages,
get_reinforce_plus_plus_returns,
)
from miles.utils.rolloutpostprocessor import RolloutPostprocessor
from miles.utils.types import RolloutBatch

from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean
Expand Down Expand Up @@ -465,6 +466,16 @@ def policy_loss_function(
are enabled.
"""
advantages = torch.cat(batch["advantages"], dim=0)
# compute advantage-level global stats (masked by loss_masks)
try:
flat_adv_mask = torch.cat(batch["loss_masks"]).to(device=advantages.device)
adv_stats = RolloutPostprocessor.compute_masked_stats_safe(
advantages, flat_adv_mask, process_group=mpu.get_data_parallel_group()
)
except Exception as e:
logger.error(f"error in computing advantage stats: {e}")
adv_stats = None

old_log_probs = batch["rollout_log_probs"] if args.use_rollout_logprobs else batch["log_probs"]

response_lengths = batch["response_lengths"]
Expand Down Expand Up @@ -553,6 +564,9 @@ def policy_loss_function(
tis_func = vanilla_tis_function
pg_loss, modified_response_masks, tis_metrics = tis_func(**tis_kwargs)

# if TIS modified masks, use them for metric computation
metric_masks = modified_response_masks

# [decouple IS and rejection] Rebuild sum_of_sample_mean with modified_response_masks for denominator correction
# modified_response_masks will be sliced with cp in get_sum_of_sample_mean
sum_of_sample_mean = get_sum_of_sample_mean(
Expand All @@ -563,6 +577,18 @@ def policy_loss_function(
args.qkv_format,
batch.get("max_seq_lens", None),
)
else:
metric_masks = batch["loss_masks"]

# compute pg_loss token-level stats (masked)
try:
flat_pg_mask = torch.cat(metric_masks).to(device=pg_loss.device)
pg_stats = RolloutPostprocessor.compute_masked_stats_safe(
pg_loss, flat_pg_mask, process_group=mpu.get_data_parallel_group()
)
except Exception as e:
logger.error(f"error in computing pg_loss stats: {e}")
pg_stats = None

pg_loss = sum_of_sample_mean(pg_loss)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)
Expand Down Expand Up @@ -608,6 +634,19 @@ def policy_loss_function(
"ppo_kl": ppo_kl.clone().detach(),
}

# Add advantage stats if available
if adv_stats is not None:
reported_loss["advantage_mean"] = adv_stats["mean"].clone().detach()
reported_loss["advantage_std"] = adv_stats["std"].clone().detach()
reported_loss["advantage_max"] = adv_stats["max"].clone().detach()
reported_loss["advantage_min"] = adv_stats["min"].clone().detach()

# Add pg_loss distributional stats if available
if pg_stats is not None:
reported_loss["pg_loss_max"] = pg_stats["max"].clone().detach()
reported_loss["pg_loss_min"] = pg_stats["min"].clone().detach()
reported_loss["pg_loss_std"] = pg_stats["std"].clone().detach()

if train_rollout_logprob_abs_diff is not None:
reported_loss["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff.clone().detach()

Expand Down
26 changes: 26 additions & 0 deletions miles/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix
from miles.utils.misc import load_function
from miles.utils.ray_utils import Box
from miles.utils.rolloutpostprocessor import RolloutPostprocessor
from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions
from miles.utils.tracking_utils import init_tracking
from miles.utils.types import Sample
Expand Down Expand Up @@ -608,6 +609,31 @@ def compute_metrics_from_samples(args, samples):
response_lengths = [sample.effective_response_length for sample in samples]

log_dict = {}
log_dict |= dict_add_prefix(compute_statistics(response_lengths), f"response_len/")
# Add distributed-safe global stats for response lengths and rewards when possible
try:
if len(response_lengths) > 0:
vals = torch.tensor(response_lengths, dtype=torch.float32)
mask = torch.ones_like(vals, dtype=torch.bool)
stats = RolloutPostprocessor.compute_masked_stats_safe(vals, mask, process_group=None)
log_dict["response_len/global_mean"] = stats["mean"].item()
log_dict["response_len/global_std"] = stats["std"].item()
log_dict["response_len/global_max"] = stats["max"].item()
log_dict["response_len/global_min"] = stats["min"].item()

# rewards
rewards = [sample.get_reward_value(args) for sample in samples]
if len(rewards) > 0:
rvals = torch.tensor(rewards, dtype=torch.float32)
rmask = torch.ones_like(rvals, dtype=torch.bool)
rstats = RolloutPostprocessor.compute_masked_stats_safe(rvals, rmask, process_group=None)
log_dict["reward/global_mean"] = rstats["mean"].item()
log_dict["reward/global_std"] = rstats["std"].item()
log_dict["reward/global_max"] = rstats["max"].item()
log_dict["reward/global_min"] = rstats["min"].item()
except Exception:
# Best-effort only; fall back to existing stats if anything goes wrong
pass
log_dict |= dict_add_prefix(compute_statistics(response_lengths), "response_len/")
log_dict |= _compute_zero_std_metrics(args, samples)
log_dict |= _compute_spec_metrics(args, samples)
Expand Down
Loading