diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 1e6af26b7..9e603980d 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -554,6 +554,19 @@ def update_weights(self) -> None: if self.args.offload_train: destroy_process_groups() + def check_weights(self, action: str = "compare") -> None: + """Verify weights between training and rollout engines.""" + if self.args.debug_train_only or self.args.debug_rollout_only: + return + + if self.args.offload_train: + reload_process_groups() + + self.weight_updater.check_weights(action=action) + + if self.args.offload_train: + destroy_process_groups() + def load_other_checkpoint(self, model_tag: str, path: str) -> None: old_args = self.args.load, self.args.no_load_optim, self.args.no_load_rng, self.args.finetune self.args.load = path diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 801074553..434f252b0 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -1,3 +1,4 @@ +import hashlib import socket import time from argparse import Namespace @@ -82,6 +83,8 @@ def update_weights(self) -> None: ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) + self._last_checksums = {} + buffer_size = 0 converted_named_tensors = [] # non expert params @@ -116,6 +119,22 @@ def update_weights(self) -> None: ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) + def check_weights(self, action: str) -> None: + if (self.args.enable_weight_checker or action == "compare") and self._last_checksums: + if dist.get_rank() == 0: + # Send the entire dictionary in one Ray RPC call to ensure a single summary log. + ray.get( + [ + engine.check_weights.remote( + action="compare_checksum", payload={"checksums": self._last_checksums} + ) + for engine in self.rollout_engines + ] + ) + else: + if dist.get_rank() == 0: + ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines]) + def _update_weight_from_distributed( self, name: str, @@ -208,6 +227,11 @@ def _update_bucket_weights_from_distributed( """ Lock → broadcast → clear → unlock → pbar++. Lock prevents NCCL deadlock. """ + if self.args.enable_weight_checker or self.args.check_weight_update_equal: + for name, tensor in converted_named_tensors: + t_cpu = tensor.detach().cpu().contiguous() + self._last_checksums[name] = hashlib.sha256(t_cpu.view(torch.uint8).numpy()).hexdigest() + # lock the rollout engines to prevent dead lock on broadcast. while not ray.get(self.rollout_engine_lock.acquire.remote()): time.sleep(0.1) diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 527d3cfe9..12ed36d02 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -1,3 +1,4 @@ +import hashlib from argparse import Namespace from collections.abc import Callable, Mapping, Sequence from typing import Any @@ -116,7 +117,14 @@ def update_weights(self) -> None: megatron_local_weights = self.weights_getter() + self._last_checksums = {} + for hf_named_tensors in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights): + if self.args.enable_weight_checker or self.args.check_weight_update_equal: + for name, tensor in hf_named_tensors: + t_cpu = tensor.detach().cpu().contiguous() + self._last_checksums[name] = hashlib.sha256(t_cpu.view(torch.uint8).numpy()).hexdigest() + refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) ray.get(refs) del long_lived_tensors @@ -148,6 +156,22 @@ def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: return all_refs, long_lived_tensors + def check_weights(self, action: str) -> None: + if (self.args.enable_weight_checker or action == "compare") and self._last_checksums: + if dist.get_rank() == 0: + # One request per update cycle to keep logs clean and minimize network overhead. + ray.get( + [ + engine.check_weights.remote( + action="compare_checksum", payload={"checksums": self._last_checksums} + ) + for engine in self.rollout_engines + ] + ) + else: + if dist.get_rank() == 0: + ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines]) + def _send_to_colocated_engine( hf_named_tensors: list[tuple[str, torch.Tensor]], diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 179306023..cb5f15eaf 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -346,8 +346,11 @@ def resume_memory_occupation(self, tags: list[str] = None): {"tags": tags}, ) - def check_weights(self, action: str): - return self._make_request("weights_checker", {"action": action}) + def check_weights(self, action: str, payload: dict | None = None): + data = {"action": action} + if payload is not None: + data.update(payload) + return self._make_request("weights_checker", data) def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): return self._make_request( diff --git a/miles/ray/actor_group.py b/miles/ray/actor_group.py index 0b11df312..def574969 100644 --- a/miles/ray/actor_group.py +++ b/miles/ray/actor_group.py @@ -124,6 +124,10 @@ def update_weights(self): """Broadcast weights from rank 0 to all other ranks.""" return ray.get([actor.update_weights.remote() for actor in self._actor_handlers]) + def check_weights(self, action: str = "compare"): + """Verify weights between training and rollout engines.""" + return ray.get([actor.check_weights.remote(action=action) for actor in self._actor_handlers]) + def onload(self): return ray.get([actor.wake_up.remote() for actor in self._actor_handlers]) diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 79c6649be..d462cffb5 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -205,8 +205,14 @@ def health_monitoring_resume(self) -> None: if self._health_monitor is not None: self._health_monitor.resume() - def check_weights(self, action: str): - return ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines]) + def check_weights(self, action: str, payload: dict | None = None): + return ray.get( + [ + engine.check_weights.remote(action=action, payload=payload) + for engine in self.rollout_engines + if engine is not None + ] + ) def _get_rollout_data(self, rollout_id): if self.args.load_debug_rollout_data: diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 51b3d970b..c5475f53f 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1126,6 +1126,7 @@ def add_debug_arguments(parser): choices=["torch", "memray"], default="torch", ) + parser.add_argument("--enable-weight-checker", action="store_true") parser.add_argument("--check-weight-update-equal", action="store_true") return parser @@ -1569,6 +1570,9 @@ def miles_validate_args(args): ) args.debug_train_only = True + if args.enable_weight_checker: + pass # We will use this flag directly in the updaters to trigger the bitwise check. + args.use_critic = args.advantage_estimator == "ppo" if args.critic_num_gpus_per_node is None: args.critic_num_gpus_per_node = args.actor_num_gpus_per_node diff --git a/tests/test_weight_update_correctness.py b/tests/test_weight_update_correctness.py new file mode 100644 index 000000000..2a102ae3e --- /dev/null +++ b/tests/test_weight_update_correctness.py @@ -0,0 +1,68 @@ +import unittest +import pytest +import torch +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen2.5-3B" +MODEL_TYPE = "qwen2.5-3B" +NUM_GPUS = 2 +BASE_DIR = "/root" + + +@pytest.mark.skipif(torch.cuda.device_count() < NUM_GPUS, reason=f"Need at least {NUM_GPUS} GPUs") +class TestWeightUpdateCorrectness(unittest.TestCase): + def setUp(self): + U.exec_command(f"mkdir -p {BASE_DIR}/models {BASE_DIR}/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir {BASE_DIR}/models/{MODEL_NAME}") + U.exec_command( + f"hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir {BASE_DIR}/datasets/dapo-math-17k" + ) + + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_GPUS, + dir_dst=f"{BASE_DIR}/models", + hf_checkpoint=f"{BASE_DIR}/models/{MODEL_NAME}", + ) + + def test_weight_correctness(self): + ckpt_args = ( + f"--hf-checkpoint {BASE_DIR}/models/{MODEL_NAME}/ --ref-load {BASE_DIR}/models/{MODEL_NAME}_torch_dist " + ) + + rollout_args = ( + f"--prompt-data {BASE_DIR}/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--num-rollout 2 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 1 " + "--rollout-max-response-len 128 " + "--global-batch-size 4 " + ) + + ppo_args = "--advantage-estimator grpo --rm-type math " + sglang_args = ( + f"--rollout-num-gpus-per-engine {NUM_GPUS} --rollout-num-gpus {NUM_GPUS} --sglang-mem-fraction-static 0.6 " + ) + checker_args = "--enable-weight-checker " + + misc_args = ( + "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " "--update-weights-interval 1" + ) + + train_args = ( + f"{ckpt_args} " f"{rollout_args} " f"{ppo_args} " f"{sglang_args} " f"{checker_args} " f"{misc_args}" + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/train.py b/train.py index 745dcbed6..d57972207 100644 --- a/train.py +++ b/train.py @@ -26,8 +26,8 @@ def train(args): # always update weight first so that sglang has the loaded weights from training. actor_model.update_weights() - if args.check_weight_update_equal: - ray.get(rollout_manager.check_weights.remote(action="compare")) + if args.check_weight_update_equal or args.enable_weight_checker: + actor_model.check_weights() if args.offload_rollout: ray.get(rollout_manager.onload_kv.remote()) @@ -87,6 +87,8 @@ def save(rollout_id): if args.offload_rollout: ray.get(rollout_manager.onload_weights.remote()) actor_model.update_weights() + if args.check_weight_update_equal or args.enable_weight_checker: + actor_model.check_weights() if args.offload_rollout: ray.get(rollout_manager.onload_kv.remote())