-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Description
Summary
The clamp_min(1.0 / 31.0) in quantize_int6_per_row (and in the matching QAT/STE fake-quantization path used by many int6 submissions we inspected) appears to waste int6 quantization levels for weight rows where row_max < 1.0.
We measured this clamp triggering on 93% of rows even at WD=0.02, which is the default in the code we were using. At that setting the effect is mild, because most rows are still close enough to 1.0 that the forced scale is only slightly too large. At higher weight decay, where weights shrink more, the effect becomes much larger.
I am not suggesting this invalidates current leaderboard results. The clamp looks like a reasonable safety choice, and current strong configs around WD=0.02-0.04 are only modestly affected. The issue is that it seems to block exploration of higher-WD regimes that otherwise produce much smaller artifacts.
We ran into this while exploring weight decay as a compression strategy. Sharing in case it saves others time, or points toward a better scale encoding.
What's happening
The int6 quantization used in several submissions computes per-row scales like this:
row_max = t.abs().amax(dim=1)
scale = (row_max / 31.0).clamp_min(1.0 / 31.0)When row_max < 1.0, the scale is forced to 1/31 ≈ 0.032, which gives an effective quantization range of [-1.0, +1.0].
If a row’s actual weights live in [-0.3, 0.3], only about 19 of the 64 int6 levels fall into the region the weights actually use. The remaining levels are effectively unused.
The same clamp also appears in the STE fake-quantization path used during QAT (CastedLinear.forward). So training sees the same coarse grid for small rows, rather than a grid matched to the actual weight range.
How widespread it is
On one 12-layer checkpoint trained with WD=0.02:
93.1% of weight rows have row_max < 1.0
Median row_max: 0.54
Mean row_max: 0.71
This is just one checkpoint, so I would not generalize too far from it. But it does suggest the clamp is active on most rows even at fairly standard settings.
The important distinction is frequency vs severity:
- if
row_max = 0.9, forcing scale from0.029to0.032is probably minor - if
row_max = 0.1, forcing scale from0.003to0.032is much more damaging
At WD=0.02, many rows are still in the mildly affected range. As WD increases, the mismatch gets worse.
When it becomes destructive
We swept Muon weight decay from 0.02 to 0.3 on an 11L / 512d model (4xH200, about 3300 steps in 600s):
| MUON_WD | Float BPB | Int6 BPB | Quant Gap | Artifact Size |
|---|---|---|---|---|
| 0.02 | 1.1739 | 1.1829 | 0.009 | 15.73 MB |
| 0.05 | 1.1734 | 1.1932 | 0.020 | 13.56 MB |
| 0.10 | 1.1796 | 1.2168 | 0.037 | 10.98 MB |
| 0.30 | 1.2121 | 1.3257 | 0.114 | 7.14 MB |
The compression benefit is real. Higher WD produces smaller weights, and those compress very well under zstd. But the quantization gap grows quickly and eventually overwhelms the size benefit.
At WD=0.02, the clamp is common but usually not too harmful. At higher WD, it becomes a real bottleneck.
The fix helps quantization but hurts artifact size
Changing clamp_min(1.0 / 31.0) to clamp_min(1e-7) improves quantization quality, at least in the setup below (640-dim model, WD=0.1):
| Variant | QAT Clamp | Export Clamp | Quant Gap | Artifact |
|---|---|---|---|---|
| Original | 1/31 | 1/31 | 0.042 | 14.94 MB |
| Fixed | 1e-7 | 1e-7 | 0.015 | 28.88 MB |
So this is not a free win.
Allowing scales to track very small rows more accurately improves the quantization gap a lot. But the stored scales now span a much wider range, which seems to increase entropy enough that zstd compresses them much worse.
Why this seems worth discussing
We trained a 640-dim model at WD=0.1 and got our best float BPB so far, 1.1563 in 2997 steps. The wider model looks genuinely better in float. The main problem is that quantization quality degrades badly on export.
That makes this feel less like "high WD is bad" and more like "current scale encoding may be the bottleneck."
Possible directions
I have not tested these yet, but they seem like the obvious next things to try:
- Quantized scales: store per-row scales as uint8 plus a per-tensor
scale_max, instead of float16 - Grouped scales: share one scale across N rows to reduce the number of unique scale values
- Other low-entropy scale encodings: anything that preserves small-row resolution without making the scale stream expensive to compress
If anyone has already explored this, I’d be very interested to hear what did or did not work.
Reproducing
# Baseline: clamp is common, effect is mild
MUON_WD=0.02 NUM_LAYERS=11 XSA_LAST_N=4 EMA_ENABLED=1 LATE_QAT=1 \
torchrun --standalone --nproc_per_node=4 train_gpt.py
# Higher WD: clamp remains common, effect becomes much larger
MUON_WD=0.1 NUM_LAYERS=11 XSA_LAST_N=4 EMA_ENABLED=1 LATE_QAT=1 \
torchrun --standalone --nproc_per_node=4 train_gpt.pyTo check how many rows hit the clamp, for each 2D weight tensor:
row_max = param.abs().amax(dim=1)
pct_clamped = (row_max < 1.0).float().mean()Environment
- 4x NVIDIA H200 (144GB)
- PyTorch 2.11+cu130
- Flash Attention v3 (Hopper)