Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions miles/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from miles.utils.tracking_utils import init_tracking

from ...utils import tracking_utils
from ...utils.metric_processor import process_metric
from ...utils.profile_utils import TrainProfiler
from . import checkpoint
from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences
Expand Down Expand Up @@ -509,25 +510,29 @@ def _log_rollout_data(self, rollout_id: int, rollout_data, packed_batches):
for metric_key in ["log_probs", "rollout_log_probs", "ref_log_probs", "advantages", "returns"]:
if metric_key not in packed_batches[0]:
continue
val = torch.tensor([0.0], device=torch.cuda.current_device())

val = []
for _mbs_id, batches in enumerate(packed_batches):
unpacked_batches = unpack_sequences(batches)
for unpacked_batch in unpacked_batches:
if isinstance(unpacked_batch[metric_key], torch.Tensor):
loss_masks_tensor = unpacked_batch["loss_masks"].to(device=torch.cuda.current_device())
metric_tensor = unpacked_batch[metric_key].to(device=torch.cuda.current_device())
val += (metric_tensor * loss_masks_tensor).sum() / loss_masks_tensor.sum().clamp_min(1)
result = (metric_tensor * loss_masks_tensor).sum() / loss_masks_tensor.sum().clamp_min(1)
val.append(result.item())
else:
val += unpacked_batch[metric_key]
dist.all_reduce(val, op=dist.ReduceOp.SUM, group=self.dp_group)
log_dict[f"rollout/{metric_key}"] = (
val / (self.args.n_samples_per_prompt * self.args.rollout_batch_size)
).item()
val.append(float(unpacked_batch[metric_key]))
if len(val) > 0:
local_vals = torch.stack(val)
else:
# default value for empty list
local_vals = torch.tensor([], device=torch.cuda.current_device(), dtype=torch.float32)
process_metric(log_dict, f"rollout/{metric_key}", local_vals, self.dp_group)

if dist.get_rank() == 0:
logger.info(f"rollout {rollout_id}: {log_dict}")
log_dict["rollout/step"] = compute_rollout_step(self.args, rollout_id)
tracking_utils.log(self.args, log_dict, step_key="rollout/step")

if self.args.ci_test and self.args.true_on_policy_mode:
assert log_dict["rollout/log_probs"] == log_dict["rollout/rollout_log_probs"], (
f"CI check failed: true_on_policy_mode is enabled, but log_probs "
Expand Down Expand Up @@ -678,7 +683,7 @@ def _has_rollout_log_probs(batch) -> bool:
pg_loss = pg_loss * tis_clip

assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented"
pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks)
pg_loss, pg_loss_vec = get_sample_mean_info(pg_loss, response_lengths, loss_masks)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks)

Expand Down Expand Up @@ -712,7 +717,7 @@ def _has_rollout_log_probs(batch) -> bool:

reported = {
"loss": loss.detach(),
"pg_loss": pg_loss.detach(),
"pg_loss": pg_loss_vec.detach(),
"pg_clipfrac": pg_clipfrac.detach(),
"ppo_kl": ppo_kl.detach(),
"entropy_loss": entropy_loss.detach(),
Expand Down Expand Up @@ -750,19 +755,14 @@ def _has_rollout_log_probs(batch) -> bool:
# Update learning rate
self.lr_scheduler.step()
self.optimizer.zero_grad(set_to_none=True)
# Aggregate logs
aggregated = {k: torch.stack(v).sum().item() for k, v in reported_accum.items()}
# TODO: change this, this is slow.
reduced_aggregated = [None] * self.dp_size
dist.all_gather_object(reduced_aggregated, aggregated, group=self.dp_group)
aggregated = {}
for k in reported_accum.keys():
aggregated[k] = sum([r[k] for r in reduced_aggregated]) / (self.args.global_batch_size)
log_dict = {}
for k, v_list in reported_accum.items():
# Format and pass to process_metric for metrics in reported_accum
combined_tensor = torch.cat([v.view(-1) for v in v_list])

process_metric(log_dict=log_dict, key=f"train/{k}", metric=combined_tensor, group=self.dp_group)
reported_accum.clear()
if dist.get_rank() == 0:
log_dict = {
f"train/{k}": (val.item() if torch.is_tensor(val) else val) for k, val in aggregated.items()
}
log_dict["train/grad_norm"] = grad_norm

# Log learning rate per parameter group; use scheduler's last computed LRs
Expand All @@ -771,13 +771,13 @@ def _has_rollout_log_probs(batch) -> bool:
log_dict[f"train/lr-pg_{gid}"] = lr_values[gid]

kl_info = ""
if self.args.use_kl_loss and "kl_loss" in aggregated:
kl_info = f", kl_loss: {aggregated['kl_loss']:.4f}, kl_penalty: {aggregated['kl_loss'] * self.args.kl_loss_coef:.4f}"
if self.args.use_kl_loss and "kl_loss" in log_dict:
kl_info = f", kl_loss: {log_dict['kl_loss']:.4f}, kl_penalty: {log_dict['kl_loss'] * self.args.kl_loss_coef:.4f}"
logger.info(kl_info)
logger.info(f"step {self.global_step}: {log_dict}")

log_dict["train/step"] = self.global_step
tracking_utils.log(self.args, log_dict, step_key="train/step")
tracking_utils.log(self.args, log_dict, step_key="rollout/step")
self.global_step += 1

@timer
Expand Down Expand Up @@ -1032,6 +1032,26 @@ def sum_of_sample_mean(x: torch.Tensor, response_lengths: list[int], loss_masks:
)


def get_sample_mean_info(x: torch.Tensor, response_lengths: list[int], loss_masks: list[torch.Tensor]) -> torch.Tensor:
"""Compute sum and vector of per-sample means for variable-length responses.

Parameters:
x: Flat tensor of concatenated per-token values.
response_lengths: Length of each sample in `x`.
loss_masks: Per-sample masks for `response_lengths`.

Returns:
A tuple of (sum_of_means, vector_of_means). The vector is used for reduction in `process_metric`.
"""
sample_mean_list = [
(x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1)
for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)
]
if len(sample_mean_list) == 0:
return torch.tensor(0.0, device=x.device), torch.tensor([], device=x.device)
return sum(sample_mean_list), torch.stack(sample_mean_list)


@torch.no_grad()
def move_torch_optimizer(optimizer, device):
"""ref: https://github.com/volcengine/verl/blob/main/verl/utils/fsdp_utils.py"""
Expand Down
43 changes: 43 additions & 0 deletions miles/backends/megatron_utils/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,49 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor:
return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token


def get_vector_of_sample_mean(
total_lengths: list[int],
response_lengths: list[int],
loss_masks: list[torch.Tensor],
) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Calculate correct sample vector for CP
"""
cp_size = mpu.get_context_parallel_world_size()
if cp_size == 1:

def vector_of_sample_mean(x: torch.Tensor) -> torch.Tensor:
sample_means = [
(x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1)
for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)
]
return torch.stack(sample_means) if len(sample_means) > 0 else torch.Tensor([], device=x.device)

else:
cp_chunk_lengths = []
chunked_loss_masks = []
for i, (total_length, response_length, loss_mask) in enumerate(
zip(total_lengths, response_lengths, loss_masks, strict=False)
):
prompt_length = total_length - response_length
_, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp(total_length, response_length)
loss_mask_0 = loss_mask[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length]
loss_mask_1 = loss_mask[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length]
chunked_loss_masks.append(torch.cat([loss_mask_0, loss_mask_1], dim=0))
cp_chunk_lengths.append(chunked_loss_masks[i].size(0))

def vector_of_sample_mean(x: torch.Tensor) -> torch.Tensor:
sample_means = [
(x_i * chunked_loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1)
for x_i, chunked_loss_mask, loss_mask in zip(
x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, loss_masks, strict=False
)
]
return torch.stack(sample_means) if len(sample_means) > 0 else torch.Tensor([], device=x.device)

return vector_of_sample_mean


def all_gather_with_cp(tensor: torch.Tensor, total_length: int, response_length: int) -> torch.Tensor:
"""
Gather tensors across all ranks in the context parallel group.
Expand Down
70 changes: 56 additions & 14 deletions miles/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from miles.utils.types import RolloutBatch

from ...utils import tracking_utils
from .cp_utils import get_sum_of_sample_mean, slice_with_cp
from ...utils.metric_processor import _EXTEND_METRICS, process_metric
from .cp_utils import get_vector_of_sample_mean, slice_with_cp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -149,6 +150,50 @@ def gather_log_data(
return None


def reduce_log_data(
metric_name: str,
args: Namespace,
rollout_id: int,
log_dict: dict[str, torch.Tensor],
) -> dict[str, float] | None:
"""
Reduce per-rank tensor metrics across DP group and log results.

Supports both scalar and vector tensors by using `process_metric` for reduction.
Global keys are synchronized across the DP group to handle potential missing
keys on individual ranks. Returns the reduced metrics on the DP source rank.
"""
dp_group = mpu.get_data_parallel_group(with_context_parallel=True)
dp_size = mpu.get_data_parallel_world_size(with_context_parallel=True)
local_keys = list(log_dict.keys())
all_keys_list = [None] * dist.get_world_size(group=dp_group)
dist.all_gather_object(all_keys_list, local_keys, group=dp_group)

global_keys = sorted(list(set([k for keys in all_keys_list for k in keys])))

reduced_log_dict = {}
for key in global_keys:
metric_tensor = log_dict.get(key, torch.tensor([], dtype=torch.float32, device=torch.cuda.current_device()))

process_metric(
log_dict=reduced_log_dict,
key=f"{metric_name}/{key}",
metric=metric_tensor,
group=dp_group,
amount=dp_size if metric_tensor.numel() <= 1 else -1,
)

if mpu.get_data_parallel_rank(with_context_parallel=True) == 0:
step = compute_rollout_step(args, rollout_id)
reduced_log_dict["rollout/step"] = step

logger.info(f"{metric_name} {rollout_id}: {reduced_log_dict}")
tracking_utils.log(args, reduced_log_dict, step_key="rollout/step")

return reduced_log_dict
return None


class DataIterator:
"""Micro-batch iterator over rollout dicts.

Expand Down Expand Up @@ -320,7 +365,6 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc
response_lengths = rollout_data["response_lengths"]
loss_masks = rollout_data["loss_masks"]
total_lengths = rollout_data["total_lengths"]
max_seq_lens = rollout_data.get("max_seq_lens", None)

for key, val in rollout_data.items():
if key in [
Expand All @@ -339,26 +383,24 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc
# NOTE: Here we have to do the clone().detach(), otherwise the tensor will be
# modified in place and will cause problem for the next rollout.
val = torch.cat(val).clone().detach()
if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages", "values"]:
sum_of_sample_mean = get_sum_of_sample_mean(
total_lengths,
response_lengths,
loss_masks,
qkv_format=args.qkv_format,
max_seq_lens=max_seq_lens,
)
val = cp_size * sum_of_sample_mean(val) / len(loss_masks)
if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "advantages", "returns", "values"]:
vector_of_sample_mean = get_vector_of_sample_mean(total_lengths, response_lengths, loss_masks)
# maintain sample mean for extend metrics
if key in _EXTEND_METRICS:
val = vector_of_sample_mean(val)
else:
val = cp_size * sum(vector_of_sample_mean(val)) / len(loss_masks)
else:
val = val.mean() * cp_size
else:
val = sum(val) / len(val)
val = torch.tensor(sum(val) / len(val), device=torch.cuda.current_device())
elif isinstance(val, torch.Tensor):
val = val.float().mean()
else:
raise ValueError(f"Unsupported type: {type(val)} for key: {key}")
log_dict[key] = val.item() if isinstance(val, torch.Tensor) else val
log_dict[key] = val

reduced_log_dict = gather_log_data("rollout", args, rollout_id, log_dict)
reduced_log_dict = reduce_log_data("rollout", args, rollout_id, log_dict)
if args.ci_test and reduced_log_dict is not None:
if (
rollout_id == 0
Expand Down
13 changes: 10 additions & 3 deletions miles/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
)
from miles.utils.types import RolloutBatch

from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean
from .cp_utils import (
all_gather_with_cp,
get_logits_and_tokens_offset_with_cp,
get_sum_of_sample_mean,
get_vector_of_sample_mean,
)


def get_responses(
Expand Down Expand Up @@ -563,8 +568,10 @@ def policy_loss_function(
args.qkv_format,
batch.get("max_seq_lens", None),
)
vector_of_sample_mean = get_vector_of_sample_mean(total_lengths, response_lengths, modified_response_masks)

pg_loss = sum_of_sample_mean(pg_loss)
pg_loss_vec = vector_of_sample_mean(pg_loss)
pg_loss = pg_loss_vec.sum()
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)
ppo_kl = sum_of_sample_mean(ppo_kl)

Expand Down Expand Up @@ -602,7 +609,7 @@ def policy_loss_function(

reported_loss = {
"loss": loss.clone().detach(),
"pg_loss": pg_loss.clone().detach(),
"pg_loss": pg_loss_vec.clone().detach(),
"entropy_loss": entropy_loss.clone().detach(),
"pg_clipfrac": pg_clipfrac.clone().detach(),
"ppo_kl": ppo_kl.clone().detach(),
Expand Down
28 changes: 15 additions & 13 deletions miles/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from miles.utils import tracking_utils
from miles.utils.memory_utils import clear_memory

from ...utils.metric_processor import process_metric
from .checkpoint import load_checkpoint, save_checkpoint
from .cp_utils import slice_with_cp
from .data import DataIterator, get_batch
Expand Down Expand Up @@ -481,22 +482,22 @@ def build_loss_mask_for_mtp(batch: dict[str, object]) -> torch.Tensor | None:
optimizer.zero_grad()

if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches.
keys = losses_reduced[0]["keys"]
values = None
for x in losses_reduced:
if values is None:
values = x["values"]
else:
values += x["values"]
assert len(keys) + 1 == values.numel()
torch.distributed.all_reduce(values, group=mpu.get_data_parallel_group(with_context_parallel=True))
dp_group = mpu.get_data_parallel_group(with_context_parallel=True)
concat_vals = [
torch.cat([torch.atleast_1d(x["values"][i]) for x in losses_reduced])
for i in range(len(losses_reduced[0]["values"]))
]

local_amount = concat_vals[0].sum()
torch.distributed.all_reduce(local_amount, group=dp_group)
amount = local_amount.item()

loss_reduced = {}
values = values.tolist()
num_samples_or_tokens = values[0]
for key, value in zip(keys, values[1:], strict=False):
loss_reduced[key] = value * mpu.get_context_parallel_world_size() / num_samples_or_tokens
for i, key in enumerate(keys):
process_metric(loss_reduced, key, concat_vals[i + 1], group=dp_group, amount=amount)
loss_reduced[key] *= mpu.get_context_parallel_world_size()

return loss_reduced, grad_norm
return {}, grad_norm

Expand Down Expand Up @@ -723,6 +724,7 @@ def initialize_model_and_optimizer(

if torch.version.hip:
import megatron.core.dist_checkpointing.strategies.filesystem_async as filesystem_async_module

from miles.utils.rocm_checkpoint_writer import ROCmFileSystemWriterAsync

filesystem_async_module.FileSystemWriterAsync = ROCmFileSystemWriterAsync
Expand Down
Loading