Conversation
There was a problem hiding this comment.
Pull request overview
Adds support for PTPTC-style FP8 MoE quantization by extending MoE weight/scale handling and propagating per-activation-token quantization flags through the fused MoE quant config, along with related checkpoint-loading fixes for per-channel scale tensors.
Changes:
- Extend
Fp8MoEMethodto support per-token/per-channel (PTPTC) FP8 MoE via per-channel weight scales and updated post-load processing. - Propagate
per_act_token_quantintoFusedMoEQuantConfiggroup-shape derivation for correct activation quant descriptor behavior. - Make linear weight loading more robust for per-channel scale tensors (shape edge-cases) and update merged replicated sharding logic for per-channel scale layouts.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
atom/model_ops/moe.py |
Adds PTPTC/per-channel FP8 MoE scale allocation and refactors post-load processing by quant strategy. |
atom/model_ops/linear.py |
Improves checkpoint weight loading for per-channel scales (shape handling) and fixes sharding offsets for per-channel scales in merged replicated linears. |
atom/model_ops/fused_moe/config.py |
Ensures per_act_token_quant influences activation group-shape selection in fused MoE quant config creation. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
atom/model_ops/linear.py
Outdated
| param.data = param.data.view(loaded_weight.dtype) | ||
| param.data.copy_(post_process_func(loaded_weight)) | ||
| loaded_weight = post_process_func(loaded_weight) | ||
| if loaded_weight.shape != param.data.shape and loaded_weight.numel() == param.data.numel(): |
There was a problem hiding this comment.
The new fallback reshaping in weight_loader_process will silently accept any shape mismatch as long as numel() matches, which can mask real checkpoint/parameter layout bugs (e.g., transposed tensors) and lead to incorrect weights without an error. Consider restricting reshapes to the specific expected cases (e.g., 1D -> (N,1) or squeezing singleton dimensions), and otherwise raise an error with a helpful message that includes both shapes.
| if loaded_weight.shape != param.data.shape and loaded_weight.numel() == param.data.numel(): | |
| # Only allow very specific, safe reshapes. Arbitrary reshapes based solely | |
| # on matching numel() can hide real layout/transpose bugs. | |
| if loaded_weight.shape != param.data.shape: | |
| if loaded_weight.numel() != param.data.numel(): | |
| raise RuntimeError( | |
| f"Loaded weight shape {tuple(loaded_weight.shape)} with numel=" | |
| f"{loaded_weight.numel()} does not match parameter shape " | |
| f"{tuple(param.data.shape)} with numel={param.data.numel()}." | |
| ) | |
| # Allow reshapes that differ only by singleton (size-1) dimensions. | |
| def _strip_ones(shape: torch.Size) -> tuple[int, ...]: | |
| return tuple(int(d) for d in shape if int(d) != 1) | |
| core_loaded = _strip_ones(loaded_weight.shape) | |
| core_param = _strip_ones(param.data.shape) | |
| if core_loaded != core_param: | |
| raise RuntimeError( | |
| "Loaded weight shape is incompatible with parameter shape even " | |
| "though numel() matches, which may indicate a layout/transpose " | |
| "mismatch.\n" | |
| f" Checkpoint shape: {tuple(loaded_weight.shape)}\n" | |
| f" Parameter shape: {tuple(param.data.shape)}" | |
| ) |
21bbc93 to
0d24431
Compare
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist