Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions miles/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,19 @@ def update_weights(self) -> None:
if self.args.offload_train:
destroy_process_groups()

def check_weights(self, action: str = "compare") -> None:
"""Verify weights between training and rollout engines."""
if self.args.debug_train_only or self.args.debug_rollout_only:
return

if self.args.offload_train:
reload_process_groups()

self.weight_updater.check_weights(action=action)

if self.args.offload_train:
destroy_process_groups()

def load_other_checkpoint(self, model_tag: str, path: str) -> None:
old_args = self.args.load, self.args.no_load_optim, self.args.no_load_rng, self.args.finetune
self.args.load = path
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import socket
import time
from argparse import Namespace
Expand Down Expand Up @@ -82,6 +83,8 @@ def update_weights(self) -> None:
ray.get([engine.flush_cache.remote() for engine in self.rollout_engines])
dist.barrier(group=get_gloo_group())

self._last_checksums = {}

buffer_size = 0
converted_named_tensors = []
# non expert params
Expand Down Expand Up @@ -116,6 +119,22 @@ def update_weights(self) -> None:
ray.get([engine.continue_generation.remote() for engine in self.rollout_engines])
dist.barrier(group=get_gloo_group())

def check_weights(self, action: str) -> None:
if (self.args.enable_weight_checker or action == "compare") and self._last_checksums:
if dist.get_rank() == 0:
# Send the entire dictionary in one Ray RPC call to ensure a single summary log.
ray.get(
[
engine.check_weights.remote(
action="compare_checksum", payload={"checksums": self._last_checksums}
)
for engine in self.rollout_engines
]
)
else:
if dist.get_rank() == 0:
ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines])

def _update_weight_from_distributed(
self,
name: str,
Expand Down Expand Up @@ -208,6 +227,11 @@ def _update_bucket_weights_from_distributed(
"""
Lock → broadcast → clear → unlock → pbar++. Lock prevents NCCL deadlock.
"""
if self.args.enable_weight_checker or self.args.check_weight_update_equal:
for name, tensor in converted_named_tensors:
t_cpu = tensor.detach().cpu().contiguous()
self._last_checksums[name] = hashlib.sha256(t_cpu.view(torch.uint8).numpy()).hexdigest()

# lock the rollout engines to prevent dead lock on broadcast.
while not ray.get(self.rollout_engine_lock.acquire.remote()):
time.sleep(0.1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
from argparse import Namespace
from collections.abc import Callable, Mapping, Sequence
from typing import Any
Expand Down Expand Up @@ -116,7 +117,14 @@ def update_weights(self) -> None:

megatron_local_weights = self.weights_getter()

self._last_checksums = {}

for hf_named_tensors in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights):
if self.args.enable_weight_checker or self.args.check_weight_update_equal:
for name, tensor in hf_named_tensors:
t_cpu = tensor.detach().cpu().contiguous()
self._last_checksums[name] = hashlib.sha256(t_cpu.view(torch.uint8).numpy()).hexdigest()

refs, long_lived_tensors = self._send_hf_params(hf_named_tensors)
ray.get(refs)
del long_lived_tensors
Expand Down Expand Up @@ -148,6 +156,22 @@ def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]:

return all_refs, long_lived_tensors

def check_weights(self, action: str) -> None:
if (self.args.enable_weight_checker or action == "compare") and self._last_checksums:
if dist.get_rank() == 0:
# One request per update cycle to keep logs clean and minimize network overhead.
ray.get(
[
engine.check_weights.remote(
action="compare_checksum", payload={"checksums": self._last_checksums}
)
for engine in self.rollout_engines
]
)
else:
if dist.get_rank() == 0:
ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines])


def _send_to_colocated_engine(
hf_named_tensors: list[tuple[str, torch.Tensor]],
Expand Down
7 changes: 5 additions & 2 deletions miles/backends/sglang_utils/sglang_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,11 @@ def resume_memory_occupation(self, tags: list[str] = None):
{"tags": tags},
)

def check_weights(self, action: str):
return self._make_request("weights_checker", {"action": action})
def check_weights(self, action: str, payload: dict | None = None):
data = {"action": action}
if payload is not None:
data.update(payload)
return self._make_request("weights_checker", data)

def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name, backend):
return self._make_request(
Expand Down
4 changes: 4 additions & 0 deletions miles/ray/actor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def update_weights(self):
"""Broadcast weights from rank 0 to all other ranks."""
return ray.get([actor.update_weights.remote() for actor in self._actor_handlers])

def check_weights(self, action: str = "compare"):
"""Verify weights between training and rollout engines."""
return ray.get([actor.check_weights.remote(action=action) for actor in self._actor_handlers])

def onload(self):
return ray.get([actor.wake_up.remote() for actor in self._actor_handlers])

Expand Down
10 changes: 8 additions & 2 deletions miles/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,14 @@ def health_monitoring_resume(self) -> None:
if self._health_monitor is not None:
self._health_monitor.resume()

def check_weights(self, action: str):
return ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines])
def check_weights(self, action: str, payload: dict | None = None):
return ray.get(
[
engine.check_weights.remote(action=action, payload=payload)
for engine in self.rollout_engines
if engine is not None
]
)

def _get_rollout_data(self, rollout_id):
if self.args.load_debug_rollout_data:
Expand Down
4 changes: 4 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,7 @@ def add_debug_arguments(parser):
choices=["torch", "memray"],
default="torch",
)
parser.add_argument("--enable-weight-checker", action="store_true")
parser.add_argument("--check-weight-update-equal", action="store_true")
return parser

Expand Down Expand Up @@ -1569,6 +1570,9 @@ def miles_validate_args(args):
)
args.debug_train_only = True

if args.enable_weight_checker:
pass # We will use this flag directly in the updaters to trigger the bitwise check.

args.use_critic = args.advantage_estimator == "ppo"
if args.critic_num_gpus_per_node is None:
args.critic_num_gpus_per_node = args.actor_num_gpus_per_node
Expand Down
68 changes: 68 additions & 0 deletions tests/test_weight_update_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import unittest
import pytest
import torch
import miles.utils.external_utils.command_utils as U

MODEL_NAME = "Qwen2.5-3B"
MODEL_TYPE = "qwen2.5-3B"
NUM_GPUS = 2
BASE_DIR = "/root"


@pytest.mark.skipif(torch.cuda.device_count() < NUM_GPUS, reason=f"Need at least {NUM_GPUS} GPUs")
class TestWeightUpdateCorrectness(unittest.TestCase):
def setUp(self):
U.exec_command(f"mkdir -p {BASE_DIR}/models {BASE_DIR}/datasets")
U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir {BASE_DIR}/models/{MODEL_NAME}")
U.exec_command(
f"hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir {BASE_DIR}/datasets/dapo-math-17k"
)

U.convert_checkpoint(
model_name=MODEL_NAME,
megatron_model_type=MODEL_TYPE,
num_gpus_per_node=NUM_GPUS,
dir_dst=f"{BASE_DIR}/models",
hf_checkpoint=f"{BASE_DIR}/models/{MODEL_NAME}",
)

def test_weight_correctness(self):
ckpt_args = (
f"--hf-checkpoint {BASE_DIR}/models/{MODEL_NAME}/ --ref-load {BASE_DIR}/models/{MODEL_NAME}_torch_dist "
)

rollout_args = (
f"--prompt-data {BASE_DIR}/datasets/dapo-math-17k/dapo-math-17k.jsonl "
"--input-key prompt "
"--label-key label "
"--apply-chat-template "
"--num-rollout 2 "
"--rollout-batch-size 4 "
"--n-samples-per-prompt 1 "
"--rollout-max-response-len 128 "
"--global-batch-size 4 "
)

ppo_args = "--advantage-estimator grpo --rm-type math "
sglang_args = (
f"--rollout-num-gpus-per-engine {NUM_GPUS} --rollout-num-gpus {NUM_GPUS} --sglang-mem-fraction-static 0.6 "
)
checker_args = "--enable-weight-checker "

misc_args = (
"--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " "--update-weights-interval 1"
)

train_args = (
f"{ckpt_args} " f"{rollout_args} " f"{ppo_args} " f"{sglang_args} " f"{checker_args} " f"{misc_args}"
)

U.execute_train(
train_args=train_args,
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
)


if __name__ == "__main__":
unittest.main()
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def train(args):
# always update weight first so that sglang has the loaded weights from training.
actor_model.update_weights()

if args.check_weight_update_equal:
ray.get(rollout_manager.check_weights.remote(action="compare"))
if args.check_weight_update_equal or args.enable_weight_checker:
actor_model.check_weights()

if args.offload_rollout:
ray.get(rollout_manager.onload_kv.remote())
Expand Down Expand Up @@ -87,6 +87,8 @@ def save(rollout_id):
if args.offload_rollout:
ray.get(rollout_manager.onload_weights.remote())
actor_model.update_weights()
if args.check_weight_update_equal or args.enable_weight_checker:
actor_model.check_weights()
if args.offload_rollout:
ray.get(rollout_manager.onload_kv.remote())

Expand Down