From f9fd375b5feecbaa560288b742d8062e72921b2c Mon Sep 17 00:00:00 2001 From: ratish Date: Sat, 10 Jan 2026 22:33:49 +0530 Subject: [PATCH 1/4] feat: Implement bitwise weight correctness checker for Miles-SGLang sync --- docker/patch/v0.5.7/sglang.patch | 125 ++++++++++++++++++ .../update_weight_from_distributed.py | 24 ++++ .../update_weight_from_tensor.py | 24 ++++ miles/backends/sglang_utils/sglang_engine.py | 7 +- miles/ray/rollout.py | 6 +- miles/utils/arguments.py | 4 + tests/test_weight_update_correctness.py | 68 ++++++++++ 7 files changed, 254 insertions(+), 4 deletions(-) create mode 100644 tests/test_weight_update_correctness.py diff --git a/docker/patch/v0.5.7/sglang.patch b/docker/patch/v0.5.7/sglang.patch index 42d23ed65..ae99e0023 100644 --- a/docker/patch/v0.5.7/sglang.patch +++ b/docker/patch/v0.5.7/sglang.patch @@ -862,3 +862,128 @@ index a702df4f8..61d9ae366 100644 return Device2DraftCudaGraphRunner = { + + +diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py +index 1234567..89abcdef 100644 +--- a/python/sglang/srt/utils/weight_checker.py ++++ b/python/sglang/srt/utils/weight_checker.py +@@ -1,5 +1,6 @@ + import logging +-from typing import Dict, Iterable, Tuple ++import hashlib ++from typing import Dict, Iterable, Tuple, Optional + + import torch + +@@ -16,13 +17,19 @@ class WeightChecker: + self._model_runner = model_runner + self._snapshot_tensors = None + +- def handle(self, action: str): ++ def handle(self, action: str, checksums=None, rank_offset=0): + logger.info(f"[WeightChecker] handle action={action}") + if action == "snapshot": + self._snapshot() + elif action == "reset_tensors": + self._reset_tensors() + elif action == "compare": + self._compare() ++ elif action == "compare_checksum": ++ self._compare_checksum(checksums, rank_offset=rank_offset) + else: + raise Exception(f"Unsupported {action=}") + ++ def _compare_checksum(self, expected_checksums, rank_offset=0): ++ if expected_checksums is None: ++ return ++ ++ # 1. Get raw model state (BF16 for Qwen2.5-3B) ++ actual_state = dict(self._model_state()) ++ ++ import sglang.srt.distributed.parallel_state as ps ++ tp_group = ps.get_tp_group() ++ tp_rank = ps.get_tensor_model_parallel_rank() ++ ++ errors = [] ++ matched_count = 0 ++ ++ for name in sorted(actual_state.keys()): ++ if name in expected_checksums: ++ param = actual_state[name] ++ expected_hash = expected_checksums[name] ++ ++ data = param.data.to(torch.bfloat16) ++ ++ # STAGE A: Direct Match (Handles Replicated layers like Norms/Embeds) ++ t_cpu = data.detach().cpu().contiguous() ++ actual_hash = hashlib.sha256(t_cpu.view(torch.uint8).numpy()).hexdigest() ++ if actual_hash == expected_hash: ++ matched_count += 1 ++ continue ++ ++ # STAGE B: Shard Reconstruction (Handles TP layers) ++ if tp_group.world_size > 1: ++ # Gather shards from all TP ranks ++ all_shards = [torch.empty_like(data) for _ in range(tp_group.world_size)] ++ torch.distributed.all_gather(all_shards, data, group=tp_group.device_group) ++ ++ # Try Dim 0 (ColumnParallel: Gate, Up, QKV) ++ full_p0 = torch.cat(all_shards, dim=0) ++ if hashlib.sha256(full_p0.detach().cpu().contiguous().view(torch.uint8).numpy()).hexdigest() == expected_hash: ++ matched_count += 1 ++ continue ++ ++ # Try Dim 1 (RowParallel: O_proj, Down_proj) - only for 2D weights ++ if data.ndim > 1: ++ full_p1 = torch.cat(all_shards, dim=1) ++ if hashlib.sha256(full_p1.detach().cpu().contiguous().view(torch.uint8).numpy()).hexdigest() == expected_hash: ++ matched_count += 1 ++ continue ++ ++ # record mismatch if all reconstruction attempts fail ++ errors.append(f"name={name} TP_rank={tp_rank} mismatch! expected={expected_hash}, actual_shard={actual_hash}") ++ ++ print(f"[WeightChecker] verified {matched_count} parameters on TP_rank {tp_rank}") ++ if errors: ++ raise Exception("Weight checksum verification failed:\n" + "\n".join(errors[:5])) + + def _snapshot(self): + named_tensors = [ +diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +index 293a84350..0947f77e0 100644 +--- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py ++++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +@@ -190,4 +190,4 @@ class SchedulerUpdateWeightsMixin: + return ResumeMemoryOccupationReqOutput() + + def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): +- self.tp_worker.model_runner.check_weights(action=recv_req.action) ++ self.tp_worker.model_runner.check_weights(action=recv_req.action, checksums=recv_req.checksums, rank_offset=recv_req.rank_offset) + return CheckWeightsReqOutput(success=True, message="Success.") + +diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py +index 1234567..89abcdef 100644 +--- a/python/sglang/srt/managers/io_struct.py ++++ b/python/sglang/srt/managers/io_struct.py +@@ -628,6 +628,8 @@ class ResumeMemoryOccupationReqOutput(BaseReq): + @dataclass + class CheckWeightsReqInput(BaseReq): + action: str ++ checksums: Optional[list | dict] = None ++ rank_offset: int = 0 + + @dataclass + class CheckWeightsReqOutput(BaseReq): +diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py +index 1d69c0582..9027374be 100644 +--- a/python/sglang/srt/model_executor/model_runner.py ++++ b/python/sglang/srt/model_executor/model_runner.py +@@ -2687,5 +2687,5 @@ class ModelRunner(ModelRunnerKVCacheMixin): + return (1.0 - self.mem_fraction_static) * total_gpu_memory + return (self.server_args.mem_fraction_static) * total_gpu_memory + +- def check_weights(self, action: str): +- self._weight_checker.handle(action=action) ++ def check_weights(self, action: str, checksums=None, rank_offset=0): ++ self._weight_checker.handle(action=action, checksums=checksums, rank_offset=rank_offset) diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 801074553..ef9f1f862 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -1,3 +1,4 @@ +import hashlib import socket import time from argparse import Namespace @@ -82,6 +83,8 @@ def update_weights(self) -> None: ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) + self._last_checksums = {} + buffer_size = 0 converted_named_tensors = [] # non expert params @@ -116,6 +119,22 @@ def update_weights(self) -> None: ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) + def check_weights(self, action: str) -> None: + if action == "compare" and self._last_checksums: + if dist.get_rank() == 0: + # Rank 0 hashes represent the full parameters (post-gathering). + ray.get( + [ + engine.check_weights.remote( + action="compare_checksum", payload={"checksums": self._last_checksums} + ) + for engine in self.rollout_engines + ] + ) + else: + if dist.get_rank() == 0: + ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines]) + def _update_weight_from_distributed( self, name: str, @@ -208,6 +227,11 @@ def _update_bucket_weights_from_distributed( """ Lock → broadcast → clear → unlock → pbar++. Lock prevents NCCL deadlock. """ + if self.args.enable_weight_checker or self.args.check_weight_update_equal: + for name, tensor in converted_named_tensors: + t_cpu = tensor.detach().cpu().contiguous() + self._last_checksums[name] = hashlib.sha256(t_cpu.view(torch.uint8).numpy()).hexdigest() + # lock the rollout engines to prevent dead lock on broadcast. while not ray.get(self.rollout_engine_lock.acquire.remote()): time.sleep(0.1) diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 527d3cfe9..94dbbb5df 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -1,3 +1,4 @@ +import hashlib from argparse import Namespace from collections.abc import Callable, Mapping, Sequence from typing import Any @@ -116,7 +117,14 @@ def update_weights(self) -> None: megatron_local_weights = self.weights_getter() + self._last_checksums = {} + for hf_named_tensors in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights): + if self.args.check_weight_update_equal: + for name, tensor in hf_named_tensors: + t_cpu = tensor.detach().cpu().contiguous() + self._last_checksums[name] = hashlib.sha256(t_cpu.view(torch.uint8).numpy()).hexdigest() + refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) ray.get(refs) del long_lived_tensors @@ -148,6 +156,22 @@ def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: return all_refs, long_lived_tensors + def check_weights(self, action: str) -> None: + if (self.args.enable_weight_checker or action == "compare") and self._last_checksums: + if dist.get_rank() == 0: + # Rank 0 hashes represent the full parameters (post-gathering). + ray.get( + [ + engine.check_weights.remote( + action="compare_checksum", payload={"checksums": self._last_checksums} + ) + for engine in self.rollout_engines + ] + ) + else: + if dist.get_rank() == 0: + ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines]) + def _send_to_colocated_engine( hf_named_tensors: list[tuple[str, torch.Tensor]], diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 179306023..cb5f15eaf 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -346,8 +346,11 @@ def resume_memory_occupation(self, tags: list[str] = None): {"tags": tags}, ) - def check_weights(self, action: str): - return self._make_request("weights_checker", {"action": action}) + def check_weights(self, action: str, payload: dict | None = None): + data = {"action": action} + if payload is not None: + data.update(payload) + return self._make_request("weights_checker", data) def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): return self._make_request( diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 79c6649be..06acf85b8 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -205,8 +205,10 @@ def health_monitoring_resume(self) -> None: if self._health_monitor is not None: self._health_monitor.resume() - def check_weights(self, action: str): - return ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines]) + def check_weights(self, action: str, payload: dict | None = None): + return ray.get( + [engine.check_weights.remote(action=action, payload=payload) for engine in self.rollout_engines] + ) def _get_rollout_data(self, rollout_id): if self.args.load_debug_rollout_data: diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 51b3d970b..c5475f53f 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1126,6 +1126,7 @@ def add_debug_arguments(parser): choices=["torch", "memray"], default="torch", ) + parser.add_argument("--enable-weight-checker", action="store_true") parser.add_argument("--check-weight-update-equal", action="store_true") return parser @@ -1569,6 +1570,9 @@ def miles_validate_args(args): ) args.debug_train_only = True + if args.enable_weight_checker: + pass # We will use this flag directly in the updaters to trigger the bitwise check. + args.use_critic = args.advantage_estimator == "ppo" if args.critic_num_gpus_per_node is None: args.critic_num_gpus_per_node = args.actor_num_gpus_per_node diff --git a/tests/test_weight_update_correctness.py b/tests/test_weight_update_correctness.py new file mode 100644 index 000000000..2a102ae3e --- /dev/null +++ b/tests/test_weight_update_correctness.py @@ -0,0 +1,68 @@ +import unittest +import pytest +import torch +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen2.5-3B" +MODEL_TYPE = "qwen2.5-3B" +NUM_GPUS = 2 +BASE_DIR = "/root" + + +@pytest.mark.skipif(torch.cuda.device_count() < NUM_GPUS, reason=f"Need at least {NUM_GPUS} GPUs") +class TestWeightUpdateCorrectness(unittest.TestCase): + def setUp(self): + U.exec_command(f"mkdir -p {BASE_DIR}/models {BASE_DIR}/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir {BASE_DIR}/models/{MODEL_NAME}") + U.exec_command( + f"hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir {BASE_DIR}/datasets/dapo-math-17k" + ) + + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_GPUS, + dir_dst=f"{BASE_DIR}/models", + hf_checkpoint=f"{BASE_DIR}/models/{MODEL_NAME}", + ) + + def test_weight_correctness(self): + ckpt_args = ( + f"--hf-checkpoint {BASE_DIR}/models/{MODEL_NAME}/ --ref-load {BASE_DIR}/models/{MODEL_NAME}_torch_dist " + ) + + rollout_args = ( + f"--prompt-data {BASE_DIR}/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--num-rollout 2 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 1 " + "--rollout-max-response-len 128 " + "--global-batch-size 4 " + ) + + ppo_args = "--advantage-estimator grpo --rm-type math " + sglang_args = ( + f"--rollout-num-gpus-per-engine {NUM_GPUS} --rollout-num-gpus {NUM_GPUS} --sglang-mem-fraction-static 0.6 " + ) + checker_args = "--enable-weight-checker " + + misc_args = ( + "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " "--update-weights-interval 1" + ) + + train_args = ( + f"{ckpt_args} " f"{rollout_args} " f"{ppo_args} " f"{sglang_args} " f"{checker_args} " f"{misc_args}" + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + ) + + +if __name__ == "__main__": + unittest.main() From c133f130a06aea963662530b82dd8c310a0ac154 Mon Sep 17 00:00:00 2001 From: ratish Date: Sun, 11 Jan 2026 00:08:00 +0530 Subject: [PATCH 2/4] more --- docker/patch/v0.5.7/sglang.patch | 13 ++++++------- .../update_weight/update_weight_from_tensor.py | 2 +- miles/ray/rollout.py | 6 +++++- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/docker/patch/v0.5.7/sglang.patch b/docker/patch/v0.5.7/sglang.patch index ae99e0023..b4881e1b2 100644 --- a/docker/patch/v0.5.7/sglang.patch +++ b/docker/patch/v0.5.7/sglang.patch @@ -881,7 +881,7 @@ index 1234567..89abcdef 100644 self._snapshot_tensors = None - def handle(self, action: str): -+ def handle(self, action: str, checksums=None, rank_offset=0): ++ def handle(self, action: str, checksums=None): logger.info(f"[WeightChecker] handle action={action}") if action == "snapshot": self._snapshot() @@ -890,11 +890,11 @@ index 1234567..89abcdef 100644 elif action == "compare": self._compare() + elif action == "compare_checksum": -+ self._compare_checksum(checksums, rank_offset=rank_offset) ++ self._compare_checksum(checksums) else: raise Exception(f"Unsupported {action=}") -+ def _compare_checksum(self, expected_checksums, rank_offset=0): ++ def _compare_checksum(self, expected_checksums): + if expected_checksums is None: + return + @@ -944,7 +944,7 @@ index 1234567..89abcdef 100644 + # record mismatch if all reconstruction attempts fail + errors.append(f"name={name} TP_rank={tp_rank} mismatch! expected={expected_hash}, actual_shard={actual_hash}") + -+ print(f"[WeightChecker] verified {matched_count} parameters on TP_rank {tp_rank}") ++ logger.info(f"[WeightChecker] verified {matched_count} parameters on TP_rank {tp_rank}") + if errors: + raise Exception("Weight checksum verification failed:\n" + "\n".join(errors[:5])) @@ -959,7 +959,7 @@ index 293a84350..0947f77e0 100644 def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): - self.tp_worker.model_runner.check_weights(action=recv_req.action) -+ self.tp_worker.model_runner.check_weights(action=recv_req.action, checksums=recv_req.checksums, rank_offset=recv_req.rank_offset) ++ self.tp_worker.model_runner.check_weights(action=recv_req.action, checksums=recv_req.checksums) return CheckWeightsReqOutput(success=True, message="Success.") diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py @@ -971,7 +971,6 @@ index 1234567..89abcdef 100644 class CheckWeightsReqInput(BaseReq): action: str + checksums: Optional[list | dict] = None -+ rank_offset: int = 0 @dataclass class CheckWeightsReqOutput(BaseReq): @@ -986,4 +985,4 @@ index 1d69c0582..9027374be 100644 - def check_weights(self, action: str): - self._weight_checker.handle(action=action) + def check_weights(self, action: str, checksums=None, rank_offset=0): -+ self._weight_checker.handle(action=action, checksums=checksums, rank_offset=rank_offset) ++ self._weight_checker.handle(action=action, checksums=checksums) diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 94dbbb5df..57f297284 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -120,7 +120,7 @@ def update_weights(self) -> None: self._last_checksums = {} for hf_named_tensors in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights): - if self.args.check_weight_update_equal: + if self.args.check.enable_weight_checker or self.args.check_weight_update_equal: for name, tensor in hf_named_tensors: t_cpu = tensor.detach().cpu().contiguous() self._last_checksums[name] = hashlib.sha256(t_cpu.view(torch.uint8).numpy()).hexdigest() diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 06acf85b8..d462cffb5 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -207,7 +207,11 @@ def health_monitoring_resume(self) -> None: def check_weights(self, action: str, payload: dict | None = None): return ray.get( - [engine.check_weights.remote(action=action, payload=payload) for engine in self.rollout_engines] + [ + engine.check_weights.remote(action=action, payload=payload) + for engine in self.rollout_engines + if engine is not None + ] ) def _get_rollout_data(self, rollout_id): From 8aab8da439d3f1e60cf16b56ec39c1f050bd3278 Mon Sep 17 00:00:00 2001 From: ratish Date: Tue, 13 Jan 2026 13:13:22 +0530 Subject: [PATCH 3/4] remove sglang patch --- docker/patch/v0.5.7/sglang.patch | 124 ------------------------------- 1 file changed, 124 deletions(-) diff --git a/docker/patch/v0.5.7/sglang.patch b/docker/patch/v0.5.7/sglang.patch index b4881e1b2..42d23ed65 100644 --- a/docker/patch/v0.5.7/sglang.patch +++ b/docker/patch/v0.5.7/sglang.patch @@ -862,127 +862,3 @@ index a702df4f8..61d9ae366 100644 return Device2DraftCudaGraphRunner = { - - -diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py -index 1234567..89abcdef 100644 ---- a/python/sglang/srt/utils/weight_checker.py -+++ b/python/sglang/srt/utils/weight_checker.py -@@ -1,5 +1,6 @@ - import logging --from typing import Dict, Iterable, Tuple -+import hashlib -+from typing import Dict, Iterable, Tuple, Optional - - import torch - -@@ -16,13 +17,19 @@ class WeightChecker: - self._model_runner = model_runner - self._snapshot_tensors = None - -- def handle(self, action: str): -+ def handle(self, action: str, checksums=None): - logger.info(f"[WeightChecker] handle action={action}") - if action == "snapshot": - self._snapshot() - elif action == "reset_tensors": - self._reset_tensors() - elif action == "compare": - self._compare() -+ elif action == "compare_checksum": -+ self._compare_checksum(checksums) - else: - raise Exception(f"Unsupported {action=}") - -+ def _compare_checksum(self, expected_checksums): -+ if expected_checksums is None: -+ return -+ -+ # 1. Get raw model state (BF16 for Qwen2.5-3B) -+ actual_state = dict(self._model_state()) -+ -+ import sglang.srt.distributed.parallel_state as ps -+ tp_group = ps.get_tp_group() -+ tp_rank = ps.get_tensor_model_parallel_rank() -+ -+ errors = [] -+ matched_count = 0 -+ -+ for name in sorted(actual_state.keys()): -+ if name in expected_checksums: -+ param = actual_state[name] -+ expected_hash = expected_checksums[name] -+ -+ data = param.data.to(torch.bfloat16) -+ -+ # STAGE A: Direct Match (Handles Replicated layers like Norms/Embeds) -+ t_cpu = data.detach().cpu().contiguous() -+ actual_hash = hashlib.sha256(t_cpu.view(torch.uint8).numpy()).hexdigest() -+ if actual_hash == expected_hash: -+ matched_count += 1 -+ continue -+ -+ # STAGE B: Shard Reconstruction (Handles TP layers) -+ if tp_group.world_size > 1: -+ # Gather shards from all TP ranks -+ all_shards = [torch.empty_like(data) for _ in range(tp_group.world_size)] -+ torch.distributed.all_gather(all_shards, data, group=tp_group.device_group) -+ -+ # Try Dim 0 (ColumnParallel: Gate, Up, QKV) -+ full_p0 = torch.cat(all_shards, dim=0) -+ if hashlib.sha256(full_p0.detach().cpu().contiguous().view(torch.uint8).numpy()).hexdigest() == expected_hash: -+ matched_count += 1 -+ continue -+ -+ # Try Dim 1 (RowParallel: O_proj, Down_proj) - only for 2D weights -+ if data.ndim > 1: -+ full_p1 = torch.cat(all_shards, dim=1) -+ if hashlib.sha256(full_p1.detach().cpu().contiguous().view(torch.uint8).numpy()).hexdigest() == expected_hash: -+ matched_count += 1 -+ continue -+ -+ # record mismatch if all reconstruction attempts fail -+ errors.append(f"name={name} TP_rank={tp_rank} mismatch! expected={expected_hash}, actual_shard={actual_hash}") -+ -+ logger.info(f"[WeightChecker] verified {matched_count} parameters on TP_rank {tp_rank}") -+ if errors: -+ raise Exception("Weight checksum verification failed:\n" + "\n".join(errors[:5])) - - def _snapshot(self): - named_tensors = [ -diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py -index 293a84350..0947f77e0 100644 ---- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py -+++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py -@@ -190,4 +190,4 @@ class SchedulerUpdateWeightsMixin: - return ResumeMemoryOccupationReqOutput() - - def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): -- self.tp_worker.model_runner.check_weights(action=recv_req.action) -+ self.tp_worker.model_runner.check_weights(action=recv_req.action, checksums=recv_req.checksums) - return CheckWeightsReqOutput(success=True, message="Success.") - -diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py -index 1234567..89abcdef 100644 ---- a/python/sglang/srt/managers/io_struct.py -+++ b/python/sglang/srt/managers/io_struct.py -@@ -628,6 +628,8 @@ class ResumeMemoryOccupationReqOutput(BaseReq): - @dataclass - class CheckWeightsReqInput(BaseReq): - action: str -+ checksums: Optional[list | dict] = None - - @dataclass - class CheckWeightsReqOutput(BaseReq): -diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 1d69c0582..9027374be 100644 ---- a/python/sglang/srt/model_executor/model_runner.py -+++ b/python/sglang/srt/model_executor/model_runner.py -@@ -2687,5 +2687,5 @@ class ModelRunner(ModelRunnerKVCacheMixin): - return (1.0 - self.mem_fraction_static) * total_gpu_memory - return (self.server_args.mem_fraction_static) * total_gpu_memory - -- def check_weights(self, action: str): -- self._weight_checker.handle(action=action) -+ def check_weights(self, action: str, checksums=None, rank_offset=0): -+ self._weight_checker.handle(action=action, checksums=checksums) From d73d8e5257da04d42d03bc07cbed17c672b73976 Mon Sep 17 00:00:00 2001 From: ratish Date: Thu, 15 Jan 2026 15:05:45 +0530 Subject: [PATCH 4/4] update --- miles/backends/megatron_utils/actor.py | 13 +++++++++++++ .../update_weight/update_weight_from_distributed.py | 4 ++-- .../update_weight/update_weight_from_tensor.py | 4 ++-- miles/ray/actor_group.py | 4 ++++ train.py | 6 ++++-- 5 files changed, 25 insertions(+), 6 deletions(-) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 1e6af26b7..9e603980d 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -554,6 +554,19 @@ def update_weights(self) -> None: if self.args.offload_train: destroy_process_groups() + def check_weights(self, action: str = "compare") -> None: + """Verify weights between training and rollout engines.""" + if self.args.debug_train_only or self.args.debug_rollout_only: + return + + if self.args.offload_train: + reload_process_groups() + + self.weight_updater.check_weights(action=action) + + if self.args.offload_train: + destroy_process_groups() + def load_other_checkpoint(self, model_tag: str, path: str) -> None: old_args = self.args.load, self.args.no_load_optim, self.args.no_load_rng, self.args.finetune self.args.load = path diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py index ef9f1f862..434f252b0 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -120,9 +120,9 @@ def update_weights(self) -> None: dist.barrier(group=get_gloo_group()) def check_weights(self, action: str) -> None: - if action == "compare" and self._last_checksums: + if (self.args.enable_weight_checker or action == "compare") and self._last_checksums: if dist.get_rank() == 0: - # Rank 0 hashes represent the full parameters (post-gathering). + # Send the entire dictionary in one Ray RPC call to ensure a single summary log. ray.get( [ engine.check_weights.remote( diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 57f297284..12ed36d02 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -120,7 +120,7 @@ def update_weights(self) -> None: self._last_checksums = {} for hf_named_tensors in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights): - if self.args.check.enable_weight_checker or self.args.check_weight_update_equal: + if self.args.enable_weight_checker or self.args.check_weight_update_equal: for name, tensor in hf_named_tensors: t_cpu = tensor.detach().cpu().contiguous() self._last_checksums[name] = hashlib.sha256(t_cpu.view(torch.uint8).numpy()).hexdigest() @@ -159,7 +159,7 @@ def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: def check_weights(self, action: str) -> None: if (self.args.enable_weight_checker or action == "compare") and self._last_checksums: if dist.get_rank() == 0: - # Rank 0 hashes represent the full parameters (post-gathering). + # One request per update cycle to keep logs clean and minimize network overhead. ray.get( [ engine.check_weights.remote( diff --git a/miles/ray/actor_group.py b/miles/ray/actor_group.py index 0b11df312..def574969 100644 --- a/miles/ray/actor_group.py +++ b/miles/ray/actor_group.py @@ -124,6 +124,10 @@ def update_weights(self): """Broadcast weights from rank 0 to all other ranks.""" return ray.get([actor.update_weights.remote() for actor in self._actor_handlers]) + def check_weights(self, action: str = "compare"): + """Verify weights between training and rollout engines.""" + return ray.get([actor.check_weights.remote(action=action) for actor in self._actor_handlers]) + def onload(self): return ray.get([actor.wake_up.remote() for actor in self._actor_handlers]) diff --git a/train.py b/train.py index 745dcbed6..d57972207 100644 --- a/train.py +++ b/train.py @@ -26,8 +26,8 @@ def train(args): # always update weight first so that sglang has the loaded weights from training. actor_model.update_weights() - if args.check_weight_update_equal: - ray.get(rollout_manager.check_weights.remote(action="compare")) + if args.check_weight_update_equal or args.enable_weight_checker: + actor_model.check_weights() if args.offload_rollout: ray.get(rollout_manager.onload_kv.remote()) @@ -87,6 +87,8 @@ def save(rollout_id): if args.offload_rollout: ray.get(rollout_manager.onload_weights.remote()) actor_model.update_weights() + if args.check_weight_update_equal or args.enable_weight_checker: + actor_model.check_weights() if args.offload_rollout: ray.get(rollout_manager.onload_kv.remote())