Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Normalized N-gram + Bayesian First-Match + Pre-Enrichment + XSA

> **Status: Active development** — Score-first TTT integration in progress.

val_bpb: **0.3922** (full-vocab 1024-token normalized n-gram, Bayesian first-match, fixed 0.5 blend)
| 1.1478 (sliding window) | 14.94 MB | 8xH100 SXM

| Metric | PR #810 (standard) | This PR (normalized) |
|---|---|---|
| val_bpb | 0.2722 | **0.3922** |
| Sliding BPP | 1.1478 | 1.1478 |
| N-gram gain over neural | -0.876 BPP | -0.756 BPP |
| **Collision premium** | — | **0.120 BPP** |

## Progress

| | v1 | v2 | v3 | v4 | v5 | v6 | v7 | v8 | v9 | v10 | v11 (this) |
|---|---|---|---|---|---|---|---|---|---|---|---|
| val_bpb | 1.1855 | 1.1709 | 1.1668 | 1.1629 | 1.0689 | 0.9784 | 0.9408 | 0.9393 | 0.2995 | 0.2722 | **0.3922** |
| Eval method | sliding | sliding | sliding | sliding | 5-gram | 2-7 backoff | 2-11 backoff | +PE conf | shared cache | +phrase cache | **normalized** |

v11 is intentionally higher than v10. I replaced standard scoring with
full-vocab 1024-token normalized distributions. The 0.12 BPP increase is the
**collision premium** — the portion of n-gram gain that comes from inflated
pseudo-probabilities rather than genuine statistical signal.

## Key Finding: The Collision Premium

Standard n-gram scoring computes `p = pair_count / ctx_count`. With 4M hash buckets:
- Both counts are inflated by unrelated entries hashing to the same bucket
- The pair bucket is queried specifically for the target token being scored
- This creates an information leak through the collision structure

**Evidence:**
- 256M-bucket experiment (near collision-free): n-gram gain drops to near-zero (1.1123 vs 1.1109 float base)
- This submission (1024-token normalization): 0.3922 vs 0.2722 = **0.120 BPP collision premium**
- The remaining 0.756 BPP gain (1.1478 → 0.3922) is genuine n-gram signal


## Key Contributions

### Full-Vocab 1024-Token Distribution Scoring
For each scored position and each n-gram order, look up counts for all 1024
vocabulary tokens and normalize to sum to 1.0:
```
pair_h = (ctx_hash[:, None] * PAIR_MULT + all_tokens[None, :]) % NG_B # [chunk, 1024]
pair_counts = ng_pair[order][pair_h] # [chunk, 1024]
p_ng = pair_counts / pair_counts.sum(dim=1, keepdim=True) # normalized distribution
```

### Bayesian First-Match with Neural Prior
Instead of raw `pair/ctx` ratio, use Bayesian estimate with neural model as prior:
```
p_local = (raw_correct + beta * p_neural) / (ctx_count + beta)
```
`beta=2.0` — neural prior contributes 2 pseudo-counts. Low-evidence contexts are
smoothed toward the neural prediction rather than overfit to sparse counts.

### A/B Mixing Experiments

| Config | val_bpb | Finding |
|---|---|---|
| Fixed 0.5 blend | **0.3922** | Best — less gating = better |
| Count-confidence (gain=12) | 0.4942 | Confidence gating attenuates real signal |
| Count-confidence (gain=50) | 0.7041 | Too conservative, near-neural baseline |
| Dirichlet mixing (#944 style) | 0.3171 | Wrong for incremental cache (needs high counts) |
| CTW recursive (10 orders) | 2.5326 | Compounding across orders kills neural signal |

Once distributions are normalized, simple mixing outperforms sophisticated approaches.
The n-gram signal is real but sparse — adaptive schemes tend to attenuate it further.

### Phrase Cache Removed
Dropped entirely. The phrase cache uses the same hash-table structure and suffers
the same collision inflation.

## Neural Architecture (unchanged from PR #810)

- **Model**: 10L, 512d, 8H/4KV GQA, MLP 3x, tied embeddings, U-Net skip connections
- **GELU Pre-Enrichment** (512->768->512): wider nonlinear transformation before transformer blocks
- **XSA** on last 4 layers: removes self-value bias (arXiv:2603.09078)
- **SmearGate**: per-dim gate blending each token with previous token
- **BigramHash** (2048x128): hash-table embedding for token bigrams
- **EMA** (decay=0.997) on GPU: 37% faster training (64.7ms vs 101ms/step)
- **Int6 QAT + lzma**: 14.94 MB artifact, quant gap 0.004

## Compliance

- Score-first: n-gram cache updated AFTER scoring each chunk
- Backward-looking: cache at position p contains only tokens 0..p-1
- No oracle selection: blend weight is fixed 0.5, never depends on ground truth
- No training data access during eval
- No two-pass rescoring
- **Normalized distributions**: n-gram probabilities computed across all 1024 tokens

## Reproduction

```
python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80
torchrun --standalone --nproc_per_node=8 train_gpt.py
```
8xH100 SXM, 600s training + ~193s eval.

Tunable env vars: `CTW_BETA=2.0`, `CTW_BLEND=0.5`, `NG_MIN=1`

## Key Metrics

| Metric | Value |
|---|---|
| val_bpb (normalized n-gram) | 0.3922 |
| Sliding window val_bpb | 1.1478 |
| Post-quant val_bpb (standard) | 1.1690 |
| Collision premium vs PR #810 | 0.120 BPP |
| Eval time | 193,472ms |
| Artifact size | 14,942,971 bytes |
| Model parameters | 25,254,992 |

## Credits

- Muon optimizer — modded-nanogpt baseline (kellerjordan)
- SmearGate + BigramHash — PR #65 (@aquariouseworkman)
- XSA — arXiv:2603.09078; GQA-aware PR #265 (@unnir)
- EMA + GPTQ-lite + warmdown tuning — PR #414 (@signalrush)
- N-gram eval cache — concept PR #659 (@deanbrr); fixed 5-gram PR #706 (@newjordan)
- Multi-order backoff — PR #727 (@Asukabot0)
- Shared GPU n-gram cache — PR #796 (@Robby955); PR #800 (@newjordan); PR #809 (@AayushBaniya2006)
- Dirichlet mixing inspiration — PR #944 (@aamodbhatt)
- 256M-bucket collision analysis — competition Issue #140 discussion
- Context Tree Weighting theory — Willems, Shtarkov, Tjalkens (1995)
- GELU Pre-Enrichment — original to this submission
- EMA on GPU — original to this submission
- Full-vocab normalized n-gram scoring — original to this submission
- Collision premium quantification — original to this submission

## Update Log

- v1 (1.1855): int8+zlib, MLP 2x, seq 1024
- v2 (1.1709): int6 QAT + lzma, MLP 3x, SWA, seq 2048
- v3 (1.1668): + SmearGate + BigramHash + EMA + wider pre-enrichment
- v4 (1.1629): + XSA on last 4 layers
- v5 (1.0689): + EMA on GPU (64ms/step) + 5-gram eval cache
- v6 (0.9784): + multi-order backoff 2-7 + entropy-adaptive alpha
- v7 (0.9408): + extended to orders 2-11 + steeper alpha
- v8 (0.9393): + pre-enrichment confidence modulation
- v9 (0.2995): + two-phase shared cache + per-order adaptive alpha (3-seed: 0.2995)
- v10 (0.2722): + long phrase cache (lengths 48, 36, 28, 20, 16)
- **v11 (0.3922): full-vocab normalized n-gram scoring + Bayesian first-match + collision premium analysis**
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"author": "Idanr",
"github_id": "idan3011",
"name": "Normalized N-gram + Bayesian First-Match + Pre-Enrichment + XSA",
"blurb": "Full-vocab 1024-token normalized n-gram scoring + Bayesian first-match-wins (beta=2.0) + fixed 0.5 blend. Multi-order backoff (2-11). Two-phase shared cache. EMA on GPU (64.7ms/step). 10L 512d.",
"date": "2026-03-27T15:30:00Z",
"val_loss": 1.93793804,
"val_bpb": 0.39220592,
"pre_quant_val_loss": 1.9663,
"pre_quant_val_bpb": 1.1646,
"step_stop": 9268,
"wallclock_seconds": 600.031,
"eval_time_seconds": 193.472,
"bytes_total": 14942971,
"bytes_model_int6_lzma": 14878748,
"bytes_code": 64223
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
W0326 02:39:19.172000 34413 torch/distributed/run.py:803]
W0326 02:39:19.172000 34413 torch/distributed/run.py:803] *****************************************
W0326 02:39:19.172000 34413 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0326 02:39:19.172000 34413 torch/distributed/run.py:803] *****************************************
logs/0d771539-26db-4427-b5a8-0a4c24bd56ad.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:25254992
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=True flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9319 val_bpb:4.1055 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9318 train_time:62ms step_avg:61.75ms
step:2/20000 train_loss:7.1516 train_time:121ms step_avg:60.53ms
step:3/20000 train_loss:6.1791 train_time:185ms step_avg:61.59ms
step:4/20000 train_loss:6.4189 train_time:249ms step_avg:62.18ms
step:5/20000 train_loss:6.5862 train_time:313ms step_avg:62.55ms
step:6/20000 train_loss:6.2277 train_time:377ms step_avg:62.78ms
step:7/20000 train_loss:5.4960 train_time:441ms step_avg:62.97ms
step:8/20000 train_loss:5.2973 train_time:505ms step_avg:63.10ms
step:9/20000 train_loss:5.0005 train_time:569ms step_avg:63.20ms
step:10/20000 train_loss:4.8514 train_time:633ms step_avg:63.30ms
step:200/20000 train_loss:2.7511 train_time:12872ms step_avg:64.36ms
step:400/20000 train_loss:2.2579 train_time:25781ms step_avg:64.45ms
step:600/20000 train_loss:2.4713 train_time:38736ms step_avg:64.56ms
step:800/20000 train_loss:2.2316 train_time:51722ms step_avg:64.65ms
step:1000/20000 train_loss:2.3340 train_time:64727ms step_avg:64.73ms
step:1000/20000 val_loss:2.2855 val_bpb:1.3536 train_time:64739ms step_avg:64.74ms
step:1200/20000 train_loss:2.3620 train_time:77744ms step_avg:64.79ms
step:1400/20000 train_loss:2.3964 train_time:90750ms step_avg:64.82ms
step:1600/20000 train_loss:2.0689 train_time:103750ms step_avg:64.84ms
step:1800/20000 train_loss:2.1729 train_time:116742ms step_avg:64.86ms
step:2000/20000 train_loss:2.2158 train_time:129716ms step_avg:64.86ms
step:2000/20000 val_loss:2.1975 val_bpb:1.3015 train_time:129728ms step_avg:64.86ms
step:2200/20000 train_loss:2.0324 train_time:142686ms step_avg:64.86ms
step:2400/20000 train_loss:2.1624 train_time:155641ms step_avg:64.85ms
step:2600/20000 train_loss:2.3841 train_time:168596ms step_avg:64.84ms
step:2800/20000 train_loss:2.2002 train_time:181543ms step_avg:64.84ms
step:3000/20000 train_loss:2.1908 train_time:194474ms step_avg:64.82ms
step:3000/20000 val_loss:2.1539 val_bpb:1.2757 train_time:194486ms step_avg:64.83ms
step:3200/20000 train_loss:2.1563 train_time:207406ms step_avg:64.81ms
step:3400/20000 train_loss:2.1250 train_time:220338ms step_avg:64.81ms
step:3600/20000 train_loss:2.0721 train_time:233268ms step_avg:64.80ms
step:3800/20000 train_loss:2.1786 train_time:246196ms step_avg:64.79ms
step:4000/20000 train_loss:2.1419 train_time:259115ms step_avg:64.78ms
step:4000/20000 val_loss:2.1367 val_bpb:1.2655 train_time:259127ms step_avg:64.78ms
step:4200/20000 train_loss:2.1372 train_time:272101ms step_avg:64.79ms
step:4400/20000 train_loss:2.0839 train_time:285022ms step_avg:64.78ms
step:4600/20000 train_loss:1.9446 train_time:297946ms step_avg:64.77ms
step:4800/20000 train_loss:2.2371 train_time:310856ms step_avg:64.76ms
step:5000/20000 train_loss:1.9905 train_time:323763ms step_avg:64.75ms
step:5000/20000 val_loss:2.1285 val_bpb:1.2606 train_time:323775ms step_avg:64.76ms
step:5200/20000 train_loss:2.1516 train_time:336678ms step_avg:64.75ms
step:5400/20000 train_loss:2.1670 train_time:349585ms step_avg:64.74ms
step:5600/20000 train_loss:2.1609 train_time:362500ms step_avg:64.73ms
step:5800/20000 train_loss:2.1178 train_time:375416ms step_avg:64.73ms
step:6000/20000 train_loss:2.1963 train_time:388331ms step_avg:64.72ms
step:6000/20000 val_loss:2.1194 val_bpb:1.2552 train_time:388343ms step_avg:64.72ms
step:6200/20000 train_loss:2.0618 train_time:401239ms step_avg:64.72ms
step:6400/20000 train_loss:2.1328 train_time:414152ms step_avg:64.71ms
step:6600/20000 train_loss:2.0839 train_time:427067ms step_avg:64.71ms
step:6800/20000 train_loss:2.1327 train_time:439971ms step_avg:64.70ms
step:7000/20000 train_loss:2.1739 train_time:452890ms step_avg:64.70ms
step:7000/20000 val_loss:2.0766 val_bpb:1.2299 train_time:452903ms step_avg:64.70ms
step:7200/20000 train_loss:2.1442 train_time:465802ms step_avg:64.69ms
step:7400/20000 train_loss:2.0575 train_time:478715ms step_avg:64.69ms
step:7600/20000 train_loss:1.9264 train_time:491637ms step_avg:64.69ms
step:7800/20000 train_loss:2.0683 train_time:504556ms step_avg:64.69ms
step:8000/20000 train_loss:2.0304 train_time:517550ms step_avg:64.69ms
step:8000/20000 val_loss:2.0324 val_bpb:1.2037 train_time:517563ms step_avg:64.70ms
step:8200/20000 train_loss:2.1001 train_time:530461ms step_avg:64.69ms
step:8400/20000 train_loss:2.0298 train_time:543436ms step_avg:64.69ms
step:8600/20000 train_loss:2.0308 train_time:556429ms step_avg:64.70ms
step:8800/20000 train_loss:1.9809 train_time:569549ms step_avg:64.72ms
step:9000/20000 train_loss:1.8848 train_time:582572ms step_avg:64.73ms
step:9000/20000 val_loss:1.9773 val_bpb:1.1711 train_time:582573ms step_avg:64.73ms
step:9200/20000 train_loss:1.9494 train_time:595634ms step_avg:64.74ms
step:9268/20000 val_loss:1.9663 val_bpb:1.1646 train_time:600031ms step_avg:64.74ms
stopping_early: wallclock_cap train_time:600031ms step:9268/20000
peak memory allocated: 13058 MiB reserved: 13280 MiB
swa: averaging 14 checkpoints on top of EMA
ema: loading weights
Serialized model: 99486509 bytes
Code size: 64223 bytes
Total submission size: 99550732 bytes
Serialized model int6+lzma: 14878748 bytes (payload:25993024 raw_torch:26045291 payload_ratio:3.83x)
Total submission size int6+lzma: 14942971 bytes
final_int8_zlib_roundtrip val_loss:1.9738 val_bpb:1.1690 eval_time:2054ms
final_int8_zlib_roundtrip_exact val_loss:1.97382834 val_bpb:1.16901232
final_sliding_window sliding_bpb:1.1478 val_bpb:0.3922 eval_time:193472ms
final_sliding_window_exact sliding_bpb:1.14775606 val_bpb:0.39220592
Loading