Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Poly5 Softcap + BigramHash(3072) + Wider GPTQ-lite + Temperature Scaling + Z-Loss

**11L 512d + LeakyReLU(0.5)² + Poly5 Softcap + BigramHash(3072) + Wider GPTQ-lite (9-pct) + Z-Loss + Temperature Scaling + Legal TTT + Parallel Muon**

## Summary

This submission builds on the current SOTA (LeakyReLU² + Legal TTT + Parallel Muon, 1.1194 BPB) with 6 targeted improvements, each with evidence from ablations or related submissions:

## Improvements over SOTA

### 1. Poly-5 Softcap (replaces tanh)
- **What:** Degree-5 polynomial approximation of tanh: `x * (1 - x²/3 + x⁴/15)` clamped to [-1, 1]
- **Why:** Better torch.compile kernel fusion (tanh breaks fusion per ternary submission analysis: 16ms/step faster). Also provides smoother gradient landscape.
- **Evidence:** Used successfully in the ternary submission (1.1570 BPB). Critical finding from that work: "Switching to tanh broke fusion — F63 was 16ms/step slower."
- **Expected impact:** ~0.001 BPB from faster training (more steps in 10 min) + marginal quality improvement

### 2. BigramHash(3072) (up from 2048)
- **What:** Increased bigram hash embedding vocabulary from 2048 to 3072
- **Why:** The SOTA's own ablation table shows `BigramHash 2048→3072: -0.0009 BPB`
- **Expected impact:** -0.0009 BPB

### 3. Wider GPTQ-lite Percentile Search (9 candidates vs 5)
- **What:** Expanded quantization clip percentile candidates from `[0.999, 0.9995, 0.9999, 0.99999, 1.0]` to `[0.998, 0.999, 0.9993, 0.9995, 0.9997, 0.9999, 0.99995, 0.99999, 1.0]`
- **Why:** More candidates = lower MSE reconstruction error for each weight row. Zero training cost (post-training only).
- **Expected impact:** -0.0001 to -0.0003 BPB

### 4. Temperature Scaling (T=0.95) at Eval
- **What:** Apply temperature=0.95 to logits during sliding window and TTT evaluation
- **Why:** Sharpening the distribution can improve BPB when the model is slightly oversmoothed. Used in ternary submission (T=0.90).
- **Expected impact:** -0.001 to -0.002 BPB (conservative T=0.95 vs aggressive T=0.90)

### 5. Z-Loss Regularization (weight=1e-4)
- **What:** Added `z_loss = 1e-4 * mean(logsumexp(logits)²)` to training loss
- **Why:** Stabilizes logit magnitudes, prevents loss spikes, improves training stability. Standard technique from PaLM/Gemini.
- **Expected impact:** -0.0005 to -0.001 BPB from more stable training

### 6. LZMA Preset 9 Compression (up from 6)
- **What:** Increased lzma compression level from 6 to 9 (maximum)
- **Why:** Better compression ratio means smaller artifact, leaving more headroom for parameters
- **Expected impact:** ~0.5-1% smaller artifact, marginal quality improvement

## Architecture (preserved from SOTA)

- **Layers:** 11 (512d, 8 heads, 4 KV heads GQA)
- **MLP:** 3x expansion, LeakyReLU(0.5)² activation
- **Attention:** XSA on last 4 layers, Partial RoPE (16/64 dims)
- **Embeddings:** BigramHash(3072, 128d), tied embeddings, ValueEmbedding(128d, layers 9-10)
- **Normalization:** RMSNorm with LN Scale (1/sqrt(layer+1))
- **U-Net:** 5 encoder + 6 decoder with learned skip weights
- **SmearGate:** Token blending with learned gate

## Training (preserved from SOTA)

- **Optimizer:** Parallel Muon (batched Newton-Schulz) for banks + AdamW for embeddings/scalars
- **LR:** matrix=0.025, scalar=0.025, tied_embed=0.035
- **Momentum:** 0.99 (warmup from 0.92 over 1500 steps)
- **Weight decay:** 0.04 (both Muon and Adam)
- **Warmdown:** 3500 iterations
- **Weight averaging:** EMA(0.997) + Tight SWA (every 50 steps when scale < 0.2)
- **Late QAT:** STE int6 when LR scale < 0.15

## Evaluation

- **Sliding window:** stride=64 with temperature scaling (T=0.95)
- **Legal TTT:** Score-first protocol, 3 epochs SGD per chunk, all blocks unfrozen
- **Quantization:** GPTQ-lite int6 (9-percentile search) + lzma-9 compression

## Run Command

```bash
torchrun --nproc_per_node=8 train_gpt.py
```

With TTT enabled:
```bash
TTT_ENABLED=1 torchrun --nproc_per_node=8 train_gpt.py
```

## Expected Results

Based on individual technique deltas:
- SOTA baseline: 1.1194 BPB
- + Poly5 softcap: ~-0.001 (faster steps + quality)
- + BigramHash 3072: ~-0.0009 (ablation-proven)
- + Wider GPTQ-lite: ~-0.0002 (better quantization)
- + Temperature T=0.95: ~-0.001 (conservative estimate)
- + Z-loss: ~-0.0005 (training stability)
- + LZMA-9: marginal
- **Estimated total: ~1.116 BPB** (before TTT), **~1.113 BPB** (with TTT)

Note: Actual results require 8xH100 training. Individual improvements may not stack linearly.
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/bin/bash
# Run training + evaluation with all improvements enabled
# Requires: 8xH100 80GB SXM

set -euo pipefail

# Download data if needed
python3 data/cached_challenge_fineweb.py --variant sp1024

# --- Training configuration ---
export SEED=${SEED:-1337}
export NUM_LAYERS=11
export MODEL_DIM=512
export NUM_HEADS=8
export NUM_KV_HEADS=4
export MLP_MULT=3
export TRAIN_SEQ_LEN=2048
export EVAL_SEQ_LEN=2048
export TRAIN_BATCH_TOKENS=786432
export BIGRAM_VOCAB_SIZE=3072
export BIGRAM_DIM=128
export XSA_LAST_N=4
export ROPE_DIMS=16
export LN_SCALE=1
export VE_ENABLED=1
export VE_DIM=128
export VE_LAYERS="9,10"
export TIE_EMBEDDINGS=1
export TIED_EMBED_LR=0.035
export MATRIX_LR=0.025
export SCALAR_LR=0.025
export MUON_MOMENTUM=0.99
export MUON_MOMENTUM_WARMUP_START=0.92
export MUON_MOMENTUM_WARMUP_STEPS=1500
export MUON_WD=0.04
export ADAM_WD=0.04
export GRAD_CLIP_NORM=0.5
export WARMDOWN_ITERS=3500
export WARMUP_STEPS=20
export LATE_QAT_THRESHOLD=0.15
export SWA_ENABLED=1
export SWA_EVERY=50
export EVAL_STRIDE=64
export MAX_WALLCLOCK_SECONDS=600.0
export LOGIT_SOFTCAP=30.0
export SOFTCAP_TYPE=poly
export Z_LOSS_WEIGHT=1e-4
export EVAL_TEMPERATURE=0.95

# --- Legal TTT ---
export TTT_ENABLED=1
export TTT_LR=0.002
export TTT_EPOCHS=3
export TTT_CHUNK_TOKENS=32768
export TTT_FREEZE_BLOCKS=0
export TTT_MOMENTUM=0.9
export TTT_BATCH_SEQS=32
export TTT_GRAD_CLIP=1.0

# --- Run ---
torchrun --nproc_per_node=8 \
records/track_10min_16mb/2026-03-26_Poly5Softcap_BigramHash3072_WiderGPTQ_TempScale/train_gpt.py \
2>&1 | tee "train_seed${SEED}.log"
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"name": "Poly5 Softcap + BigramHash(3072) + Wider GPTQ-lite + Temperature Scaling + Z-Loss + Legal TTT",
"val_bpb": null,
"bytes_total": null,
"blurb": "Builds on LeakyReLU²+TTT+ParallelMuon SOTA with 6 targeted improvements: (1) Poly-5 softcap replaces tanh for better torch.compile kernel fusion and gradient flow, (2) BigramHash(3072) from ablation-proven expansion, (3) Wider GPTQ-lite percentile search (9 candidates vs 5) for lower quantization error, (4) Temperature scaling (T=0.95) at eval, (5) Z-loss regularization (1e-4) for training stability, (6) LZMA preset 9 for better compression. All other techniques preserved: 11L 512d, XSA4, Partial RoPE 16/64, LN Scale, EMA(0.997)+SWA, LeakyReLU(0.5)², Legal Score-First TTT, Parallel Muon.",
"author": "jimliu741523",
"github_id": "jimliu741523",
"date": "2026-03-26"
}
Loading