Skip to content

Conversation

@Ratish1
Copy link
Contributor

@Ratish1 Ratish1 commented Jan 13, 2026

Motivation

This PR adds a mechanism to verify that the model weights loaded in SGLang are bit-for-bit identical. For more info, see here radixark/miles/pull/415.

Modifications

  • Protocol: Added compare_checksum to WeightChecker. It gathers local TP shards, reconstructs the full parameter (supporting both Column and Row Parallelism), and verifies the SHA256 hash against provided ground truth.
  • Precision: Forces a cast to bfloat16 during hashing to ensure bit-perfect alignment.
  • API: Updated CheckWeightsReqInput to accept an optional checksums payload and propagated it from the API to the ModelRunner.

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Copy link
Collaborator

@zhaochenyang20 zhaochenyang20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR for checking the bitwise verification. I have several comments:

  1. In the actual weights update logic in RL frameworks, try to overlap the checksum and weight refit. The idea is that check_sum computation on the CPU is rather slower than moving data from the GPU to the CPU by PCIe. If we are currently using this:
# blocking check sum
for name, tensor in tensors:
    t_cpu = tensor.cpu() 
    hash = hashlib.sha256(t_cpu).hexdigest()

The ideal Version is that we should turn it into non-blocking:

import concurrent.futures

executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
future = None

for tensor in tensors:
    t_cpu = tensor.detach().cpu().contiguous()
    if future:
        check_hash(future.result())
    future = executor.submit(compute_sha256, t_cpu)
if future:
    check_hash(future.result())
  1. But, for the first step, favor simplicity. No need to make the non-blocking version of check_sum right now; it should be okay. To this point, we shall not turn the check_sum default on. Let the users decide whether to check_sum in each weight update round.

  2. I add some comments in python/sglang/srt/utils/weight_checker.py. I think we should add dedicated comments to users on @app.post("/weights_checker") for RL users.

param = actual_state[name]
expected_hash = expected_checksums[name]

data = param.data.to(torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Why should we copy the data? It doubled the size of GPU usage.
  2. Why should we move it to bf16? If we are doing FP8 training, we want to compare everything in FP8.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback!

1)Regarding GPU Memory: You are absolutely right. I will refactor the code to process parameters one-by-one and delete the reconstructed tensors.

2)Regarding BF16 vs FP8: The reason for the BF16 cast is that the training engine (Miles) provides the ground-truth hashes in the model's base precision (BF16). If SGLang is running with FP8 quantization, a raw bitwise comparison of the FP8 memory will always fail against the trainer's BF16 source?, correct me if Im wrong on this.
Since the goal of issue here is to support 'bf16 training fp8 rollout' verification, we must align the precision to the trainer's source to verify the equality. Or do you suggest we should instead quantize the weights on the trainer side before hashing, or is dequantizing one-by-one in SGLang for the check acceptable?. Lmk which way you would prefer for this.

Comment on lines 74 to 103
# Try Dim 0 (ColumnParallel: Gate, Up, QKV)
full_p0 = torch.cat(all_shards, dim=0)
if (
hashlib.sha256(
full_p0.detach()
.cpu()
.contiguous()
.view(torch.uint8)
.numpy()
).hexdigest()
== expected_hash
):
matched_count += 1
continue

# Try Dim 1 (RowParallel: O_proj, Down_proj) - only for 2D weights
if data.ndim > 1:
full_p1 = torch.cat(all_shards, dim=1)
if (
hashlib.sha256(
full_p1.detach()
.cpu()
.contiguous()
.view(torch.uint8)
.numpy()
).hexdigest()
== expected_hash
):
matched_count += 1
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code tries Dim 0, calculates hash, and if it fails, tries Dim 1. This is inefficient. For Row Parallel layers, you are performing the reconstruction and hashing twice. This wastes CPU/GPU bandwidth.

Please determine the correct sharding dimension based on the parameter name or metadata (Row vs. Column Parallel) and reconstruct only once.

all_shards = [
torch.empty_like(data) for _ in range(tp_group.world_size)
]
torch.distributed.all_gather(
Copy link
Collaborator

@zhaochenyang20 zhaochenyang20 Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned, this is an all_gather in a for loop. It launched some GPU kernels N times. We should use overlap methods to verify each parameter. Not verify all the parameters in a large expected_checksums.

In other words, if we can make checksum non-blocking, the passed-in expected_checksums should be a one-element dictionary of {parameter_name, hex_digest}.

@dataclass
class CheckWeightsReqInput(BaseReq):
action: str
checksums: Optional[list | dict] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From your codes, I think this is indeed a {name, hex}. This type hint is wrong?

Also, please follow this:

https://github.com/sgl-project/sglang/pull/17009/changes#r2692231715

# 1. Get raw model state (BF16 for Qwen2.5-3B)
actual_state = dict(self._model_state())

import sglang.srt.distributed.parallel_state as ps
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not import anything here.

sufubao pushed a commit to ModelTC/LightLLM that referenced this pull request Jan 15, 2026
…17009)

This commit implements bitwise weight verification through SHA256 checksums,
similar to sglang's weight verification feature. This enables integrity checking
for model weights, especially useful in distributed inference with tensor parallelism.

Key changes:
- Add WeightChecker utility class (lightllm/utils/weight_checker.py) for computing
  and verifying SHA256 checksums of weight tensors
- Extend BaseWeightTpl and MMWeightTpl with checksum computation methods
- Add TpPartBaseModel methods for model-level weight verification:
  * enable_weight_checksum_verification()
  * compute_weight_checksums()
  * verify_weight_checksums()
- Add CheckWeightsReqInput and CheckWeightsResult dataclasses for API integration
- Include comprehensive documentation (docs/WEIGHT_VERIFICATION.md)
- Add usage examples (examples/weight_verification_example.py)

Features:
- SHA256 checksum computation with optional bfloat16 casting for consistency
- Support for tensor-parallel (TP) sharded weight verification
- Model-level and layer-level checksum APIs
- Detect weight corruption or loading mismatches

Reference: sgl-project/sglang#17009

# STAGE A: Direct Match (Handles Replicated layers like Norms/Embeds)
t_cpu = data.detach().cpu().contiguous()
actual_hash = hashlib.sha256(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use more fast hash?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that is something maybe we could consider. My initial choice was based on avoiding new dependencies. Do you have any suggestions for a faster algorithm or a more standard library that we can use for this PR, if so, then I have to update my miles pr aswell accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants