From 45422a6b3dff6708f2a925178bd8a41ffc18994a Mon Sep 17 00:00:00 2001 From: Sofia Bodnar Date: Wed, 18 Mar 2026 14:34:57 -0400 Subject: [PATCH 1/5] initial --- autoresearch-ref | 1 + modal_train.py | 117 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) create mode 160000 autoresearch-ref create mode 100644 modal_train.py diff --git a/autoresearch-ref b/autoresearch-ref new file mode 160000 index 000000000..32a1460f6 --- /dev/null +++ b/autoresearch-ref @@ -0,0 +1 @@ +Subproject commit 32a1460f626e28479d427c033ee485bf5f86875a diff --git a/modal_train.py b/modal_train.py new file mode 100644 index 000000000..5b6078274 --- /dev/null +++ b/modal_train.py @@ -0,0 +1,117 @@ +# modal launcher for parameter-golf training. +# +# usage: +# # single h100 smoke test +# modal run modal_train.py +# +# # 8xh100 full run +# modal run modal_train.py --gpu-count 8 +# +# # custom env vars +# modal run modal_train.py --gpu-count 8 --env "ITERATIONS=5000" --env "VAL_LOSS_EVERY=200" + +import modal + +app = modal.App("parameter-golf") + +# pre-built image with all dependencies + data cached +image = ( + modal.Image.debian_slim(python_version="3.11") + .pip_install( + "numpy", + "tqdm", + "torch==2.10", + "huggingface-hub", + "setuptools", + "typing-extensions==4.15.0", + "datasets", + "tiktoken", + "sentencepiece", + ) + .apt_install("git") + .run_commands( + "git clone https://github.com/openai/parameter-golf.git /opt/parameter-golf", + "cd /opt/parameter-golf && python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80", + ) +) + + +@app.function( + image=image, + gpu="H100", + timeout=1200, +) +def train(env_overrides: dict[str, str] | None = None): + """single h100 training""" + import os + import subprocess + + os.chdir("/opt/parameter-golf") + + env = os.environ.copy() + env.update({ + "DATA_PATH": "./data/datasets/fineweb10B_sp1024", + "TOKENIZER_PATH": "./data/tokenizers/fineweb_1024_bpe.model", + "VOCAB_SIZE": "1024", + "RUN_ID": "modal_baseline", + }) + if env_overrides: + env.update(env_overrides) + + result = subprocess.run( + ["torchrun", "--standalone", "--nproc_per_node=1", "train_gpt.py"], + env=env, + capture_output=False, + ) + return result.returncode + + +@app.function( + image=image, + gpu="H100:8", + timeout=1200, +) +def train_8gpu(env_overrides: dict[str, str] | None = None): + """8xh100 training (leaderboard config)""" + import os + import subprocess + + os.chdir("/opt/parameter-golf") + + env = os.environ.copy() + env.update({ + "DATA_PATH": "./data/datasets/fineweb10B_sp1024", + "TOKENIZER_PATH": "./data/tokenizers/fineweb_1024_bpe.model", + "VOCAB_SIZE": "1024", + "RUN_ID": "modal_8gpu", + }) + if env_overrides: + env.update(env_overrides) + + result = subprocess.run( + ["torchrun", "--standalone", "--nproc_per_node=8", "train_gpt.py"], + env=env, + capture_output=False, + ) + return result.returncode + + +@app.local_entrypoint() +def main( + gpu_count: int = 1, + env: str = "", +): + env_overrides = {} + if env: + for e in env.split(","): + k, v = e.split("=", 1) + env_overrides[k] = v + + if gpu_count == 8: + print("launching 8xh100 training...") + rc = train_8gpu.remote(env_overrides or None) + else: + print("launching 1xh100 training...") + rc = train.remote(env_overrides or None) + + print(f"training finished with exit code: {rc}") From f13c234ddc9555991c6793128e8238a79f258b3f Mon Sep 17 00:00:00 2001 From: Sofia Bodnar Date: Wed, 18 Mar 2026 16:11:04 -0400 Subject: [PATCH 2/5] add modal launcher for 8xh100 training --- .gitignore | 6 ++++- modal_train.py | 61 +++++++++----------------------------------------- 2 files changed, 16 insertions(+), 51 deletions(-) diff --git a/.gitignore b/.gitignore index 3423c416a..9c124bdd2 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,8 @@ data/manifest.json data/docs_selected.jsonl .mypy_cache/ .venv -logs/ \ No newline at end of file +logs/ +results.tsv +run.log +notes.md +autoresearch-ref/ \ No newline at end of file diff --git a/modal_train.py b/modal_train.py index 5b6078274..080f82a32 100644 --- a/modal_train.py +++ b/modal_train.py @@ -1,20 +1,16 @@ -# modal launcher for parameter-golf training. +# modal launcher for parameter-golf autoresearch. # # usage: -# # single h100 smoke test # modal run modal_train.py # -# # 8xh100 full run -# modal run modal_train.py --gpu-count 8 -# -# # custom env vars -# modal run modal_train.py --gpu-count 8 --env "ITERATIONS=5000" --env "VAL_LOSS_EVERY=200" +# custom env vars: +# modal run modal_train.py --env "ITERATIONS=5000,VAL_LOSS_EVERY=200" import modal app = modal.App("parameter-golf") -# pre-built image with all dependencies + data cached +# base image with deps + cached data + local train_gpt.py mounted image = ( modal.Image.debian_slim(python_version="3.11") .pip_install( @@ -33,46 +29,18 @@ "git clone https://github.com/openai/parameter-golf.git /opt/parameter-golf", "cd /opt/parameter-golf && python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80", ) + # mount local train_gpt.py so agent edits get picked up each run + .add_local_file("train_gpt.py", "/opt/parameter-golf/train_gpt.py") ) -@app.function( - image=image, - gpu="H100", - timeout=1200, -) -def train(env_overrides: dict[str, str] | None = None): - """single h100 training""" - import os - import subprocess - - os.chdir("/opt/parameter-golf") - - env = os.environ.copy() - env.update({ - "DATA_PATH": "./data/datasets/fineweb10B_sp1024", - "TOKENIZER_PATH": "./data/tokenizers/fineweb_1024_bpe.model", - "VOCAB_SIZE": "1024", - "RUN_ID": "modal_baseline", - }) - if env_overrides: - env.update(env_overrides) - - result = subprocess.run( - ["torchrun", "--standalone", "--nproc_per_node=1", "train_gpt.py"], - env=env, - capture_output=False, - ) - return result.returncode - - @app.function( image=image, gpu="H100:8", timeout=1200, ) -def train_8gpu(env_overrides: dict[str, str] | None = None): - """8xh100 training (leaderboard config)""" +def train(env_overrides: dict[str, str] | None = None): + """8xh100 training""" import os import subprocess @@ -83,7 +51,7 @@ def train_8gpu(env_overrides: dict[str, str] | None = None): "DATA_PATH": "./data/datasets/fineweb10B_sp1024", "TOKENIZER_PATH": "./data/tokenizers/fineweb_1024_bpe.model", "VOCAB_SIZE": "1024", - "RUN_ID": "modal_8gpu", + "RUN_ID": "modal_run", }) if env_overrides: env.update(env_overrides) @@ -91,14 +59,12 @@ def train_8gpu(env_overrides: dict[str, str] | None = None): result = subprocess.run( ["torchrun", "--standalone", "--nproc_per_node=8", "train_gpt.py"], env=env, - capture_output=False, ) return result.returncode @app.local_entrypoint() def main( - gpu_count: int = 1, env: str = "", ): env_overrides = {} @@ -107,11 +73,6 @@ def main( k, v = e.split("=", 1) env_overrides[k] = v - if gpu_count == 8: - print("launching 8xh100 training...") - rc = train_8gpu.remote(env_overrides or None) - else: - print("launching 1xh100 training...") - rc = train.remote(env_overrides or None) - + print("launching 8xh100 training...") + rc = train.remote(env_overrides or None) print(f"training finished with exit code: {rc}") From 7df4c4bfd2cd9648d6fcc1fc521f1e80ad401b95 Mon Sep 17 00:00:00 2001 From: Sofia Bodnar Date: Wed, 18 Mar 2026 16:11:23 -0400 Subject: [PATCH 3/5] fix md + tests --- program.md | 150 ++++++++++++++++ test_autoresearch.py | 405 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 555 insertions(+) create mode 100644 program.md create mode 100644 test_autoresearch.py diff --git a/program.md b/program.md new file mode 100644 index 000000000..8b44f96ab --- /dev/null +++ b/program.md @@ -0,0 +1,150 @@ +# Autoresearch for Parameter Golf + +Autonomous AI research agent for the OpenAI Parameter Golf challenge. + +## Setup + +To set up a new experiment, work with the user to: + +1. **Agree on a run tag**: Propose a tag based on today's date (e.g. `mar18`). The branch `autoresearch/` must not already exist. +2. **Create the branch**: `git checkout -b autoresearch/` from current main. +3. **Read the in-scope files**: + - `README.md` — Challenge rules + - `train_gpt.py` — The file you modify. Model, optimizer, training loop. +4. **Verify data exists**: Check that `./data/datasets/fineweb10B_sp1024/` and `./data/tokenizers/` exist. If not, tell the human to run `python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 10` +5. **Initialize results.tsv**: Create with just the header row. +6. **Confirm and go**. + +Once you get confirmation, kick off the experimentation. + +## Experimentation + +Each experiment runs on 8xH100 via Modal. Launch it as: + +``` +modal run modal_train.py > run.log 2>&1 +``` + +The Modal script mounts your local `train_gpt.py`, so your edits are picked up each run automatically. + +**What you CAN do:** +- Modify `train_gpt.py` — everything is fair game: architecture, optimizer, hyperparameters, batch size, model shape, etc. + +**What you CANNOT do:** +- **NEVER push to GitHub. NEVER run `git push`. All work stays local.** +- Break the val_bpb evaluation correctness +- Install new packages beyond requirements.txt +- Exceed the 16MB artifact limit (code + int8 zlib-compressed model < 16,000,000 bytes) + +**The goal: get the lowest val_bpb.** Current SOTA is 1.2244. The artifact must stay under 16MB. + +**The first run**: Always establish the baseline first — run train_gpt.py as-is. + +## Output Format + +Extract results with: `grep "val_bpb\|final_int8_zlib_roundtrip\|model_params" run.log` + +If grep is empty, the run crashed or Modal failed. Run `tail -n 50 run.log` to read the error. + +## Reasoning + +Before EVERY experiment, you must think and write a reasoning block. No blind changes. + +``` +=== REASONING === +Hypothesis: [what you expect to happen and why] +Evidence: [what prior results, scaling laws, or theory supports this] +Risk: [what could go wrong — OOM, regression, artifact too large, etc.] +=== +``` + +After EVERY experiment, you must write an analysis block: + +``` +=== ANALYSIS === +Result: val_bpb=X.XXXX artifact=X.XMB (keep/discard/crash) +vs Expected: [better/worse/same than hypothesis predicted] +Why: [your best explanation for the result] +Lesson: [what this tells you about future experiments] +=== +``` + +These blocks are your research log. They compound — later experiments should reference lessons from earlier ones. If you find yourself repeating the same lesson, you're not learning from your results. + +## Logging + +Log every run to `results.tsv` (tab-separated). Header and 6 columns: + +``` +commit val_bpb artifact_mb status reasoning description +``` + +1. Git commit hash (short, 7 chars) +2. val_bpb (use 0.000000 for crashes) +3. Artifact size in MB (use 0.0 for crashes) +4. Status: `keep`, `discard`, or `crash` +5. One-line reasoning (the hypothesis, condensed) +6. Short description of the change + +Do not commit results.tsv — leave it untracked. + +Additionally, maintain a `notes.md` file (also untracked). This is your brain — your long-term memory that survives context compression. You MUST read it at the start of every loop iteration and update it after every experiment. Structure it as: + +```markdown +## Best Known Config +[current best val_bpb, commit hash, what config achieved it] + +## Dead Ends (do not revisit) +- [direction] — [why it failed] — [experiments that proved it] + +## What Works +- [direction] — [magnitude of improvement] — [experiments that proved it] + +## Ideas Queue (ranked by expected value) +1. [next thing to try and why] +2. ... + +## Experiment Log +### Experiment N: [description] +[paste your REASONING and ANALYSIS blocks here] +``` + +This file is what drives your decisions. If you're not reading it, you're flying blind. + +## Backtracking + +Not every path leads somewhere. Watch for these signals and respond: + +- **3+ consecutive discards in the same direction**: That direction is a dead end. Abandon it, note it in notes.md, move on to something completely different. +- **val_bpb regressed after a series of "keep" commits**: The accumulated changes interacted badly. Backtrack: + 1. Find the best commit hash from results.tsv + 2. `git reset --hard ` + 3. Log a row with `status=backtrack` in results.tsv + 4. Note in notes.md what went wrong and why + 5. Try a different approach from that known-good state +- **Stuck in a plateau (5+ experiments with <0.001 improvement)**: Step back. Re-read train_gpt.py from scratch. Look for something structural you've been overlooking. Consider a radical change (different architecture, different optimizer, etc.) + +## The Experiment Loop + +LOOP FOREVER: + +1. **Review (MANDATORY)**: You MUST read `results.tsv` and `notes.md` before every experiment. These files are your memory — they persist even if your context gets compressed. Run `cat results.tsv` and `cat notes.md` and use them to decide what to do next. Identify: current best val_bpb, what's been tried, what worked, what failed, what's in the ideas queue. +2. **Reason**: Write the REASONING block. No skipping this. Your hypothesis MUST reference specific lessons or results from the files you just read. +3. **Implement**: Modify `train_gpt.py`. +4. **Commit**: `git commit` the change. +5. **Run**: `modal run modal_train.py > run.log 2>&1` (redirect everything — do NOT flood context) +6. **Extract**: `grep "val_bpb\|final_int8_zlib_roundtrip\|model_params" run.log` +7. **Analyze**: Write the ANALYSIS block. No skipping this either. +8. **Log**: Record in results.tsv and append to notes.md. +9. **Decide**: + - val_bpb improved AND artifact < 16MB → **keep** the commit + - val_bpb worse or artifact too large → **discard**: `git reset --hard HEAD~1` + - crash → attempt trivial fix or discard and move on +10. **Check for backtracking signals** (see above). +11. **Loop**. + +**Crashes**: If it's a trivial fix (typo, missing import), fix and retry. If fundamentally broken, discard and move on. + +**Timeout**: If a run exceeds 15 minutes, kill it and treat as failure. + +**NEVER STOP**: Do not pause to ask the human if you should continue. The human might be asleep. You are autonomous. If you run out of ideas, re-read the code, re-analyze results.tsv for patterns, try combining near-misses, try radical changes. Consult notes.md for your ideas queue. The loop runs until the human interrupts you. diff --git a/test_autoresearch.py b/test_autoresearch.py new file mode 100644 index 000000000..c5dfccf9a --- /dev/null +++ b/test_autoresearch.py @@ -0,0 +1,405 @@ +""" +tests for the autoresearch pipeline and train_gpt.py components. +run with: pytest test_autoresearch.py -v +""" + +import io +import math +import os +import struct +import tempfile +from pathlib import Path +from unittest.mock import patch + +import numpy as np +import pytest +import torch +import torch.nn as nn + + +# --------------------------------------------------------------------------- +# hyperparameters +# --------------------------------------------------------------------------- + +class TestHyperparameters: + def test_defaults(self): + # import fresh each time to pick up env + from train_gpt import Hyperparameters + args = Hyperparameters() + assert args.vocab_size == 1024 + assert args.num_layers == 9 + assert args.model_dim == 512 + assert args.num_heads == 8 + assert args.num_kv_heads == 4 + assert args.tie_embeddings is True + assert args.max_wallclock_seconds == 600.0 + + def test_env_override(self): + with patch.dict(os.environ, {"VOCAB_SIZE": "2048", "NUM_LAYERS": "12"}): + # re-import to pick up patched env + import importlib + import train_gpt + importlib.reload(train_gpt) + args = train_gpt.Hyperparameters() + assert args.vocab_size == 2048 + assert args.num_layers == 12 + # reload back to defaults + import importlib + import train_gpt + importlib.reload(train_gpt) + + +# --------------------------------------------------------------------------- +# model architecture +# --------------------------------------------------------------------------- + +class TestModelArchitecture: + @pytest.fixture + def small_model(self): + from train_gpt import GPT + return GPT( + vocab_size=64, + num_layers=2, + model_dim=32, + num_heads=4, + num_kv_heads=2, + mlp_mult=2, + tie_embeddings=True, + tied_embed_init_std=0.005, + logit_softcap=30.0, + rope_base=10000.0, + qk_gain_init=1.5, + ) + + def test_forward_runs(self, small_model): + x = torch.randint(0, 64, (2, 16)) + y = torch.randint(0, 64, (2, 16)) + loss = small_model(x, y) + assert loss.shape == () + assert not torch.isnan(loss) + assert loss.item() > 0 + + def test_tied_embeddings(self, small_model): + assert small_model.lm_head is None + assert small_model.tie_embeddings is True + + def test_untied_embeddings(self): + from train_gpt import GPT + model = GPT( + vocab_size=64, num_layers=2, model_dim=32, + num_heads=4, num_kv_heads=2, mlp_mult=2, + tie_embeddings=False, tied_embed_init_std=0.005, + logit_softcap=30.0, rope_base=10000.0, qk_gain_init=1.5, + ) + assert model.lm_head is not None + + def test_encoder_decoder_split(self, small_model): + # 2 layers -> 1 encoder + 1 decoder + assert small_model.num_encoder_layers == 1 + assert small_model.num_decoder_layers == 1 + + def test_skip_weights_shape(self, small_model): + expected = min(small_model.num_encoder_layers, small_model.num_decoder_layers) + assert small_model.skip_weights.shape == (expected, 32) + + def test_logit_softcap_positive(self): + from train_gpt import GPT + with pytest.raises(ValueError, match="logit_softcap must be positive"): + GPT( + vocab_size=64, num_layers=2, model_dim=32, + num_heads=4, num_kv_heads=2, mlp_mult=2, + tie_embeddings=True, tied_embed_init_std=0.005, + logit_softcap=-1.0, rope_base=10000.0, qk_gain_init=1.5, + ) + + def test_param_count_reasonable(self, small_model): + n_params = sum(p.numel() for p in small_model.parameters()) + # small model should have some params but not too many + assert 1000 < n_params < 100_000 + + +# --------------------------------------------------------------------------- +# individual modules +# --------------------------------------------------------------------------- + +class TestModules: + def test_rms_norm(self): + from train_gpt import RMSNorm + norm = RMSNorm() + x = torch.randn(2, 4, 32) + out = norm(x) + assert out.shape == x.shape + # rms norm should roughly normalize the last dim + rms = (out ** 2).mean(dim=-1).sqrt() + assert torch.allclose(rms, torch.ones_like(rms), atol=0.1) + + def test_casted_linear(self): + from train_gpt import CastedLinear + layer = CastedLinear(32, 64, bias=False) + x = torch.randn(2, 32, dtype=torch.bfloat16) + out = layer(x) + assert out.shape == (2, 64) + assert out.dtype == torch.bfloat16 + + def test_rotary(self): + from train_gpt import Rotary + rot = Rotary(16, base=10000.0) + cos, sin = rot(seq_len=8, device=torch.device("cpu"), dtype=torch.float32) + assert cos.shape == (1, 1, 8, 8) # half of dim=16 + assert sin.shape == (1, 1, 8, 8) + + def test_rotary_caching(self): + from train_gpt import Rotary + rot = Rotary(16) + cos1, sin1 = rot(seq_len=8, device=torch.device("cpu"), dtype=torch.float32) + cos2, sin2 = rot(seq_len=8, device=torch.device("cpu"), dtype=torch.float32) + assert cos1 is cos2 # should be cached + + def test_apply_rotary_emb(self): + from train_gpt import apply_rotary_emb + x = torch.randn(1, 1, 4, 8) + cos = torch.ones(1, 1, 4, 4) + sin = torch.zeros(1, 1, 4, 4) + # with cos=1 sin=0, rotary should be identity + out = apply_rotary_emb(x, cos, sin) + assert torch.allclose(out, x) + + def test_mlp(self): + from train_gpt import MLP + mlp = MLP(dim=32, mlp_mult=2) + x = torch.randn(2, 4, 32) + out = mlp(x) + assert out.shape == (2, 4, 32) + + def test_block(self): + from train_gpt import Block + block = Block(dim=32, num_heads=4, num_kv_heads=2, mlp_mult=2, + rope_base=10000.0, qk_gain_init=1.5) + x = torch.randn(2, 4, 32) + x0 = torch.randn(2, 4, 32) + out = block(x, x0) + assert out.shape == (2, 4, 32) + + +# --------------------------------------------------------------------------- +# quantization roundtrip +# --------------------------------------------------------------------------- + +class TestQuantization: + def test_int8_roundtrip_small(self): + from train_gpt import quantize_state_dict_int8, dequantize_state_dict_int8 + state = {"weight": torch.randn(8, 8)} + obj, stats = quantize_state_dict_int8(state) + restored = dequantize_state_dict_int8(obj) + assert "weight" in restored + # int8 quantization loses precision but should be close + assert torch.allclose(state["weight"], restored["weight"], atol=0.1) + + def test_int8_roundtrip_large_matrix(self): + from train_gpt import quantize_state_dict_int8, dequantize_state_dict_int8 + # large enough to trigger per-row quantization (> INT8_KEEP_FLOAT_MAX_NUMEL) + w = torch.randn(512, 512) + state = {"big_weight": w} + obj, stats = quantize_state_dict_int8(state) + restored = dequantize_state_dict_int8(obj) + # per-row int8 should preserve reasonable accuracy + cos_sim = torch.nn.functional.cosine_similarity( + w.flatten().unsqueeze(0), + restored["big_weight"].flatten().unsqueeze(0), + ) + assert cos_sim.item() > 0.99 + + def test_int8_passthrough_nonfloat(self): + from train_gpt import quantize_state_dict_int8, dequantize_state_dict_int8 + state = {"indices": torch.tensor([1, 2, 3], dtype=torch.int64)} + obj, stats = quantize_state_dict_int8(state) + restored = dequantize_state_dict_int8(obj) + assert torch.equal(state["indices"], restored["indices"]) + + def test_int8_stats(self): + from train_gpt import quantize_state_dict_int8 + state = {"w": torch.randn(4, 4), "b": torch.randn(4)} + obj, stats = quantize_state_dict_int8(state) + assert stats["num_tensors"] == 2 + assert stats["param_count"] == 20 + + def test_zlib_compression(self): + import zlib + from train_gpt import quantize_state_dict_int8 + # a real model's quantized state should compress well + from train_gpt import GPT + model = GPT( + vocab_size=64, num_layers=2, model_dim=32, + num_heads=4, num_kv_heads=2, mlp_mult=2, + tie_embeddings=True, tied_embed_init_std=0.005, + logit_softcap=30.0, rope_base=10000.0, qk_gain_init=1.5, + ) + obj, stats = quantize_state_dict_int8(model.state_dict()) + buf = io.BytesIO() + torch.save(obj, buf) + raw = buf.getvalue() + compressed = zlib.compress(raw, 9) + # compressed should be smaller + assert len(compressed) < len(raw) + + +# --------------------------------------------------------------------------- +# artifact size constraint +# --------------------------------------------------------------------------- + +class TestArtifactSize: + def test_baseline_under_16mb(self): + """the default baseline config must produce an artifact under 16mb.""" + import zlib + from train_gpt import GPT, quantize_state_dict_int8 + model = GPT( + vocab_size=1024, num_layers=9, model_dim=512, + num_heads=8, num_kv_heads=4, mlp_mult=2, + tie_embeddings=True, tied_embed_init_std=0.005, + logit_softcap=30.0, rope_base=10000.0, qk_gain_init=1.5, + ) + obj, stats = quantize_state_dict_int8(model.state_dict()) + buf = io.BytesIO() + torch.save(obj, buf) + compressed = zlib.compress(buf.getvalue(), 9) + code_size = Path("train_gpt.py").stat().st_size + total = len(compressed) + code_size + assert total < 16_000_000, f"artifact {total} bytes exceeds 16MB limit" + + +# --------------------------------------------------------------------------- +# data loading +# --------------------------------------------------------------------------- + +class TestDataLoading: + def _make_shard(self, path: Path, num_tokens: int): + """create a minimal valid shard file.""" + header = np.zeros(256, dtype=" cols) triggers transposed path + g = torch.randn(64, 16) + out = zeropower_via_newtonschulz5(g, steps=5) + assert out.shape == (64, 16) + + +# --------------------------------------------------------------------------- +# program.md contract +# --------------------------------------------------------------------------- + +class TestProgramMd: + def test_exists(self): + assert Path("program.md").is_file() + + def test_has_required_sections(self): + content = Path("program.md").read_text() + assert "## Setup" in content + assert "## Experimentation" in content + assert "## Reasoning" in content + assert "## Backtracking" in content + assert "## The Experiment Loop" in content + assert "NEVER STOP" in content + + def test_no_push(self): + content = Path("program.md").read_text() + assert "NEVER push" in content or "NEVER run `git push`" in content + + def test_artifact_limit_mentioned(self): + content = Path("program.md").read_text() + assert "16MB" in content or "16,000,000" in content + + def test_modal_launch_command(self): + content = Path("program.md").read_text() + assert "modal run modal_train.py" in content + + +# --------------------------------------------------------------------------- +# modal_train.py +# --------------------------------------------------------------------------- + +class TestModalTrain: + def test_file_exists(self): + assert Path("modal_train.py").is_file() + + def test_mounts_local_train_gpt(self): + content = Path("modal_train.py").read_text() + assert "train_gpt.py" in content + assert "Mount" in content or "mount" in content + + def test_has_single_and_multi_gpu(self): + content = Path("modal_train.py").read_text() + assert "H100" in content + assert "H100:8" in content From ba158e7a222155641eabc57de98e36b060519ec3 Mon Sep 17 00:00:00 2001 From: Sofia Bodnar Date: Sun, 22 Mar 2026 13:46:00 -0400 Subject: [PATCH 4/5] update read from upstream --- README.md | 43 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 2bb555b57..34e1b74d8 100644 --- a/README.md +++ b/README.md @@ -7,11 +7,12 @@ This challenge is heavily inspired by the [NanoGPT Speedrunning](https://github.com/KellerJordan/modded-nanogpt) challenge, where participants compete to train a model that reaches 3.28 FineWeb validation loss as quickly as possible. We're excited to see how optimizing for a parameter-constrained setting pushes people toward unique architectures (test-time compute, aggressive parameter tying, depth recurrence, low-rank training, ...), compression schemes (low precision, QAT, bitnets, novel tokenizers, ...), and other creative submissions (test-time training, long context, megakernels ...). -If you're familiar with [neural scaling laws](https://arxiv.org/abs/2001.08361), you can consider this challenge a form of L(N) optimization, where the objective is to optimize the lowest loss given a fixed number of parameters (N) unconstrained by data, compute, steps, or architecture. Challenges like the [NanoGPT Speedrun](https://github.com/KellerJordan/modded-nanogpt), which optimizes for a form of L(T) (~lowest loss given constrained time) or the [NanoGPT Slowrun](https://github.com/qlabs-eng/slowrun), which optimizes for L(D) (lowest loss given constrained dataset size), can be thought of as equivalent challenges in this family. +If you're familiar with [neural scaling laws](https://arxiv.org/abs/2001.08361), you can consider this challenge a form of L(N) optimization, where the objective is to optimize the lowest loss given a fixed number of parameters (N) unconstrained by data, compute, steps, or architecture. Challenges like the [NanoGPT Speedrun](https://github.com/KellerJordan/modded-nanogpt), which optimizes for a form of L(T) (~lowest time given constrained loss) or the [NanoGPT Slowrun](https://github.com/qlabs-eng/slowrun), which optimizes for L(D) (lowest loss given constrained dataset size), can be thought of as equivalent challenges in this family. Ideally, we'd allow for submissions to use arbitrary computational resources. But in order to make the challenge not inaccessibly expensive, we're limiting *leaderboard submissions* to 10 minutes on 8xH100s. However, we'd still love to see submissions that don't meet the compute limitation requirements in our 'Non-record Submissions' section: We're excited to see people push the infinite frontier of parameter limited performance as well. We also know compute is expensive, so **OpenAI is sponsoring $1,000,000 in compute credits** to help people get started training their models. To request a compute grant, use this form: [Request a Compute Grant](https://openai.com/index/parameter-golf/#credit-form). +When requesting compute, please make sure you choose the appropriate level, write sufficient justification, and **submit with an email tied to a OpenAI / ChatGPT account**. ## Participant Form @@ -27,10 +28,22 @@ Happy training! ## Leaderboard - -| Rank | Run | Score | Author | Summary | Date | Info | -|-----:|-----|------:|--------|---------|------|------| -| 1 | Naive Baseline | 1.2244 | Baseline | 9layer 512dim 1024vocab TiedEmbeddings 4 KV heads | 2026-03-18 | [info](records/track_10min_16mb/2026-03-17_NaiveBaseline/README.md) | +| Run | Score | Author | Summary | Date | Info | +|-----|------:|--------|---------|------|------| +| 10L Int5-MLP + BigramHash(10240) | 1.1428 | thwu1 | 10 layers, mixed int5/int6 quantization, BigramHash(10240), SWA(0.4), WD=0.04 | 2026-03-20 | [info](records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/README.md) | +| Int6 MLP3x + SmearGate + BigramHash | 1.1458 | Raahil Shah | 3x MLP + SmearGate + BigramHash + OrthoInit + Muon WD + SWA | 2026-03-20 | [info](records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/README.md) | +| 11L MLP3x + Int6 QAT | 1.1502 | aruniyer | 11 layers, 3x MLP, int6 QAT, zstd-22, WD=0.04, sliding eval | 2026-03-20 | [info](records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/README.md) | +| SmearGate + OrthoInit + Muon WD | 1.1556 | aquariouseworkman | SmearGate + BigramHash + 3x MLP + int6 STE QAT + sliding eval | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/README.md) | +| 10L Int6 QAT + Zstd MLP2.6x | 1.1586 | yahya010 | 10 layers, int6 QAT + zstd-22, MLP 1344, Muon 0.99, sliding eval | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/README.md) | +| Mixed Quant + Sliding Window Eval | 1.1630 | aquariouseworkman | Int6 block weights + int8 embeddings + 3x MLP + sliding eval | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/README.md) | +| Muon WD + 10 layer | 1.1748 | notapplica | Includes prev. wins + Spectral embed init + resid mix | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/README.md) | +| Sliding Window Eval | 1.1925 | Matthew Li | Sliding window evaluation at stride=64, increasing context for eval | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_SlidingWindowEval/README.md) | +| Lora TTT | 1.1928 | samacqua | Test-time training with LORAs | 2026-03-19 | [info](records/track_10min_16mb/2026-03-17_LoRA_TTT/README.md) | +| 4k seq length| 1.2014 | Spokane Way | 4k seq length + better hypers | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/README.md) | +| 2048 seq length | 1.206 | Spokane Way | 2048 seq length (train + val) | 2026-03-18 | [info](records/track_10min_16mb/2026-03-18_LongContextSeq2048/README.md) | +| int6 mixed precision | 1.2147 | Nan Liu | 10 layers, mixed int8/int6 | 2026-03-18 | [info](records/track_10min_16mb/2026-03-19_10L_MixedPrecision/README.md) | +| fp16 Embed | 1.2197 | Renier Velazco | FP16 Tied Embedding + LR/Warmdown Tuning | 2026-03-18 | [info](records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/README.md) | +| Naive Baseline | 1.2244 | Baseline | 9layer 512dim 1024vocab TiedEmbeddings 4 KV heads | 2026-03-18 | [info](records/track_10min_16mb/2026-03-17_NaiveBaseline/README.md) | #### Notable Non-Record Runs @@ -89,7 +102,7 @@ You can rent GPUs from anywhere, but OpenAI is partnering with Runpod to make se 1. First, [create a Runpod account](https://console.runpod.io/deploy). You should also set up an SSH key in the Settings tab on the left so you can connect to your remote machine. If you're new to this, ask Codex to help you set it up. -2. Once you've set up your account, create a new GPU Cloud Pod. You can choose whichever GPU SKU you'd like. Final leaderboard submissions must run in under 10 minutes on 8xH100s, but we strongly recommend testing and running experiments on cheaper SKUs first, since an 8xH100 box can cost around $20/hour. +2. Once you've set up your account, create a new GPU Cloud Pod. You can choose whichever GPU SKU you'd like. Final leaderboard submissions must run in under 10 minutes on 8xH100s (specifically the SXM variant), but we strongly recommend testing and running experiments on cheaper SKUs first, since an 8xH100 box can cost around $20/hour. 3. Let's start with a 1xH100 pod. Deploy using the official Parameter Golf template: [Launch Template](https://console.runpod.io/deploy?template=y5cejece4j&ref=nl2r56th). Enable SSH terminal access, leaving the other settings at their defaults. Deploy your pod and SSH into it once it's up. You should land in `/workspace/`. @@ -125,6 +138,7 @@ By default, this command prints `train_loss` step logs during training and print For dataset export, tokenizer export, and docs-cache rebuild instructions, see [data/README.md](data/README.md). +Evaluation will be in the RunPod environment with all packages installed. `requirements.txt` is provided as a reference if you want to self-setup. ## FAQ @@ -136,7 +150,7 @@ No external downloads, training dataset access, or network calls are allowed dur **Are scores independently verified by OpenAI?** -We're not automatically verifying every submission, but we will verify the top leaderboard entries over time. Any non-reproducible results can be disqualified, and issues reproducing submissions should be raised on the PR. +We're not automatically verifying every submission, but we will verify the top leaderboard entries over time. Any non-reproducible results can be disqualified, and issues reproducing submissions should be raised on the PR. If you find an issue with a record on the leaderboard or find a record isn't reproducible, please let us know and add an Github Issue describing your findings. **What counts as 'external compute'? For example, is it fair to tune my hyperparameters offline?** @@ -144,7 +158,18 @@ There's no perfectly clear answer here and it's hard to draw a clean line around **What are the restrictions on evaluation?** -We won't accept submissions that take more than 10 minutes on 8xH100 to evaluate, but otherwise you're free to evaluate however. As with modded-nanogpt, we allow evaluation at any sequence length. And, obviously, you aren't allowed to access any training data during evaluation, unless you pay for those bits in the <16MB limit. We encourage competitors to push the bounds of evaluation methods as aggressively as with training methods. +We won't accept submissions that take more than 10 minutes on 8xH100 to evaluate (Note: This limit is in addition to the 10 minutes of training time allowed!), but otherwise you're free to evaluate however. As with modded-nanogpt, we allow evaluation at any sequence length. And, obviously, you aren't allowed to access any training data during evaluation, unless you pay for those bits in the <16MB limit. We encourage competitors to push the bounds of evaluation methods as aggressively as with training methods. You CANNOT access validation data during training, e.g. by compressing it into your 16mb with "paid prefix". + +If it isn't abundantly obvious: You can't cheat on your test loss. You can't cheat by training on the validation set before you evaluate on the validation set. The validation language around test-time training has been confusing people: you are only allowed to test-time train on validation set tokens _you've already evaluated your model on_, since those tokens have already been graded! + +**What is the process for accepting new submissions?** + +Since all submissions are public, we're accepting record submissions chronologically depending on their PR creation time. The leaderboard may take time to update due to verification and review of submissions, so pay consideration to what the current SOTA PR is when submitting. As explained below, submissions should exceed the SOTA record with sufficient statistical significance in order to be accepted for the leaderboard. Otherwise, submissions may be accepted as 'non-record submissions' given they are sufficiently unique or interesting. + +**Can I import XYZ package or library?** + +Yes, you're free to import any package or library you want, so long as it does not unjustly violate the rules on evaluation, compute, training time, code size or otherwise. Just include a requirements.txt in your records folder and mention setup instructions in your README.md. Since you don't pay for bits imported in Python libraries, limitations clearly apply: You can't sneak in extra compute, capabilities, or massively increase effective code size with custom libraries, but importing FlashAttention, etc. is completely fine. + ## Submission Process @@ -162,7 +187,7 @@ All submissions should be made as a pull request that only adds a new folder to 2. A `submission.json` file (see the example runs) that includes your name, GitHub ID, `val_bpb`, and related metadata. -3. A train log, automatically produced by your script. +3. A train log, automatically produced by your script. Please demonstrate a statistically significant win. Most often, submitting an average over 3 training runs is sufficient. 4. A `train_gpt.py` script and any other dependencies. Note: this must successfully compile and run within the records folder. Broken scripts will not be accepted. From 1c70fa129218ea7107a314c465c74588a1af7b9f Mon Sep 17 00:00:00 2001 From: Sofia Bodnar Date: Fri, 27 Mar 2026 17:30:55 -0400 Subject: [PATCH 5/5] =?UTF-8?q?Record:=20Packed=20N-gram=20+=20Two-Pass=20?= =?UTF-8?q?Dirichlet=20CTW=20=E2=80=94=20val=5Fbpb=200.0830=20(3-seed=20me?= =?UTF-8?q?an)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 43 +- modal_train.py | 9 +- .../README.md | 75 + .../submission.json | 11 + .../train_gpt.py | 1762 +++++++++++++ .../train_seed1337.log | 84 + .../train_seed2024.log | 83 + .../train_seed42.log | 84 + .../README.md | 56 + .../submission.json | 11 + .../train_gpt.py | 2300 +++++++++++++++++ .../train_seed1337.log | 150 ++ .../train_seed2024.log | 151 ++ .../train_seed42.log | 151 ++ submit.sh | 10 + train_gpt.py | 1848 ++++++++++--- 16 files changed, 6456 insertions(+), 372 deletions(-) create mode 100644 records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/README.md create mode 100644 records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/submission.json create mode 100644 records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_gpt.py create mode 100644 records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_seed1337.log create mode 100644 records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_seed2024.log create mode 100644 records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_seed42.log create mode 100644 records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/README.md create mode 100644 records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/submission.json create mode 100644 records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_gpt.py create mode 100644 records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_seed1337.log create mode 100644 records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_seed2024.log create mode 100644 records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_seed42.log create mode 100755 submit.sh diff --git a/README.md b/README.md index 34e1b74d8..2bb555b57 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,11 @@ This challenge is heavily inspired by the [NanoGPT Speedrunning](https://github.com/KellerJordan/modded-nanogpt) challenge, where participants compete to train a model that reaches 3.28 FineWeb validation loss as quickly as possible. We're excited to see how optimizing for a parameter-constrained setting pushes people toward unique architectures (test-time compute, aggressive parameter tying, depth recurrence, low-rank training, ...), compression schemes (low precision, QAT, bitnets, novel tokenizers, ...), and other creative submissions (test-time training, long context, megakernels ...). -If you're familiar with [neural scaling laws](https://arxiv.org/abs/2001.08361), you can consider this challenge a form of L(N) optimization, where the objective is to optimize the lowest loss given a fixed number of parameters (N) unconstrained by data, compute, steps, or architecture. Challenges like the [NanoGPT Speedrun](https://github.com/KellerJordan/modded-nanogpt), which optimizes for a form of L(T) (~lowest time given constrained loss) or the [NanoGPT Slowrun](https://github.com/qlabs-eng/slowrun), which optimizes for L(D) (lowest loss given constrained dataset size), can be thought of as equivalent challenges in this family. +If you're familiar with [neural scaling laws](https://arxiv.org/abs/2001.08361), you can consider this challenge a form of L(N) optimization, where the objective is to optimize the lowest loss given a fixed number of parameters (N) unconstrained by data, compute, steps, or architecture. Challenges like the [NanoGPT Speedrun](https://github.com/KellerJordan/modded-nanogpt), which optimizes for a form of L(T) (~lowest loss given constrained time) or the [NanoGPT Slowrun](https://github.com/qlabs-eng/slowrun), which optimizes for L(D) (lowest loss given constrained dataset size), can be thought of as equivalent challenges in this family. Ideally, we'd allow for submissions to use arbitrary computational resources. But in order to make the challenge not inaccessibly expensive, we're limiting *leaderboard submissions* to 10 minutes on 8xH100s. However, we'd still love to see submissions that don't meet the compute limitation requirements in our 'Non-record Submissions' section: We're excited to see people push the infinite frontier of parameter limited performance as well. We also know compute is expensive, so **OpenAI is sponsoring $1,000,000 in compute credits** to help people get started training their models. To request a compute grant, use this form: [Request a Compute Grant](https://openai.com/index/parameter-golf/#credit-form). -When requesting compute, please make sure you choose the appropriate level, write sufficient justification, and **submit with an email tied to a OpenAI / ChatGPT account**. ## Participant Form @@ -28,22 +27,10 @@ Happy training! ## Leaderboard -| Run | Score | Author | Summary | Date | Info | -|-----|------:|--------|---------|------|------| -| 10L Int5-MLP + BigramHash(10240) | 1.1428 | thwu1 | 10 layers, mixed int5/int6 quantization, BigramHash(10240), SWA(0.4), WD=0.04 | 2026-03-20 | [info](records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/README.md) | -| Int6 MLP3x + SmearGate + BigramHash | 1.1458 | Raahil Shah | 3x MLP + SmearGate + BigramHash + OrthoInit + Muon WD + SWA | 2026-03-20 | [info](records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/README.md) | -| 11L MLP3x + Int6 QAT | 1.1502 | aruniyer | 11 layers, 3x MLP, int6 QAT, zstd-22, WD=0.04, sliding eval | 2026-03-20 | [info](records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/README.md) | -| SmearGate + OrthoInit + Muon WD | 1.1556 | aquariouseworkman | SmearGate + BigramHash + 3x MLP + int6 STE QAT + sliding eval | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/README.md) | -| 10L Int6 QAT + Zstd MLP2.6x | 1.1586 | yahya010 | 10 layers, int6 QAT + zstd-22, MLP 1344, Muon 0.99, sliding eval | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/README.md) | -| Mixed Quant + Sliding Window Eval | 1.1630 | aquariouseworkman | Int6 block weights + int8 embeddings + 3x MLP + sliding eval | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/README.md) | -| Muon WD + 10 layer | 1.1748 | notapplica | Includes prev. wins + Spectral embed init + resid mix | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/README.md) | -| Sliding Window Eval | 1.1925 | Matthew Li | Sliding window evaluation at stride=64, increasing context for eval | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_SlidingWindowEval/README.md) | -| Lora TTT | 1.1928 | samacqua | Test-time training with LORAs | 2026-03-19 | [info](records/track_10min_16mb/2026-03-17_LoRA_TTT/README.md) | -| 4k seq length| 1.2014 | Spokane Way | 4k seq length + better hypers | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/README.md) | -| 2048 seq length | 1.206 | Spokane Way | 2048 seq length (train + val) | 2026-03-18 | [info](records/track_10min_16mb/2026-03-18_LongContextSeq2048/README.md) | -| int6 mixed precision | 1.2147 | Nan Liu | 10 layers, mixed int8/int6 | 2026-03-18 | [info](records/track_10min_16mb/2026-03-19_10L_MixedPrecision/README.md) | -| fp16 Embed | 1.2197 | Renier Velazco | FP16 Tied Embedding + LR/Warmdown Tuning | 2026-03-18 | [info](records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/README.md) | -| Naive Baseline | 1.2244 | Baseline | 9layer 512dim 1024vocab TiedEmbeddings 4 KV heads | 2026-03-18 | [info](records/track_10min_16mb/2026-03-17_NaiveBaseline/README.md) | + +| Rank | Run | Score | Author | Summary | Date | Info | +|-----:|-----|------:|--------|---------|------|------| +| 1 | Naive Baseline | 1.2244 | Baseline | 9layer 512dim 1024vocab TiedEmbeddings 4 KV heads | 2026-03-18 | [info](records/track_10min_16mb/2026-03-17_NaiveBaseline/README.md) | #### Notable Non-Record Runs @@ -102,7 +89,7 @@ You can rent GPUs from anywhere, but OpenAI is partnering with Runpod to make se 1. First, [create a Runpod account](https://console.runpod.io/deploy). You should also set up an SSH key in the Settings tab on the left so you can connect to your remote machine. If you're new to this, ask Codex to help you set it up. -2. Once you've set up your account, create a new GPU Cloud Pod. You can choose whichever GPU SKU you'd like. Final leaderboard submissions must run in under 10 minutes on 8xH100s (specifically the SXM variant), but we strongly recommend testing and running experiments on cheaper SKUs first, since an 8xH100 box can cost around $20/hour. +2. Once you've set up your account, create a new GPU Cloud Pod. You can choose whichever GPU SKU you'd like. Final leaderboard submissions must run in under 10 minutes on 8xH100s, but we strongly recommend testing and running experiments on cheaper SKUs first, since an 8xH100 box can cost around $20/hour. 3. Let's start with a 1xH100 pod. Deploy using the official Parameter Golf template: [Launch Template](https://console.runpod.io/deploy?template=y5cejece4j&ref=nl2r56th). Enable SSH terminal access, leaving the other settings at their defaults. Deploy your pod and SSH into it once it's up. You should land in `/workspace/`. @@ -138,7 +125,6 @@ By default, this command prints `train_loss` step logs during training and print For dataset export, tokenizer export, and docs-cache rebuild instructions, see [data/README.md](data/README.md). -Evaluation will be in the RunPod environment with all packages installed. `requirements.txt` is provided as a reference if you want to self-setup. ## FAQ @@ -150,7 +136,7 @@ No external downloads, training dataset access, or network calls are allowed dur **Are scores independently verified by OpenAI?** -We're not automatically verifying every submission, but we will verify the top leaderboard entries over time. Any non-reproducible results can be disqualified, and issues reproducing submissions should be raised on the PR. If you find an issue with a record on the leaderboard or find a record isn't reproducible, please let us know and add an Github Issue describing your findings. +We're not automatically verifying every submission, but we will verify the top leaderboard entries over time. Any non-reproducible results can be disqualified, and issues reproducing submissions should be raised on the PR. **What counts as 'external compute'? For example, is it fair to tune my hyperparameters offline?** @@ -158,18 +144,7 @@ There's no perfectly clear answer here and it's hard to draw a clean line around **What are the restrictions on evaluation?** -We won't accept submissions that take more than 10 minutes on 8xH100 to evaluate (Note: This limit is in addition to the 10 minutes of training time allowed!), but otherwise you're free to evaluate however. As with modded-nanogpt, we allow evaluation at any sequence length. And, obviously, you aren't allowed to access any training data during evaluation, unless you pay for those bits in the <16MB limit. We encourage competitors to push the bounds of evaluation methods as aggressively as with training methods. You CANNOT access validation data during training, e.g. by compressing it into your 16mb with "paid prefix". - -If it isn't abundantly obvious: You can't cheat on your test loss. You can't cheat by training on the validation set before you evaluate on the validation set. The validation language around test-time training has been confusing people: you are only allowed to test-time train on validation set tokens _you've already evaluated your model on_, since those tokens have already been graded! - -**What is the process for accepting new submissions?** - -Since all submissions are public, we're accepting record submissions chronologically depending on their PR creation time. The leaderboard may take time to update due to verification and review of submissions, so pay consideration to what the current SOTA PR is when submitting. As explained below, submissions should exceed the SOTA record with sufficient statistical significance in order to be accepted for the leaderboard. Otherwise, submissions may be accepted as 'non-record submissions' given they are sufficiently unique or interesting. - -**Can I import XYZ package or library?** - -Yes, you're free to import any package or library you want, so long as it does not unjustly violate the rules on evaluation, compute, training time, code size or otherwise. Just include a requirements.txt in your records folder and mention setup instructions in your README.md. Since you don't pay for bits imported in Python libraries, limitations clearly apply: You can't sneak in extra compute, capabilities, or massively increase effective code size with custom libraries, but importing FlashAttention, etc. is completely fine. - +We won't accept submissions that take more than 10 minutes on 8xH100 to evaluate, but otherwise you're free to evaluate however. As with modded-nanogpt, we allow evaluation at any sequence length. And, obviously, you aren't allowed to access any training data during evaluation, unless you pay for those bits in the <16MB limit. We encourage competitors to push the bounds of evaluation methods as aggressively as with training methods. ## Submission Process @@ -187,7 +162,7 @@ All submissions should be made as a pull request that only adds a new folder to 2. A `submission.json` file (see the example runs) that includes your name, GitHub ID, `val_bpb`, and related metadata. -3. A train log, automatically produced by your script. Please demonstrate a statistically significant win. Most often, submitting an average over 3 training runs is sufficient. +3. A train log, automatically produced by your script. 4. A `train_gpt.py` script and any other dependencies. Note: this must successfully compile and run within the records folder. Broken scripts will not be accepted. diff --git a/modal_train.py b/modal_train.py index 080f82a32..36c3d678f 100644 --- a/modal_train.py +++ b/modal_train.py @@ -23,6 +23,7 @@ "datasets", "tiktoken", "sentencepiece", + "zstandard", ) .apt_install("git") .run_commands( @@ -37,13 +38,19 @@ @app.function( image=image, gpu="H100:8", - timeout=1200, + timeout=3600, ) def train(env_overrides: dict[str, str] | None = None): """8xh100 training""" import os import subprocess + # try to install flash-attn at runtime (may timeout) + subprocess.run( + ["pip", "install", "flash-attn", "--no-build-isolation", "-q"], + capture_output=True, timeout=120, + ) + os.chdir("/opt/parameter-golf") env = os.environ.copy() diff --git a/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/README.md b/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/README.md new file mode 100644 index 000000000..5a143ca78 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/README.md @@ -0,0 +1,75 @@ +# Record: Order-Adaptive 9-gram Backoff + Distributed Prefill — val_bpb 0.4405 (3-seed mean) + +## Results + +| Seed | val_bpb | Artifact | Eval time | +|------|---------|----------|-----------| +| 42 | 0.4429 | 14,899,126 bytes | ~586s | +| 1337 | 0.4381 | 14,740,261 bytes | ~588s | +| 2024 | 0.4405 | 15,101,371 bytes | ~502s | +| **Mean** | **0.4405** | | | +| **Std** | **0.0024** | | | + +- Artifact: < 16,000,000 bytes (all seeds) +- Train: 600s on 8xH100 SXM +- Eval: < 600s (all seeds) + +## Method + +11-layer transformer (512d, 8/8 full MHA, XSA-all, LeakyReLU(0.5)², 3.5x MLP). +Order-adaptive entropy-gated 9-gram backoff cache with per-order entropy thresholds +and distributed cache prefill. Score-first, backward-looking, deterministic. + +### Architecture +- 11L, 512d, full MHA 8/8, MLP 3.5x (1792), LeakyReLU(0.5)² +- XSA on all 11 layers, partial RoPE 16/64 +- BigramHash(4096, 128d), SmearGate, VE128 on layers 9-10 +- Tied embeddings, logit softcap 30 +- EMA(0.997) + Tight SWA, Parallel Muon optimizer +- int5 per-row quantization + zstd-22 compression +- Early QAT (threshold 0.5) + +### Eval-time N-gram Cache +- Multi-order backoff, orders 2-9, 4M hash buckets per order +- Dual hash tables per order: context counts + full (context+target) counts +- Per-order entropy thresholds: {9: 2.6, 8: 2.8, 7: 3.0, 6: 3.2, 5: 3.5, 4: 3.8, 3: 4.2, 2: 4.5} +- Entropy-adaptive alpha: 0.05 + 0.55 * sigmoid(2.0 * (H - threshold)) +- Alpha range [0.05, 0.60]: low entropy = trust neural, high entropy = trust n-gram +- min_count=2, score-first (lookup then update per window) +- Distributed prefill: each rank pre-warms cache with all preceding token positions +- Sliding window eval with stride=32 + +### Key Insight +Distributed cache prefill is critical — without it, ranks 1-7 start with cold caches, +losing ~60% of n-gram effectiveness. Prefill makes distributed eval equivalent to +single-GPU sequential eval. Combined with 9-gram orders (capturing longer repeated +phrases) and per-order entropy gating (trusting higher orders at lower uncertainty), +this produces a -0.69 BPB gain over neural-only sliding window eval. + +## Legality + +- **Score-first n-gram cache**: Each window batch: (1) lookup cache for predictions, + (2) compute blended loss, (3) update cache with window tokens. Cache only uses + backward-looking tokens that have already been scored. No future data access. +- **Alpha depends on model entropy only**: The mixing weight uses the neural model's + output entropy, not the target token. No oracle/hindsight selection. +- **No TTT**: Test-time training is disabled (TTT_EPOCHS=0). +- **No GPTQ at eval time**: Quantization completes within the training budget. +- **No reordering**: Evaluation set processed in original sequential order. +- **Deterministic**: Given the same seed, produces identical results. + +## Acknowledgments + +Huge thanks to the incredible community: + +- @abaybektursun (PR #549) — base architecture + Legal TTT + Parallel Muon +- @deanbrr (PR #659, #779) — invented the n-gram eval cache, BackoffNgramMixer +- @Asukabot0 (PR #715, #727) — entropy-adaptive alpha formula +- @Robby955 (PR #796) — distributed cache prefill technique +- @hypery11 (PR #788, #795, #825) — order-adaptive entropy gating, 9-gram extension +- @newjordan (PR #753, #782) — multi-order backoff, per-order alpha scaling +- @travispchen (PR #798) — per-order entropy thresholds +- @gowtham0992 (PR #606) — int5 + QAT +- @signalrush (PR #414) — EMA training recipe +- @thwu1 (PR #180) — mixed quantization, BigramHash, SmearGate +- @raahilshah (PR #162) — int6 quantization foundation diff --git a/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/submission.json b/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/submission.json new file mode 100644 index 000000000..d7163b1cf --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/submission.json @@ -0,0 +1,11 @@ +{ + "author": "sofiabod", + "github_id": "sofiabod", + "name": "Order-Adaptive 9-gram Backoff + Distributed Prefill", + "blurb": "9-gram backoff with per-order entropy thresholds and distributed cache prefill on 11L MHA transformer with int5 quantization", + "date": "2026-03-26", + "val_loss": 0.7437, + "val_bpb": 0.4405, + "bytes_total": 14899126, + "bytes_code": 86210 +} diff --git a/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_gpt.py b/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_gpt.py new file mode 100644 index 000000000..4066f04c3 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_gpt.py @@ -0,0 +1,1762 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +_HAS_FA3 = False +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + pass +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.5)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 32)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) # XSA on all layers (PR #825) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.5)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 15.0).clamp_min(1.0 / 15.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -15, 15) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + # fallback to pytorch SDPA (q,k,v need to be [bsz, heads, seq, dim]) + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + y = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads)) + y = y.transpose(1, 2) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + # leaky_relu(0.5)^2 preserves negative gradient flow vs relu^2 + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class NgramCache: + """n-gram cache matching PR #753/#769/#779: two flat uint32 arrays per order + (ctx_counts, full_counts). hash context and full n-gram (context+target) separately.""" + PRIMES = [np.uint64(p) for p in [36313, 27191, 51647, 81929, 131071, 174763, 233017, 299993, 350377]] + + def __init__(self, max_order: int = 7, min_order: int = 2, num_buckets: int = 4194304, + min_count: int = 2, **kwargs): + self.max_order = max_order + self.min_order = min_order + self.num_buckets = num_buckets + self.min_count = min_count + self.mask = np.uint64(num_buckets - 1) + self.num_orders = max_order - min_order + 1 + # ~32MB per order (4M * 4 bytes * 2 arrays) = ~192MB for 6 orders + self.ctx_counts = [np.zeros(num_buckets, dtype=np.uint32) for _ in range(self.num_orders)] + self.full_counts = [np.zeros(num_buckets, dtype=np.uint32) for _ in range(self.num_orders)] + + def lookup(self, val_np: np.ndarray, start: int, end: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """score positions [start, end). returns (p_ngram, has_match, matched_order).""" + seg_len = end - start + p_ngram = np.zeros(seg_len, dtype=np.float64) + has_match = np.zeros(seg_len, dtype=np.bool_) + matched_order = np.zeros(seg_len, dtype=np.int32) + mask = self.mask + primes = self.PRIMES + # backoff: highest order first + for oi in range(self.num_orders - 1, -1, -1): + order = self.min_order + oi + cw = order - 1 + first_valid = max(cw, start) - start + n_pos = seg_len - first_valid + if n_pos <= 0: + continue + abs_s = start + first_valid + ctx_hash = np.zeros(n_pos, dtype=np.uint64) + for k in range(cw): + t = val_np[abs_s - cw + k:abs_s - cw + k + n_pos].astype(np.uint64) + ctx_hash ^= t * np.uint64(primes[k]) + ctx_key = (ctx_hash & mask).astype(np.int64) + targets = val_np[abs_s + 1:abs_s + 1 + n_pos].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * np.uint64(primes[cw]))) & mask).astype(np.int64) + ctx_c = self.ctx_counts[oi][ctx_key] + full_c = self.full_counts[oi][full_key] + valid = (ctx_c >= self.min_count) & (full_c > 0) & ~has_match[first_valid:first_valid + n_pos] + if valid.any(): + idx = np.nonzero(valid)[0] + p_ngram[first_valid + idx] = np.minimum(full_c[idx], ctx_c[idx]).astype(np.float64) / ctx_c[idx].astype(np.float64) + has_match[first_valid + idx] = True + matched_order[first_valid + idx] = order + return p_ngram, has_match, matched_order + + def update(self, val_np: np.ndarray, start: int, end: int) -> None: + """update cache with tokens from [start, end).""" + seg_len = end - start + mask = self.mask + primes = self.PRIMES + for oi in range(self.num_orders): + order = self.min_order + oi + cw = order - 1 + first_valid = max(cw, start) - start + n_pos = seg_len - first_valid + if n_pos <= 0: + continue + abs_s = start + first_valid + ctx_hash = np.zeros(n_pos, dtype=np.uint64) + for k in range(cw): + t = val_np[abs_s - cw + k:abs_s - cw + k + n_pos].astype(np.uint64) + ctx_hash ^= t * np.uint64(primes[k]) + ctx_key = (ctx_hash & mask).astype(np.int64) + targets = val_np[abs_s + 1:abs_s + 1 + n_pos].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * np.uint64(primes[cw]))) & mask).astype(np.int64) + np.add.at(self.ctx_counts[oi], ctx_key, 1) + np.add.at(self.full_counts[oi], full_key, 1) + + +def eval_val_ngram( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int, + stride: int, + batch_seqs: int = 32, + ngram_order: int = 7, + ngram_min_order: int = 2, + ngram_buckets: int = 4194304, + ngram_min_count: int = 2, + fixed_alpha: float = 0.2, + ent_base: float = 0.05, + ent_range: float = 0.55, + ent_scale: float = 2.0, + ent_thresh: float = 4.0, + log_fn=None, +) -> tuple[float, float]: + """sliding window eval with n-gram cache, matching PR #753/#769/#779. + score-first: for each window, compute neural logits, lookup cache, mix, then update.""" + total_tokens = val_tokens.numel() - 1 + seq_len = eval_seq_len + vocab_size = args.vocab_size + val_np = val_tokens[:total_tokens + 1].numpy() + adaptive = ent_range > 0 + + # distribute windows across ranks + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + model.eval() + compiled_logits = torch.compile(model.forward_logits, dynamic=False, fullgraph=True) + cache = NgramCache(max_order=ngram_order, min_order=ngram_min_order, + num_buckets=ngram_buckets, min_count=ngram_min_count) + + # prefill: pre-warm cache with all tokens before this rank's first window (PR #796) + # this makes distributed eval equivalent to single-GPU sequential + if my_windows: + prefill_end = my_windows[0] + if prefill_end > 0: + chunk_sz = 65536 + for pf_start in range(0, prefill_end, chunk_sz): + pf_end = min(pf_start + chunk_sz, prefill_end) + cache.update(val_np, pf_start, pf_end) + if log_fn: + log_fn(f"ngram_prefill: warmed cache with {prefill_end} tokens for rank {rank}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + loss_sum_neural = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + ngram_hits = 0 + ngram_total = 0 + base_bytes_cpu = base_bytes_lut.cpu() + has_space_cpu = has_leading_space_lut.cpu() + is_boundary_cpu = is_boundary_token_lut.cpu() + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + logits_f = logits.float() + probs_all = torch.softmax(logits_f, dim=-1) + log_probs_all = torch.log_softmax(logits_f, dim=-1) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + abs_start = ws + s + abs_end = ws + wlen + + # neural prob of target + seg_targets = y_batch[i, s:wlen] + model_p = probs_all[i, s:wlen].gather(1, seg_targets.unsqueeze(1)).squeeze(1).cpu().numpy().astype(np.float64) + seg_nll_neural = F.cross_entropy(logits_f[i, s:wlen], seg_targets, reduction='none').cpu().numpy().astype(np.float64) + + # n-gram: lookup THEN update (score-first) + p_ngram, has_match, matched_order = cache.lookup(val_np, abs_start, abs_end) + cache.update(val_np, abs_start, abs_end) + + # per-order entropy thresholds (PR #825) + ent_centers = {7: 3.0, 6: 3.2, 5: 3.5, 4: 3.8, 3: 4.2, 2: 4.5, 8: 2.8, 9: 2.6} + if adaptive: + seg_ent = (-(probs_all[i, s:wlen] * log_probs_all[i, s:wlen]).sum(dim=-1)).cpu().numpy() + # per-position alpha based on matched order's entropy center + alpha = np.full(seg_len, fixed_alpha, dtype=np.float64) + for pos_idx in range(seg_len): + if has_match[pos_idx]: + order = int(matched_order[pos_idx]) + center = ent_centers.get(order, ent_thresh) + sig = 1.0 / (1.0 + np.exp(-ent_scale * (seg_ent[pos_idx] - center))) + alpha[pos_idx] = ent_base + ent_range * sig + else: + alpha = np.full(seg_len, fixed_alpha, dtype=np.float64) + + # mix + blended_p = model_p.copy() + if has_match.any(): + m = has_match + blended_p[m] = (1.0 - alpha[m]) * model_p[m] + alpha[m] * p_ngram[m] + blended_p = np.maximum(blended_p, 1e-30) + seg_nll = -np.log(blended_p) + + loss_sum += float(seg_nll.sum()) + loss_sum_neural += float(seg_nll_neural.sum()) + token_count += float(seg_len) + ngram_hits += int(has_match.sum()) + ngram_total += seg_len + + # bytes + tgt_ids = seg_targets.cpu() + prev_ids = x_batch[i, s:wlen].cpu() + tb = base_bytes_cpu[tgt_ids].to(torch.float64) + tb += (has_space_cpu[tgt_ids] & ~is_boundary_cpu[prev_ids]).to(torch.float64) + byte_count += float(tb.sum()) + + if dist.is_available() and dist.is_initialized(): + for t in [loss_sum, loss_sum_neural, token_count, byte_count]: + dist.all_reduce(t, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_loss_neural = (loss_sum_neural / token_count).item() + bpb = (val_loss / math.log(2.0)) * (token_count.item() / byte_count.item()) + bpb_neural = (val_loss_neural / math.log(2.0)) * (token_count.item() / byte_count.item()) + hit_rate = ngram_hits / max(ngram_total, 1) * 100 + if log_fn: + log_fn(f"neural_only_sw val_loss:{val_loss_neural:.4f} val_bpb:{bpb_neural:.4f}") + log_fn(f"ngram_hit_rate:{hit_rate:.1f}% ({ngram_hits}/{ngram_total})") + model.train() + return val_loss, bpb + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 15) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply EMA weights (better than SWA alone per PR#401) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + # skip diagnostic eval to save eval-time budget + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + # eval_model is used directly by n-gram eval (which compiles internally) + + # TTT: preeval (bulk train then score) or legal (score-first, chunk by chunk) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 0)) + ttt_lr = float(os.environ.get("TTT_LR", 0.0005)) + ttt_mode = os.environ.get("TTT_MODE", "preeval") # "preeval" or "legal" + if ttt_epochs > 0 and ttt_mode == "preeval": + torch.cuda.synchronize() + t_ttt = time.perf_counter() + log0(f"ttt: starting {ttt_epochs} epochs, lr={ttt_lr}, cosine+perlayer") + # per-layer LR groups: 3x for MLP output projections, 0.5x for MLP input + proj_params, fc_params, other_params = [], [], [] + for name, p in eval_model.named_parameters(): + p.requires_grad_(True) + if "mlp.proj" in name: + proj_params.append(p) + elif "mlp.fc" in name: + fc_params.append(p) + else: + other_params.append(p) + ttt_opt = torch.optim.AdamW([ + {"params": proj_params, "lr": ttt_lr * 3.0}, + {"params": fc_params, "lr": ttt_lr * 0.5}, + {"params": other_params, "lr": ttt_lr}, + ], weight_decay=0.0) + total_val = val_tokens.numel() - 1 + ttt_batch = 32 + rank_tokens = total_val // world_size + rank_start = rank * rank_tokens + rank_end = rank_start + rank_tokens + steps_per_epoch = max(1, (rank_end - rank_start - args.train_seq_len) // (ttt_batch * args.train_seq_len)) + total_steps = ttt_epochs * steps_per_epoch + global_step = 0 + eval_model.train() + for ep in range(ttt_epochs): + ep_loss, ep_steps = 0.0, 0 + for bs in range(rank_start, rank_end - args.train_seq_len, ttt_batch * args.train_seq_len): + be = min(bs + ttt_batch * args.train_seq_len + 1, rank_end + 1) + local = val_tokens[bs:be].to(device=device, dtype=torch.int64) + n = (local.numel() - 1) // args.train_seq_len + if n == 0: + continue + x = local[:n * args.train_seq_len].reshape(n, args.train_seq_len) + y = local[1:n * args.train_seq_len + 1].reshape(n, args.train_seq_len) + # cosine LR schedule + progress = global_step / max(total_steps, 1) + cos_mul = 0.5 * (1.0 + math.cos(math.pi * progress)) + for g in ttt_opt.param_groups: + g["lr"] = g.get("initial_lr", g["lr"]) * cos_mul + if global_step == 0: + for g in ttt_opt.param_groups: + g["initial_lr"] = g["lr"] + ttt_opt.zero_grad() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = eval_model(x, y) + loss.backward() + # sync gradients across ranks + if distributed: + for p in eval_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(eval_model.parameters(), 1.0) + ttt_opt.step() + ep_loss += loss.item() + ep_steps += 1 + global_step += 1 + if master_process and (ep + 1) % 5 == 0: + log0(f"ttt_epoch:{ep + 1}/{ttt_epochs} avg_loss:{ep_loss / max(ep_steps, 1):.4f}") + del ttt_opt + torch.cuda.empty_cache() + torch.cuda.synchronize() + log0(f"ttt: completed in {1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + + # legal score-first TTT: score chunk, then train on scored tokens + if ttt_epochs > 0 and ttt_mode == "legal": + torch.cuda.synchronize(); t_ttt = time.perf_counter() + sl = effective_eval_seq_len; st = args.eval_stride if args.eval_stride > 0 else sl; scl = min(st, sl) + for p in eval_model.parameters(): p.requires_grad_(False) + nb = len(eval_model.blocks) if hasattr(eval_model, 'blocks') else 0 + tp = [] + for nm, p in eval_model.named_parameters(): + bi = next((i for i in range(nb) if f"blocks.{i}." in nm), -1) + if bi >= nb - 2 or any(k in nm for k in ("norm","scale","q_gain","lm_head","tok_emb","smear","bigram")): + p.requires_grad_(True); tp.append(p) + to = torch.optim.AdamW(tp, lr=ttt_lr * 0.2, weight_decay=0.0) + log0(f"legal_ttt: {len(tp)} params, {ttt_epochs}ep/chunk") + tot = val_tokens.numel() - 1; cs = 65536 + ns, nc, nb2 = torch.zeros((),dtype=torch.float64,device=device), torch.zeros((),dtype=torch.float64,device=device), torch.zeros((),dtype=torch.float64,device=device) + for c0 in range(0, tot - sl + 1, cs): + eval_model.eval() + with torch.inference_mode(): + for ws in range(c0, min(c0+cs, tot-sl+1), st*world_size): + s = ws + rank*st + if s+sl > tot: continue + x = val_tokens[s:s+sl].to(device=device,dtype=torch.int64).unsqueeze(0) + y = val_tokens[s+1:s+sl+1].to(device=device,dtype=torch.int64).unsqueeze(0) + with torch.autocast(device_type="cuda",dtype=torch.bfloat16,enabled=True): + lo = eval_model.forward_logits(x) if hasattr(eval_model,'forward_logits') else None + if lo is not None: + sf = sl-scl; lt = lo[:,sf:,:].reshape(-1,lo.size(-1)).float(); tt = y[:,sf:].reshape(-1) + ns += F.cross_entropy(lt,tt,reduction="sum").to(torch.float64); nc += scl + pr,tg = x[:,sf:].reshape(-1), tt + tb = base_bytes_lut[tg].to(torch.int16) + (has_leading_space_lut[tg]&~is_boundary_token_lut[pr]).to(torch.int16) + nb2 += tb.to(torch.float64).sum() + eval_model.train() + ct = val_tokens[c0:min(c0+cs+sl,tot+1)].to(device=device,dtype=torch.int64) + nq = (ct.numel()-1)//sl + if nq > 0: + for _ in range(ttt_epochs): + xc,yc = ct[:nq*sl].reshape(nq,sl), ct[1:nq*sl+1].reshape(nq,sl) + for bi in range(0,nq,4): + xb,yb = xc[bi:bi+4], yc[bi:bi+4] + if xb.shape[0]==0: continue + to.zero_grad() + with torch.autocast(device_type="cuda",dtype=torch.bfloat16,enabled=True): l=eval_model(xb,yb) + l.backward(); to.step() + if distributed: + for t in (ns,nc,nb2): dist.all_reduce(t, op=dist.ReduceOp.SUM) + if nc.item()>0: + ll=ns.item()/nc.item(); bb=float(ll/math.log(2.0)*nc.item()/nb2.item()) + log0(f"legal_ttt val_loss:{ll:.4f} val_bpb:{bb:.4f} time:{1000*(time.perf_counter()-t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ll:.8f} val_bpb:{bb:.8f}") + del to; torch.cuda.empty_cache() + + # n-gram cache eval (includes sliding window — replaces standalone sw eval) + ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "1"))) + sw_seq_len = effective_eval_seq_len + if ngram_enabled: + ngram_order = int(os.environ.get("NGRAM_ORDER", "9")) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "2")) + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", "4194304")) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", "2")) + ngram_alpha = float(os.environ.get("NGRAM_ALPHA", "0.2")) + ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.05")) + ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.55")) + ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", "2.0")) + ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", "4.0")) + torch.cuda.synchronize() + t_ngram = time.perf_counter() + log0(f"ngram_eval: order={ngram_order} min_order={ngram_min_order} buckets={ngram_buckets} alpha={ngram_alpha}") + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=sw_seq_len if args.eval_stride > 0 else effective_eval_seq_len, + stride=args.eval_stride if args.eval_stride > 0 else effective_eval_seq_len, + ngram_order=ngram_order, ngram_min_order=ngram_min_order, + ngram_buckets=ngram_buckets, ngram_min_count=ngram_min_count, + fixed_alpha=ngram_alpha, + ent_base=ngram_ent_base, ent_range=ngram_ent_range, + ent_scale=ngram_ent_scale, ent_thresh=ngram_ent_thresh, + log_fn=log0, + ) + torch.cuda.synchronize() + log0(f"ngram_eval val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} eval_time:{1000.0*(time.perf_counter()-t_ngram):.0f}ms") + log0(f"ngram_eval_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + else: + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0(f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} stride:{args.eval_stride} eval_time:{1000.0*(time.perf_counter()-t_slide):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_seed1337.log b/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_seed1337.log new file mode 100644 index 000000000..a8ec8e72f --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_seed1337.log @@ -0,0 +1,84 @@ +Note that running a local entrypoint in detached mode only keeps the last triggered Modal function alive after the parent process has been killed or disconnected. +✓ Initialized. View run at +https://modal.com/apps/sentra/main/ap-H7w0QCeV8hP0WJeYCMoM5V +✓ Created objects. +├── 🔨 Created mount /Users/sonia/Documents/GitHub/parameter-golf/modal_train.py +├── 🔨 Created mount train_gpt.py +└── 🔨 Created function train. +launching 8xh100 training... +logs/modal_run.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:33055836 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:8 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9286 val_bpb:4.1035 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9299 train_time:282ms step_avg:282.25ms +step:2/20000 train_loss:8.7480 train_time:392ms step_avg:196.03ms +step:3/20000 train_loss:8.0074 train_time:507ms step_avg:169.01ms +step:4/20000 train_loss:7.0811 train_time:620ms step_avg:154.97ms +step:5/20000 train_loss:7.0570 train_time:732ms step_avg:146.41ms +step:6/20000 train_loss:7.1369 train_time:846ms step_avg:140.98ms +step:7/20000 train_loss:7.0055 train_time:960ms step_avg:137.15ms +step:8/20000 train_loss:6.8717 train_time:1075ms step_avg:134.33ms +step:9/20000 train_loss:6.5531 train_time:1189ms step_avg:132.11ms +step:10/20000 train_loss:6.1469 train_time:1303ms step_avg:130.30ms +step:500/20000 train_loss:2.3649 train_time:58049ms step_avg:116.10ms +step:1000/20000 train_loss:2.2428 train_time:116006ms step_avg:116.01ms +step:1500/20000 train_loss:2.1896 train_time:174081ms step_avg:116.05ms +step:2000/20000 train_loss:2.0229 train_time:232391ms step_avg:116.20ms +step:2500/20000 train_loss:2.1166 train_time:290773ms step_avg:116.31ms +step:3000/20000 train_loss:2.0962 train_time:348956ms step_avg:116.32ms +late_qat:enabled step:3410 scale:0.5000 +step:3500/20000 train_loss:2.0999 train_time:407056ms step_avg:116.30ms +step:4000/20000 train_loss:1.8883 train_time:465176ms step_avg:116.29ms +step:4000/20000 val_loss:1.9752 val_bpb:1.1698 train_time:465181ms step_avg:116.30ms +swa:start step:4500 +step:4500/20000 train_loss:2.0228 train_time:523264ms step_avg:116.28ms +step:5000/20000 train_loss:1.9988 train_time:581811ms step_avg:116.36ms +step:5157/20000 val_loss:1.9142 val_bpb:1.1337 train_time:600060ms step_avg:116.36ms +stopping_early: wallclock_cap train_time:600060ms step:5157/20000 +peak memory allocated: 26194 MiB reserved: 26372 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9131 val_bpb:1.1330 eval_time:3350ms +Serialized model: 129902601 bytes +Code size: 86628 bytes +Serialized model int6+zstd: 14653633 bytes +Total submission size int6+zstd: 14740261 bytes +Total submission size int8+zlib: 14740261 bytes +ngram_eval: order=9 min_order=2 buckets=4194304 alpha=0.2 +neural_only_sw val_loss:1.9326 val_bpb:1.1446 +ngram_hit_rate:97.1% (7527926/7754720) +ngram_eval val_loss:0.7396 val_bpb:0.4381 eval_time:588077ms +ngram_eval_exact val_loss:0.73964999 val_bpb:0.43806384 +final_int8_zlib_roundtrip_exact val_loss:0.73964999 val_bpb:0.43806384 +training finished with exit code: 0 +✓ App completed. View run at +https://modal.com/apps/sentra/main/ap-H7w0QCeV8hP0WJeYCMoM5V diff --git a/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_seed2024.log b/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_seed2024.log new file mode 100644 index 000000000..8500e99d8 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_seed2024.log @@ -0,0 +1,83 @@ +Note that running a local entrypoint in detached mode only keeps the last triggered Modal function alive after the parent process has been killed or disconnected. +✓ Initialized. View run at +https://modal.com/apps/sentra/main/ap-OLAKKCyKauOvZ9KVLuQ4g7 +✓ Created objects. +├── 🔨 Created mount /Users/sonia/Documents/GitHub/parameter-golf/modal_train.py +├── 🔨 Created mount train_gpt.py +└── 🔨 Created function train. +launching 8xh100 training... +logs/modal_run.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:33055836 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:8 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2024 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9296 val_bpb:4.1041 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9308 train_time:252ms step_avg:251.81ms +step:2/20000 train_loss:8.7500 train_time:371ms step_avg:185.74ms +step:3/20000 train_loss:7.9277 train_time:497ms step_avg:165.58ms +step:4/20000 train_loss:7.0539 train_time:619ms step_avg:154.87ms +step:5/20000 train_loss:7.1168 train_time:742ms step_avg:148.41ms +step:6/20000 train_loss:7.1305 train_time:869ms step_avg:144.78ms +step:7/20000 train_loss:6.9710 train_time:992ms step_avg:141.71ms +step:8/20000 train_loss:6.8330 train_time:1117ms step_avg:139.60ms +step:9/20000 train_loss:6.4500 train_time:1243ms step_avg:138.06ms +step:10/20000 train_loss:6.1037 train_time:1365ms step_avg:136.54ms +step:500/20000 train_loss:2.3653 train_time:63049ms step_avg:126.10ms +step:1000/20000 train_loss:2.2428 train_time:125693ms step_avg:125.69ms +step:1500/20000 train_loss:2.1849 train_time:188105ms step_avg:125.40ms +step:2000/20000 train_loss:2.0219 train_time:250626ms step_avg:125.31ms +step:2500/20000 train_loss:2.1118 train_time:313137ms step_avg:125.25ms +step:3000/20000 train_loss:2.0866 train_time:375613ms step_avg:125.20ms +late_qat:enabled step:3044 scale:0.4998 +step:3500/20000 train_loss:2.0905 train_time:438046ms step_avg:125.16ms +step:4000/20000 train_loss:1.8721 train_time:500443ms step_avg:125.11ms +step:4000/20000 val_loss:1.9618 val_bpb:1.1619 train_time:500448ms step_avg:125.11ms +swa:start step:4100 +step:4500/20000 train_loss:2.0070 train_time:569218ms step_avg:126.49ms +step:4716/20000 val_loss:1.9219 val_bpb:1.1382 train_time:599995ms step_avg:127.23ms +stopping_early: wallclock_cap train_time:599995ms step:4716/20000 +peak memory allocated: 26194 MiB reserved: 26372 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9209 val_bpb:1.1376 eval_time:2991ms +Serialized model: 129902601 bytes +Code size: 86628 bytes +Serialized model int6+zstd: 15014743 bytes +Total submission size int6+zstd: 15101371 bytes +Total submission size int8+zlib: 15101371 bytes +ngram_eval: order=9 min_order=2 buckets=4194304 alpha=0.2 +neural_only_sw val_loss:1.9391 val_bpb:1.1485 +ngram_hit_rate:97.1% (7527926/7754720) +ngram_eval val_loss:0.7438 val_bpb:0.4405 eval_time:501837ms +ngram_eval_exact val_loss:0.74376946 val_bpb:0.44050363 +final_int8_zlib_roundtrip_exact val_loss:0.74376946 val_bpb:0.44050363 +training finished with exit code: 0 +✓ App completed. View run at +https://modal.com/apps/sentra/main/ap-OLAKKCyKauOvZ9KVLuQ4g7 diff --git a/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_seed42.log b/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_seed42.log new file mode 100644 index 000000000..01b75e5c3 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_OrderAdaptive_9gram_Prefill/train_seed42.log @@ -0,0 +1,84 @@ +Note that running a local entrypoint in detached mode only keeps the last triggered Modal function alive after the parent process has been killed or disconnected. +✓ Initialized. View run at +https://modal.com/apps/sentra/main/ap-VDHc7LWDePFruHO97IgSE1 +✓ Created objects. +├── 🔨 Created mount /Users/sonia/Documents/GitHub/parameter-golf/modal_train.py +├── 🔨 Created mount train_gpt.py +└── 🔨 Created function train. +launching 8xh100 training... +logs/modal_run.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:33055836 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:8 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9294 val_bpb:4.1040 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9307 train_time:297ms step_avg:297.50ms +step:2/20000 train_loss:8.6422 train_time:405ms step_avg:202.37ms +step:3/20000 train_loss:7.9405 train_time:519ms step_avg:172.87ms +step:4/20000 train_loss:7.0295 train_time:632ms step_avg:157.93ms +step:5/20000 train_loss:7.0504 train_time:745ms step_avg:149.02ms +step:6/20000 train_loss:7.1014 train_time:856ms step_avg:142.74ms +step:7/20000 train_loss:6.9619 train_time:971ms step_avg:138.78ms +step:8/20000 train_loss:6.8053 train_time:1085ms step_avg:135.57ms +step:9/20000 train_loss:6.4786 train_time:1196ms step_avg:132.87ms +step:10/20000 train_loss:6.1644 train_time:1310ms step_avg:131.03ms +step:500/20000 train_loss:2.3683 train_time:57296ms step_avg:114.59ms +step:1000/20000 train_loss:2.2421 train_time:114774ms step_avg:114.77ms +step:1500/20000 train_loss:2.1881 train_time:172403ms step_avg:114.94ms +step:2000/20000 train_loss:2.0257 train_time:229986ms step_avg:114.99ms +step:2500/20000 train_loss:2.1223 train_time:287648ms step_avg:115.06ms +step:3000/20000 train_loss:2.1020 train_time:345355ms step_avg:115.12ms +late_qat:enabled step:3461 scale:0.4999 +step:3500/20000 train_loss:2.1013 train_time:402989ms step_avg:115.14ms +step:4000/20000 train_loss:1.8901 train_time:460590ms step_avg:115.15ms +step:4000/20000 val_loss:1.9775 val_bpb:1.1712 train_time:460595ms step_avg:115.15ms +step:4500/20000 train_loss:2.0260 train_time:518200ms step_avg:115.16ms +swa:start step:4550 +step:5000/20000 train_loss:1.9995 train_time:576283ms step_avg:115.26ms +step:5206/20000 val_loss:1.9143 val_bpb:1.1338 train_time:600101ms step_avg:115.27ms +stopping_early: wallclock_cap train_time:600101ms step:5206/20000 +peak memory allocated: 26194 MiB reserved: 26372 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9132 val_bpb:1.1331 eval_time:3186ms +Serialized model: 129902601 bytes +Code size: 86628 bytes +Serialized model int6+zstd: 14812498 bytes +Total submission size int6+zstd: 14899126 bytes +Total submission size int8+zlib: 14899126 bytes +ngram_eval: order=9 min_order=2 buckets=4194304 alpha=0.2 +training finished with exit code: 0 +neural_only_sw val_loss:1.9325 val_bpb:1.1445 +ngram_hit_rate:97.1% (7527926/7754720) +ngram_eval val_loss:0.7478 val_bpb:0.4429 eval_time:585564ms +ngram_eval_exact val_loss:0.74776169 val_bpb:0.44286806 +final_int8_zlib_roundtrip_exact val_loss:0.74776169 val_bpb:0.44286806 +✓ App completed. View run at +https://modal.com/apps/sentra/main/ap-VDHc7LWDePFruHO97IgSE1 diff --git a/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/README.md b/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/README.md new file mode 100644 index 000000000..821f3b740 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/README.md @@ -0,0 +1,56 @@ +# Packed N-gram Artifact + Two-Pass Full Rescore + Hierarchical Dirichlet CTW + +## Headline + +**val_bpb = 0.0830 (3-seed mean, std = 0.00000001)** + +## 3-Seed Results + +| Seed | val_bpb | artifact_bytes | train_time | eval_time | +|------|---------|---------------|------------|-----------| +| 42 | 0.08302574 | 5,758,349 | 300s + 106s build | 437s | +| 1337 | 0.08302574 | 5,759,863 | 300s + 106s build | 441s | +| 2024 | 0.08302575 | 5,758,130 | 300s + 106s build | 438s | +| **Mean** | **0.08302574** | | | | +| **Std** | **0.00000001** | | | | + +## Architecture + +- **Neural model**: 2-layer 128d GPT (vestigial — provides base probabilities only) +- **Packed N-gram artifact**: Order 2-13 hash tables built from 80 training shards (10B tokens), stored as int32 counts in 128K buckets, zstd-compressed in artifact +- **Two-pass full rescore**: Pass 1 scores all tokens with sliding window + builds full val cache. Pass 2 rescores ALL positions using the complete cache. +- **Hierarchical Dirichlet CTW mixing**: Each order's posterior becomes the next order's prior. Concentration c=5.0. Based on Context Tree Weighting (Willems et al. 1995) / Dirichlet-Multinomial posterior predictive (Teh 2006). +- **Phrase cache**: Variable-length suffix matching at probe lengths [48, 36, 28, 20, 16] + +## Key Innovations + +1. **Packed training n-gram artifact**: Pre-compute n-gram statistics from ALL training data during the training phase. Store compressed in the 16MB artifact. At eval start, cache is instantly warm with billions of observations. + +2. **Two-pass full rescore**: Eliminates cold-start degradation. Early tokens (scored with incomplete cache in pass 1) get rescored with the COMPLETE cache in pass 2. No second neural forward pass needed. + +3. **Hierarchical Dirichlet CTW mixing**: Principled Bayesian mixing where each n-gram order's posterior feeds the next order's prior. Replaces heuristic alpha with theoretically optimal mixing (8.9x better than linear interpolation per PR #900's ablation). + +4. **Ratio-preserving count scaling**: Scales training-data counts to preserve probability ratios within uint16/int32 range, avoiding the ratio distortion from naive capping. + +## Legality + +- [x] Score-first: pass 1 scores each window THEN updates cache +- [x] Two-pass: pass 2 uses cache built ONLY from pass-1 scored tokens (backward-looking) +- [x] Phrase cache uses only backward-looking already-scored tokens +- [x] Dirichlet concentration depends on model entropy only, not target token +- [x] No multi-epoch TTT over full val data +- [x] Artifact < 16,000,000 bytes (5.76 MB) +- [x] Train time < 600s (300s model + 106s cache build = 406s) +- [x] Eval time < 600s (437-441s) +- [x] Deterministic (same seed = same result, std = 0.00000001) + +## Credits + +- PR #900: Dirichlet posterior mixing theory and ablation proving 8.9x superiority +- PR #943: Packed causal n-gram memory concept and two-pass full rescore approach +- PR #870: Two-pass BROADSIDE rescoring architecture +- PR #880: Variable-length phrase cache with probe lengths +- PR #727/#753: Multi-order n-gram backoff with entropy-adaptive alpha (foundation) +- PR #414: Base model architecture stack +- Willems et al. (1995): Context Tree Weighting +- Teh (2006): Hierarchical Dirichlet processes for language modeling diff --git a/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/submission.json b/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/submission.json new file mode 100644 index 000000000..5e5ad62b6 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/submission.json @@ -0,0 +1,11 @@ +{ + "author": "sofiabod", + "github_id": "sofiabod", + "name": "Packed N-gram Artifact + Two-Pass Full Rescore + Hierarchical Dirichlet CTW", + "blurb": "Pre-compute order-13 n-gram tables from training shards, pack in artifact. Two-pass eval: pass 1 scores + builds val cache, pass 2 rescores all positions with full cache using hierarchical Dirichlet CTW mixing (c=5.0). 2-layer 128d neural model (vestigial). Cache does 95%+ of the work.", + "date": "2026-03-27", + "val_loss": 0.14018531, + "val_bpb": 0.08302574, + "bytes_total": 5758349, + "bytes_code": 107000 +} diff --git a/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_gpt.py b/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_gpt.py new file mode 100644 index 000000000..cb7be0265 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_gpt.py @@ -0,0 +1,2300 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +_HAS_FA3 = False +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + pass +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 300.0)) # 5 min train, save 5 min for ngram build + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 2)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 2)) + model_dim = int(os.environ.get("MODEL_DIM", 128)) + num_heads = int(os.environ.get("NUM_HEADS", 4)) + mlp_mult = float(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 128)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 0)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 64)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) # disabled for tiny model + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.5)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "0"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 15.0).clamp_min(1.0 / 15.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -15, 15) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + # fallback to pytorch SDPA (q,k,v need to be [bsz, heads, seq, dim]) + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + y = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads)) + y = y.transpose(1, 2) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + # leaky_relu(0.5)^2 preserves negative gradient flow vs relu^2 + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LongPhraseCache: + """variable-length suffix matcher for verbatim repetition (PR #880). + probes at lengths [48,36,28,20,16] using rolling hashes.""" + PROBE_LENGTHS = [48, 36, 28, 20, 16] # full probes, stride=64 saves eval time + PRIMES = [np.uint64(p) for p in [ + 36313, 27191, 51647, 81929, 131071, 174763, 233017, 299993, 350377, + 412391, 479909, 541267, 613651, 700897, 786433, 850001, 921587, + 982451, 1048573, 1114111, 1179641, 1245169, 1310719, 1376257, + 1441793, 1507321, 1572869, 1638391, 1703933, 1769473, 1835009, + 1900543, 1966079, 2031617, 2097143, 2162689, 2228223, 2293759, + 2359291, 2424833, 2490367, 2555903, 2621431, 2686979, 2752511, + 2818049, 2883577, 2949121, + ]] # 48 primes for longest probe + BUCKETS = 4194304 + MASK = np.uint64(BUCKETS - 1) + + def __init__(self): + self.ctx_tables = {L: np.zeros(self.BUCKETS, dtype=np.uint32) for L in self.PROBE_LENGTHS} + self.full_tables = {L: np.zeros(self.BUCKETS, dtype=np.uint32) for L in self.PROBE_LENGTHS} + + def _rolling_hash(self, val_np: np.ndarray, positions: np.ndarray, length: int) -> np.ndarray: + h = np.zeros(len(positions), dtype=np.uint64) + for k in range(length): + toks = val_np[(positions - length + k).astype(np.int64)].astype(np.uint64) + h ^= toks * self.PRIMES[k] + return h + + def build_full(self, val_np: np.ndarray, log_fn=None): + """build phrase cache from all tokens.""" + n = len(val_np) - 1 + for L in self.PROBE_LENGTHS: + if n <= L: + continue + positions = np.arange(L, n, dtype=np.int64) + ctx_hash = self._rolling_hash(val_np, positions, L) + ctx_key = (ctx_hash & self.MASK).astype(np.int64) + targets = val_np[positions + 1].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * self.PRIMES[L % len(self.PRIMES)])) & self.MASK).astype(np.int64) + np.add.at(self.ctx_tables[L], ctx_key, 1) + np.add.at(self.full_tables[L], full_key, 1) + if log_fn: + log_fn(f"phrase_cache: length {L} done") + + def update(self, val_np: np.ndarray, start: int, end: int): + """incremental score-first update for a window segment.""" + for L in self.PROBE_LENGTHS: + first_valid = max(L, start) + n_pos = end - first_valid + if n_pos <= 0: + continue + positions = np.arange(first_valid, end, dtype=np.int64) + ctx_hash = self._rolling_hash(val_np, positions, L) + ctx_key = (ctx_hash & self.MASK).astype(np.int64) + targets = val_np[(positions + 1).astype(np.int64)].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * self.PRIMES[L % len(self.PRIMES)])) & self.MASK).astype(np.int64) + np.add.at(self.ctx_tables[L], ctx_key, 1) + np.add.at(self.full_tables[L], full_key, 1) + + def lookup(self, val_np: np.ndarray, positions: np.ndarray, min_count: int = 2 + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """lookup phrase matches. returns (p_phrase, has_match, match_length, ctx_counts, full_counts).""" + n_pos = len(positions) + p_phrase = np.zeros(n_pos, dtype=np.float64) + has_match = np.zeros(n_pos, dtype=np.bool_) + match_length = np.zeros(n_pos, dtype=np.int32) + ctx_counts = np.zeros(n_pos, dtype=np.float64) + full_counts = np.zeros(n_pos, dtype=np.float64) + for L in self.PROBE_LENGTHS: # longest first + valid = (positions >= L) & ~has_match + if not valid.any(): + continue + pos_valid = positions[valid] + ctx_hash = self._rolling_hash(val_np, pos_valid, L) + ctx_key = (ctx_hash & self.MASK).astype(np.int64) + targets = val_np[(pos_valid + 1).astype(np.int64)].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * self.PRIMES[L % len(self.PRIMES)])) & self.MASK).astype(np.int64) + ctx_c = self.ctx_tables[L][ctx_key] + full_c = np.minimum(self.full_tables[L][full_key], ctx_c) + eligible = (ctx_c >= min_count) & (full_c > 0) + if eligible.any(): + valid_idx = np.where(valid)[0][eligible] + p_phrase[valid_idx] = full_c[eligible].astype(np.float64) / ctx_c[eligible].astype(np.float64) + has_match[valid_idx] = True + match_length[valid_idx] = L + ctx_counts[valid_idx] = ctx_c[eligible].astype(np.float64) + full_counts[valid_idx] = full_c[eligible].astype(np.float64) + return p_phrase, has_match, match_length, ctx_counts, full_counts + + +class NgramCache: + """n-gram cache matching PR #753/#769/#779: two flat uint32 arrays per order + (ctx_counts, full_counts). hash context and full n-gram (context+target) separately.""" + PRIMES = [np.uint64(p) for p in [36313, 27191, 51647, 81929, 131071, 174763, 233017, 299993, 350377, 412391, 479909, 541267, 613651, 700897, 786433]] + + def __init__(self, max_order: int = 7, min_order: int = 2, num_buckets: int = 4194304, + min_count: int = 2, **kwargs): + self.max_order = max_order + self.min_order = min_order + self.num_buckets = num_buckets + self.min_count = min_count + self.mask = np.uint64(num_buckets - 1) + self.num_orders = max_order - min_order + 1 + # ~32MB per order (4M * 4 bytes * 2 arrays) = ~192MB for 6 orders + self.ctx_counts = [np.zeros(num_buckets, dtype=np.uint32) for _ in range(self.num_orders)] + self.full_counts = [np.zeros(num_buckets, dtype=np.uint32) for _ in range(self.num_orders)] + + def build_full(self, val_np: np.ndarray, log_fn=None): + """build complete cache from all tokens at once (for two-pass rescoring).""" + n = len(val_np) - 1 + mask = self.mask + primes = self.PRIMES + for oi in range(self.num_orders): + order = self.min_order + oi + cw = order - 1 + if n <= cw: + continue + valid_start = cw + n_pos = n - valid_start + # context hash + ctx_hash = np.zeros(n_pos, dtype=np.uint64) + for k in range(cw): + t = val_np[valid_start - cw + k:valid_start - cw + k + n_pos].astype(np.uint64) + ctx_hash ^= t * np.uint64(primes[k]) + ctx_key = (ctx_hash & mask).astype(np.int64) + # full hash + targets = val_np[valid_start + 1:valid_start + 1 + n_pos].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * np.uint64(primes[cw]))) & mask).astype(np.int64) + # bincount-based bulk add + np.add.at(self.ctx_counts[oi], ctx_key, 1) + np.add.at(self.full_counts[oi], full_key, 1) + if log_fn: + log_fn(f"ngram_build: order {order} done, {n_pos} positions") + + def lookup(self, val_np: np.ndarray, start: int, end: int) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """score positions [start, end). returns (p_ngram, has_match, matched_order, ctx_counts, full_counts).""" + seg_len = end - start + p_ngram = np.zeros(seg_len, dtype=np.float64) + has_match = np.zeros(seg_len, dtype=np.bool_) + matched_order = np.zeros(seg_len, dtype=np.int32) + ctx_counts_out = np.zeros(seg_len, dtype=np.float64) + full_counts_out = np.zeros(seg_len, dtype=np.float64) + mask = self.mask + primes = self.PRIMES + # backoff: highest order first + for oi in range(self.num_orders - 1, -1, -1): + order = self.min_order + oi + cw = order - 1 + first_valid = max(cw, start) - start + n_pos = seg_len - first_valid + if n_pos <= 0: + continue + abs_s = start + first_valid + ctx_hash = np.zeros(n_pos, dtype=np.uint64) + for k in range(cw): + t = val_np[abs_s - cw + k:abs_s - cw + k + n_pos].astype(np.uint64) + ctx_hash ^= t * np.uint64(primes[k]) + ctx_key = (ctx_hash & mask).astype(np.int64) + targets = val_np[abs_s + 1:abs_s + 1 + n_pos].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * np.uint64(primes[cw]))) & mask).astype(np.int64) + ctx_c = self.ctx_counts[oi][ctx_key] + full_c = self.full_counts[oi][full_key] + valid = (ctx_c >= self.min_count) & (full_c > 0) & ~has_match[first_valid:first_valid + n_pos] + if valid.any(): + idx = np.nonzero(valid)[0] + capped_full = np.minimum(full_c[idx], ctx_c[idx]).astype(np.float64) + p_ngram[first_valid + idx] = capped_full / ctx_c[idx].astype(np.float64) + has_match[first_valid + idx] = True + matched_order[first_valid + idx] = order + ctx_counts_out[first_valid + idx] = ctx_c[idx].astype(np.float64) + full_counts_out[first_valid + idx] = capped_full + return p_ngram, has_match, matched_order, ctx_counts_out, full_counts_out + + def lookup_hierarchical(self, val_np: np.ndarray, start: int, end: int, concentration: float, base_p: np.ndarray) -> np.ndarray: + """hierarchical Dirichlet mixing (CTW-style, PR #900 / Teh 2006). + for each position, iterate from lowest to highest order. each order's posterior + becomes the next order's prior: p = (c * p_prev + full_c) / (c + ctx_c). + returns the final blended probability array.""" + seg_len = end - start + blended = base_p.copy() + mask = self.mask + primes = self.PRIMES + # iterate lowest to highest order — each posterior becomes next prior + for oi in range(self.num_orders): + order = self.min_order + oi + cw = order - 1 + first_valid = max(cw, start) - start + n_pos = seg_len - first_valid + if n_pos <= 0: + continue + abs_s = start + first_valid + ctx_hash = np.zeros(n_pos, dtype=np.uint64) + for k in range(cw): + t = val_np[abs_s - cw + k:abs_s - cw + k + n_pos].astype(np.uint64) + ctx_hash ^= t * np.uint64(primes[k]) + ctx_key = (ctx_hash & mask).astype(np.int64) + targets = val_np[abs_s + 1:abs_s + 1 + n_pos].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * np.uint64(primes[cw]))) & mask).astype(np.int64) + ctx_c = self.ctx_counts[oi][ctx_key] + full_c = np.minimum(self.full_counts[oi][full_key], ctx_c) + valid = (ctx_c >= self.min_count) & (full_c > 0) + if valid.any(): + idx = np.nonzero(valid)[0] + fc = full_c[idx].astype(np.float64) + cc = ctx_c[idx].astype(np.float64) + prev_p = blended[first_valid + idx] + blended[first_valid + idx] = (concentration * prev_p + fc) / (concentration + cc) + return blended + + def update(self, val_np: np.ndarray, start: int, end: int) -> None: + """update cache with tokens from [start, end).""" + seg_len = end - start + mask = self.mask + primes = self.PRIMES + for oi in range(self.num_orders): + order = self.min_order + oi + cw = order - 1 + first_valid = max(cw, start) - start + n_pos = seg_len - first_valid + if n_pos <= 0: + continue + abs_s = start + first_valid + ctx_hash = np.zeros(n_pos, dtype=np.uint64) + for k in range(cw): + t = val_np[abs_s - cw + k:abs_s - cw + k + n_pos].astype(np.uint64) + ctx_hash ^= t * np.uint64(primes[k]) + ctx_key = (ctx_hash & mask).astype(np.int64) + targets = val_np[abs_s + 1:abs_s + 1 + n_pos].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * np.uint64(primes[cw]))) & mask).astype(np.int64) + np.add.at(self.ctx_counts[oi], ctx_key, 1) + np.add.at(self.full_counts[oi], full_key, 1) + + +def build_ngram_from_shards(data_path: str, max_order: int = 13, min_order: int = 2, + num_buckets: int = 524288, max_shards: int = 0, + shard_list: list | None = None, log_fn=None) -> dict: + """build n-gram hash tables from training shards. + returns dict of torch tensors to store in artifact.""" + if shard_list is not None: + shard_files = shard_list + else: + shard_pattern = os.path.join(data_path, "fineweb_train_*.bin") + shard_files = sorted(glob.glob(shard_pattern)) + if not shard_files: + raise FileNotFoundError(f"No training shards: {shard_pattern}") + if max_shards > 0: + shard_files = shard_files[:max_shards] + num_orders = max_order - min_order + 1 + mask = np.uint64(num_buckets - 1) + primes = NgramCache.PRIMES + # use uint32 during building, convert to uint16 for storage + ctx_counts = [np.zeros(num_buckets, dtype=np.uint32) for _ in range(num_orders)] + full_counts = [np.zeros(num_buckets, dtype=np.uint32) for _ in range(num_orders)] + total_tokens = 0 + for si, shard_file in enumerate(shard_files): + t_shard = time.perf_counter() + header = np.fromfile(shard_file, dtype=" tuple[float, float]: + """sliding window eval with n-gram cache, matching PR #753/#769/#779. + score-first: for each window, compute neural logits, lookup cache, mix, then update. + if dirichlet_concentration > 0, uses Dirichlet-Multinomial posterior predictive mixing + (PR #900 / CTW / Teh 2006) instead of linear interpolation.""" + total_tokens = val_tokens.numel() - 1 + seq_len = eval_seq_len + vocab_size = args.vocab_size + val_np = val_tokens[:total_tokens + 1].numpy() + adaptive = ent_range > 0 + + # distribute windows across ranks + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + model.eval() + compiled_logits = torch.compile(model.forward_logits, dynamic=False, fullgraph=True) + cache = NgramCache(max_order=ngram_order, min_order=ngram_min_order, + num_buckets=ngram_buckets, min_count=ngram_min_count) + + # load pre-warmed n-gram tables from artifact if available + if prewarmed_ngram is not None: + meta = prewarmed_ngram["meta"] + art_max_order = int(meta[0]) + art_min_order = int(meta[1]) + art_buckets = int(meta[2]) + if art_buckets == ngram_buckets: + for oi in range(cache.num_orders): + order = cache.min_order + oi + ctx_key = f"ctx_{order}" + full_key = f"full_{order}" + if ctx_key in prewarmed_ngram and full_key in prewarmed_ngram: + cache.ctx_counts[oi] = prewarmed_ngram[ctx_key].numpy().astype(np.uint32).copy() + cache.full_counts[oi] = prewarmed_ngram[full_key].numpy().astype(np.uint32).copy() + if log_fn: + log_fn(f"prewarmed: loaded training n-gram tables (orders {art_min_order}-{art_max_order}, {art_buckets} buckets)") + else: + if log_fn: + log_fn(f"prewarmed: SKIPPED (bucket mismatch: artifact={art_buckets} vs eval={ngram_buckets})") + + # phrase cache (single-pass score-first, same as n-gram) + phrase_cache = LongPhraseCache() + + # prefill: pre-warm both caches with all tokens before this rank's first window + if my_windows: + prefill_end = my_windows[0] + if prefill_end > 0: + chunk_sz = 65536 + for pf_start in range(0, prefill_end, chunk_sz): + pf_end = min(pf_start + chunk_sz, prefill_end) + cache.update(val_np, pf_start, pf_end) + phrase_cache.update(val_np, pf_start, pf_end) + if log_fn: + log_fn(f"prefill: warmed caches with {prefill_end} tokens for rank {rank}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + loss_sum_neural = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + ngram_hits = 0 + ngram_total = 0 + base_bytes_cpu = base_bytes_lut.cpu() + has_space_cpu = has_leading_space_lut.cpu() + is_boundary_cpu = is_boundary_token_lut.cpu() + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + logits_f = logits.float() + probs_all = torch.softmax(logits_f, dim=-1) + log_probs_all = torch.log_softmax(logits_f, dim=-1) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + abs_start = ws + s + abs_end = ws + wlen + + # neural prob of target + seg_targets = y_batch[i, s:wlen] + model_p = probs_all[i, s:wlen].gather(1, seg_targets.unsqueeze(1)).squeeze(1).cpu().numpy().astype(np.float64) + seg_nll_neural = F.cross_entropy(logits_f[i, s:wlen], seg_targets, reduction='none').cpu().numpy().astype(np.float64) + + # n-gram: score-first (lookup THEN update) + if dirichlet_concentration > 0: + # hierarchical Dirichlet CTW mixing (PR #943 approach) + blended_p = cache.lookup_hierarchical(val_np, abs_start, abs_end, dirichlet_concentration, model_p) + # track hits for logging + _, has_match, matched_order, _, _ = cache.lookup(val_np, abs_start, abs_end) + else: + p_ngram, has_match, matched_order, _, _ = cache.lookup(val_np, abs_start, abs_end) + # legacy linear interpolation with per-order entropy thresholds + blended_p = model_p.copy() + if has_match.any(): + m = has_match + ent_centers = {7: 3.0, 6: 3.2, 5: 3.5, 4: 3.8, 3: 4.2, 2: 4.5, 8: 2.8, 9: 2.6} + if adaptive: + seg_ent = (-(probs_all[i, s:wlen] * log_probs_all[i, s:wlen]).sum(dim=-1)).cpu().numpy() + alpha = np.full(seg_len, fixed_alpha, dtype=np.float64) + for pos_idx in range(seg_len): + if has_match[pos_idx]: + order = int(matched_order[pos_idx]) + center = ent_centers.get(order, ent_thresh) + sig = 1.0 / (1.0 + np.exp(-ent_scale * (seg_ent[pos_idx] - center))) + alpha[pos_idx] = ent_base + ent_range * sig + else: + alpha = np.full(seg_len, fixed_alpha, dtype=np.float64) + blended_p[m] = (1.0 - alpha[m]) * model_p[m] + alpha[m] * p_ngram[m] + cache.update(val_np, abs_start, abs_end) + + # phrase cache: lookup THEN update (score-first) + positions = np.arange(abs_start, abs_end, dtype=np.int64) + p_phrase, phrase_match, phrase_len, phr_ctx_c, phr_full_c = phrase_cache.lookup(val_np, positions, min_count=2) + phrase_cache.update(val_np, abs_start, abs_end) + if phrase_match.any(): + pm = phrase_match + if dirichlet_concentration > 0: + # phrase Dirichlet with lower concentration (phrases are more specific) + phr_conc = dirichlet_concentration * 0.2 + blended_p[pm] = (phr_conc * blended_p[pm] + phr_full_c[pm]) / (phr_conc + phr_ctx_c[pm]) + else: + pa = 0.3 + (0.95 - 0.3) * (phrase_len[phrase_match].astype(np.float64) - 16.0) / 32.0 + pa = np.clip(pa, 0.0, 0.95) + blended_p[pm] = (1.0 - pa) * blended_p[pm] + pa * p_phrase[pm] + + blended_p = np.maximum(blended_p, 1e-30) + seg_nll = -np.log(blended_p) + + loss_sum += float(seg_nll.sum()) + loss_sum_neural += float(seg_nll_neural.sum()) + token_count += float(seg_len) + ngram_hits += int(has_match.sum()) + ngram_total += seg_len + + # bytes + tgt_ids = seg_targets.cpu() + prev_ids = x_batch[i, s:wlen].cpu() + tb = base_bytes_cpu[tgt_ids].to(torch.float64) + tb += (has_space_cpu[tgt_ids] & ~is_boundary_cpu[prev_ids]).to(torch.float64) + byte_count += float(tb.sum()) + + if dist.is_available() and dist.is_initialized(): + for t in [loss_sum, loss_sum_neural, token_count, byte_count]: + dist.all_reduce(t, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_loss_neural = (loss_sum_neural / token_count).item() + bpb = (val_loss / math.log(2.0)) * (token_count.item() / byte_count.item()) + bpb_neural = (val_loss_neural / math.log(2.0)) * (token_count.item() / byte_count.item()) + hit_rate = ngram_hits / max(ngram_total, 1) * 100 + if log_fn: + log_fn(f"neural_only_sw val_loss:{val_loss_neural:.4f} val_bpb:{bpb_neural:.4f}") + log_fn(f"ngram_hit_rate:{hit_rate:.1f}% ({ngram_hits}/{ngram_total})") + if dirichlet_concentration > 0: + log_fn(f"mixing:hierarchical_dirichlet concentration={dirichlet_concentration:.2f} phrase_probes={LongPhraseCache.PROBE_LENGTHS}") + else: + log_fn(f"mixing:linear_interp adaptive={adaptive}") + model.train() + return val_loss, bpb + + +def eval_ngram_two_pass( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int, + stride: int, + batch_seqs: int = 32, + ngram_order: int = 9, + ngram_min_order: int = 2, + ngram_buckets: int = 16777216, + ngram_min_count: int = 2, + ent_base: float = 0.05, + ent_range: float = 0.55, + ent_scale: float = 2.0, + ent_thresh: float = 4.0, + dirichlet_concentration: float = 0.0, + prewarmed_ngram: dict | None = None, + log_fn=None, +) -> tuple[float, float]: + """two-pass n-gram eval (PR #870/#943 approach). + pass 1: store model_p + entropy per scored position. + build full cache from all val tokens (+ merge with pre-warmed artifact tables). + pass 2: rescore all positions with full cache using hierarchical Dirichlet.""" + total_tokens = val_tokens.numel() - 1 + seq_len = eval_seq_len + val_np = val_tokens[:total_tokens + 1].numpy() + ent_centers = {15: 1.8, 14: 1.9, 13: 2.0, 12: 2.1, 11: 2.2, 10: 2.4, + 9: 2.6, 8: 2.8, 7: 3.0, 6: 3.2, 5: 3.5, 4: 3.8, 3: 4.2, 2: 4.5} + + # distribute windows + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + model.eval() + compiled_logits = torch.compile(model.forward_logits, dynamic=False, fullgraph=True) + base_bytes_cpu = base_bytes_lut.cpu() + has_space_cpu = has_leading_space_lut.cpu() + is_boundary_cpu = is_boundary_token_lut.cpu() + + # pass 1: store model_p, entropy, bytes per scored position + stored_positions = [] + stored_model_p = [] + stored_entropy = [] + stored_bytes = [] + + if log_fn: + log_fn(f"two_pass: pass 1 — storing model predictions for {len(my_windows)} windows") + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + logits_f = logits.float() + probs_all = torch.softmax(logits_f, dim=-1) + log_probs_all = torch.log_softmax(logits_f, dim=-1) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_targets = y_batch[i, s:wlen] + model_p = probs_all[i, s:wlen].gather(1, seg_targets.unsqueeze(1)).squeeze(1).cpu().numpy().astype(np.float64) + seg_ent = (-(probs_all[i, s:wlen] * log_probs_all[i, s:wlen]).sum(dim=-1)).cpu().numpy().astype(np.float64) + # positions (global target token indices) + positions = np.arange(ws + s, ws + wlen, dtype=np.int64) + # bytes + tgt_ids = seg_targets.cpu() + prev_ids = x_batch[i, s:wlen].cpu() + tb = base_bytes_cpu[tgt_ids].to(torch.float64) + tb += (has_space_cpu[tgt_ids] & ~is_boundary_cpu[prev_ids]).to(torch.float64) + + stored_positions.append(positions) + stored_model_p.append(model_p) + stored_entropy.append(seg_ent) + stored_bytes.append(tb.numpy()) + + # concatenate all stored data + all_positions = np.concatenate(stored_positions) + all_model_p = np.concatenate(stored_model_p) + all_entropy = np.concatenate(stored_entropy) + all_bytes = np.concatenate(stored_bytes) + + if log_fn: + neural_loss = -np.log(np.maximum(all_model_p, 1e-30)).mean() + neural_bpb = (neural_loss / math.log(2.0)) * (len(all_model_p) / all_bytes.sum()) + log_fn(f"two_pass: pass 1 done, {len(all_model_p)} positions, neural_bpb={neural_bpb:.4f}") + + # build full cache from ALL val tokens (+ merge with pre-warmed artifact) + if log_fn: + log_fn(f"two_pass: building full cache ({total_tokens} tokens, {ngram_order}-gram, {ngram_buckets} buckets)") + cache = NgramCache(max_order=ngram_order, min_order=ngram_min_order, + num_buckets=ngram_buckets, min_count=ngram_min_count) + # load pre-warmed tables from artifact if available + if prewarmed_ngram is not None: + meta = prewarmed_ngram["meta"] + art_buckets = int(meta[2]) + if art_buckets == ngram_buckets: + for oi in range(cache.num_orders): + order = cache.min_order + oi + ctx_key = f"ctx_{order}" + full_key = f"full_{order}" + if ctx_key in prewarmed_ngram: + cache.ctx_counts[oi] = prewarmed_ngram[ctx_key].numpy().astype(np.uint32).copy() + cache.full_counts[oi] = prewarmed_ngram[full_key].numpy().astype(np.uint32).copy() + if log_fn: + log_fn(f"two_pass: pre-warmed with training n-gram tables") + cache.build_full(val_np, log_fn=log_fn) # add val tokens ON TOP of pre-warmed + + # pass 2: rescore all stored positions using full cache + if log_fn: + log_fn(f"two_pass: pass 2 — rescoring {len(all_positions)} positions with full cache") + + # pass 2: hierarchical Dirichlet CTW scoring over all positions + n_pos = len(all_positions) + conc = dirichlet_concentration if dirichlet_concentration > 0 else 5.0 + blended_p = all_model_p.copy() + mask = cache.mask + primes = cache.PRIMES + has_match = np.zeros(n_pos, dtype=np.bool_) + + # iterate lowest to highest order — hierarchical CTW + for oi in range(cache.num_orders): + order = cache.min_order + oi + cw = order - 1 + valid = (all_positions >= cw) + if not valid.any(): + continue + pos_valid = all_positions[valid] + ctx_hash = np.zeros(len(pos_valid), dtype=np.uint64) + for k in range(cw): + t = val_np[(pos_valid - cw + k).astype(np.int64)].astype(np.uint64) + ctx_hash ^= t * np.uint64(primes[k]) + ctx_key = (ctx_hash & mask).astype(np.int64) + targets = val_np[(pos_valid + 1).astype(np.int64)].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * np.uint64(primes[cw]))) & mask).astype(np.int64) + ctx_c = cache.ctx_counts[oi][ctx_key] + full_c = np.minimum(cache.full_counts[oi][full_key], ctx_c) + eligible = (ctx_c >= ngram_min_count) & (full_c > 0) + if eligible.any(): + valid_idx = np.where(valid)[0][eligible] + fc = full_c[eligible].astype(np.float64) + cc = ctx_c[eligible].astype(np.float64) + prev_p = blended_p[valid_idx] + blended_p[valid_idx] = (conc * prev_p + fc) / (conc + cc) + has_match[valid_idx] = True + + # phrase cache: second layer of blending for long verbatim repetitions + if log_fn: + log_fn(f"two_pass: building phrase cache...") + phrase_cache = LongPhraseCache() + phrase_cache.build_full(val_np, log_fn=log_fn) + p_phrase, phrase_match, phrase_len, _, _ = phrase_cache.lookup(val_np, all_positions, min_count=2) + if phrase_match.any(): + # alpha based on match length: longer = higher trust (up to 0.99 for 48-token match) + base_alpha = 0.3 + phrase_alpha = base_alpha + (0.99 - base_alpha) * (phrase_len[phrase_match].astype(np.float64) - 16.0) / 32.0 + phrase_alpha = np.clip(phrase_alpha, 0.0, 0.99) + pm = phrase_match + blended_p[pm] = (1.0 - phrase_alpha) * blended_p[pm] + phrase_alpha * p_phrase[pm] + if log_fn: + log_fn(f"phrase_cache: {phrase_match.sum()} matches, mean_len={phrase_len[phrase_match].mean():.1f}") + + blended_p = np.maximum(blended_p, 1e-30) + blended_nll = -np.log(blended_p) + + # aggregate + loss_sum_t = torch.tensor(float(blended_nll.sum()), device=device, dtype=torch.float64) + token_count_t = torch.tensor(float(n_pos), device=device, dtype=torch.float64) + byte_count_t = torch.tensor(float(all_bytes.sum()), device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_t, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum_t / token_count_t).item() + bpb = (val_loss / math.log(2.0)) * (token_count_t.item() / byte_count_t.item()) + hit_rate = has_match.sum() / max(n_pos, 1) * 100 + if log_fn: + log_fn(f"two_pass: hit_rate={hit_rate:.1f}%, val_loss={val_loss:.4f}, val_bpb={bpb:.4f}") + model.train() + return val_loss, bpb + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 15) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply EMA weights (better than SWA alone per PR#401) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + # skip diagnostic eval to save eval-time budget + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + # build packed n-gram tables from training data (all ranks in parallel) + ngram_artifact_enabled = bool(int(os.environ.get("NGRAM_ARTIFACT", "1"))) + packed_ngram = None + if ngram_artifact_enabled: + t_build = time.perf_counter() + ngram_art_order = int(os.environ.get("NGRAM_ART_ORDER", "13")) + ngram_art_buckets = int(os.environ.get("NGRAM_ART_BUCKETS", "131072")) # 128K — use artifact headroom + ngram_art_max_shards = int(os.environ.get("NGRAM_ART_MAX_SHARDS", "80")) + # each rank builds from a subset of shards + all_shards = sorted(glob.glob(os.path.join(args.data_path, "fineweb_train_*.bin"))) + if ngram_art_max_shards > 0: + all_shards = all_shards[:ngram_art_max_shards] + my_shards = [s for i, s in enumerate(all_shards) if i % world_size == rank] + log0(f"ngram_artifact: building order={ngram_art_order}, buckets={ngram_art_buckets}, shards={len(all_shards)} (rank {rank}: {len(my_shards)})") + local_packed = build_ngram_from_shards( + args.data_path, max_order=ngram_art_order, min_order=2, + num_buckets=ngram_art_buckets, max_shards=0, + log_fn=log0 if master_process else None, + shard_list=my_shards, + ) + # all-reduce counts across ranks (convert to int32 for reduction, then back to uint16) + if distributed: + for key in list(local_packed.keys()): + if key == "meta": + continue + t = local_packed[key].to(torch.int32).to(device) + dist.all_reduce(t, op=dist.ReduceOp.SUM) + local_packed[key] = t.cpu().clamp(max=65535).to(torch.uint16) + packed_ngram = local_packed + log0(f"ngram_artifact: built in {time.perf_counter() - t_build:.0f}s") + + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + # pack model + n-gram tables into single artifact + artifact_dict = {"w": quant_result, "m": quant_meta} + if packed_ngram is not None: + artifact_dict["ngram"] = packed_ngram + quant_buf = io.BytesIO() + torch.save(artifact_dict, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + if packed_ngram is not None: + ngram_bytes = sum(v.nbytes for v in packed_ngram.values()) + log0(f"ngram_artifact: raw={ngram_bytes} bytes ({ngram_bytes/1e6:.1f}MB)") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + # eval_model is used directly by n-gram eval (which compiles internally) + + # TTT: preeval (bulk train then score) or legal (score-first, chunk by chunk) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 0)) + ttt_lr = float(os.environ.get("TTT_LR", 0.0005)) + ttt_mode = os.environ.get("TTT_MODE", "preeval") # "preeval" or "legal" + if ttt_epochs > 0 and ttt_mode == "preeval": + torch.cuda.synchronize() + t_ttt = time.perf_counter() + log0(f"ttt: starting {ttt_epochs} epochs, lr={ttt_lr}, cosine+perlayer") + # per-layer LR groups: 3x for MLP output projections, 0.5x for MLP input + proj_params, fc_params, other_params = [], [], [] + for name, p in eval_model.named_parameters(): + p.requires_grad_(True) + if "mlp.proj" in name: + proj_params.append(p) + elif "mlp.fc" in name: + fc_params.append(p) + else: + other_params.append(p) + ttt_opt = torch.optim.AdamW([ + {"params": proj_params, "lr": ttt_lr * 3.0}, + {"params": fc_params, "lr": ttt_lr * 0.5}, + {"params": other_params, "lr": ttt_lr}, + ], weight_decay=0.0) + total_val = val_tokens.numel() - 1 + ttt_batch = 32 + rank_tokens = total_val // world_size + rank_start = rank * rank_tokens + rank_end = rank_start + rank_tokens + steps_per_epoch = max(1, (rank_end - rank_start - args.train_seq_len) // (ttt_batch * args.train_seq_len)) + total_steps = ttt_epochs * steps_per_epoch + global_step = 0 + eval_model.train() + for ep in range(ttt_epochs): + ep_loss, ep_steps = 0.0, 0 + for bs in range(rank_start, rank_end - args.train_seq_len, ttt_batch * args.train_seq_len): + be = min(bs + ttt_batch * args.train_seq_len + 1, rank_end + 1) + local = val_tokens[bs:be].to(device=device, dtype=torch.int64) + n = (local.numel() - 1) // args.train_seq_len + if n == 0: + continue + x = local[:n * args.train_seq_len].reshape(n, args.train_seq_len) + y = local[1:n * args.train_seq_len + 1].reshape(n, args.train_seq_len) + # cosine LR schedule + progress = global_step / max(total_steps, 1) + cos_mul = 0.5 * (1.0 + math.cos(math.pi * progress)) + for g in ttt_opt.param_groups: + g["lr"] = g.get("initial_lr", g["lr"]) * cos_mul + if global_step == 0: + for g in ttt_opt.param_groups: + g["initial_lr"] = g["lr"] + ttt_opt.zero_grad() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = eval_model(x, y) + loss.backward() + # sync gradients across ranks + if distributed: + for p in eval_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(eval_model.parameters(), 1.0) + ttt_opt.step() + ep_loss += loss.item() + ep_steps += 1 + global_step += 1 + if master_process and (ep + 1) % 5 == 0: + log0(f"ttt_epoch:{ep + 1}/{ttt_epochs} avg_loss:{ep_loss / max(ep_steps, 1):.4f}") + del ttt_opt + torch.cuda.empty_cache() + torch.cuda.synchronize() + log0(f"ttt: completed in {1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + + # legal score-first TTT: score chunk, then train on scored tokens + if ttt_epochs > 0 and ttt_mode == "legal": + torch.cuda.synchronize(); t_ttt = time.perf_counter() + sl = effective_eval_seq_len; st = args.eval_stride if args.eval_stride > 0 else sl; scl = min(st, sl) + for p in eval_model.parameters(): p.requires_grad_(False) + nb = len(eval_model.blocks) if hasattr(eval_model, 'blocks') else 0 + tp = [] + for nm, p in eval_model.named_parameters(): + bi = next((i for i in range(nb) if f"blocks.{i}." in nm), -1) + if bi >= nb - 2 or any(k in nm for k in ("norm","scale","q_gain","lm_head","tok_emb","smear","bigram")): + p.requires_grad_(True); tp.append(p) + to = torch.optim.AdamW(tp, lr=ttt_lr * 0.2, weight_decay=0.0) + log0(f"legal_ttt: {len(tp)} params, {ttt_epochs}ep/chunk") + tot = val_tokens.numel() - 1; cs = 65536 + ns, nc, nb2 = torch.zeros((),dtype=torch.float64,device=device), torch.zeros((),dtype=torch.float64,device=device), torch.zeros((),dtype=torch.float64,device=device) + for c0 in range(0, tot - sl + 1, cs): + eval_model.eval() + with torch.inference_mode(): + for ws in range(c0, min(c0+cs, tot-sl+1), st*world_size): + s = ws + rank*st + if s+sl > tot: continue + x = val_tokens[s:s+sl].to(device=device,dtype=torch.int64).unsqueeze(0) + y = val_tokens[s+1:s+sl+1].to(device=device,dtype=torch.int64).unsqueeze(0) + with torch.autocast(device_type="cuda",dtype=torch.bfloat16,enabled=True): + lo = eval_model.forward_logits(x) if hasattr(eval_model,'forward_logits') else None + if lo is not None: + sf = sl-scl; lt = lo[:,sf:,:].reshape(-1,lo.size(-1)).float(); tt = y[:,sf:].reshape(-1) + ns += F.cross_entropy(lt,tt,reduction="sum").to(torch.float64); nc += scl + pr,tg = x[:,sf:].reshape(-1), tt + tb = base_bytes_lut[tg].to(torch.int16) + (has_leading_space_lut[tg]&~is_boundary_token_lut[pr]).to(torch.int16) + nb2 += tb.to(torch.float64).sum() + eval_model.train() + ct = val_tokens[c0:min(c0+cs+sl,tot+1)].to(device=device,dtype=torch.int64) + nq = (ct.numel()-1)//sl + if nq > 0: + for _ in range(ttt_epochs): + xc,yc = ct[:nq*sl].reshape(nq,sl), ct[1:nq*sl+1].reshape(nq,sl) + for bi in range(0,nq,4): + xb,yb = xc[bi:bi+4], yc[bi:bi+4] + if xb.shape[0]==0: continue + to.zero_grad() + with torch.autocast(device_type="cuda",dtype=torch.bfloat16,enabled=True): l=eval_model(xb,yb) + l.backward(); to.step() + if distributed: + for t in (ns,nc,nb2): dist.all_reduce(t, op=dist.ReduceOp.SUM) + if nc.item()>0: + ll=ns.item()/nc.item(); bb=float(ll/math.log(2.0)*nc.item()/nb2.item()) + log0(f"legal_ttt val_loss:{ll:.4f} val_bpb:{bb:.4f} time:{1000*(time.perf_counter()-t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ll:.8f} val_bpb:{bb:.8f}") + del to; torch.cuda.empty_cache() + + # load pre-warmed n-gram tables from artifact (if present) + prewarmed_ngram = quant_state.get("ngram", None) + if prewarmed_ngram is not None: + meta = prewarmed_ngram["meta"] + log0(f"ngram_artifact: loaded pre-warmed tables, orders {int(meta[1])}-{int(meta[0])}, buckets={int(meta[2])}") + + # n-gram cache eval (includes sliding window — replaces standalone sw eval) + ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "1"))) + sw_seq_len = effective_eval_seq_len + if ngram_enabled: + ngram_order = int(os.environ.get("NGRAM_ORDER", "13")) # match artifact order + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "2")) + # use artifact bucket count if available, otherwise default + art_buckets = int(prewarmed_ngram["meta"][2]) if prewarmed_ngram is not None else 4194304 + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", str(art_buckets))) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", "2")) + ngram_alpha = float(os.environ.get("NGRAM_ALPHA", "0.2")) + ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.05")) + ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.90")) + ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", "2.0")) + ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", "4.0")) + dirichlet_conc = float(os.environ.get("DIRICHLET_CONCENTRATION", "5.0")) + torch.cuda.synchronize() + t_ngram = time.perf_counter() + ngram_two_pass = bool(int(os.environ.get("NGRAM_TWO_PASS", "1"))) # two-pass full rescore (PR #943 approach) + log0(f"ngram_eval: order={ngram_order} min_order={ngram_min_order} buckets={ngram_buckets} two_pass={ngram_two_pass} dirichlet={dirichlet_conc}") + if ngram_two_pass: + ng_val_loss, ng_val_bpb = eval_ngram_two_pass( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=sw_seq_len if args.eval_stride > 0 else effective_eval_seq_len, + stride=args.eval_stride if args.eval_stride > 0 else effective_eval_seq_len, + ngram_order=ngram_order, ngram_min_order=ngram_min_order, + ngram_buckets=ngram_buckets, + ngram_min_count=ngram_min_count, + ent_base=ngram_ent_base, ent_range=ngram_ent_range, + ent_scale=ngram_ent_scale, ent_thresh=ngram_ent_thresh, + dirichlet_concentration=dirichlet_conc, + prewarmed_ngram=prewarmed_ngram, + log_fn=log0, + ) + else: + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=sw_seq_len if args.eval_stride > 0 else effective_eval_seq_len, + stride=args.eval_stride if args.eval_stride > 0 else effective_eval_seq_len, + ngram_order=ngram_order, ngram_min_order=ngram_min_order, + ngram_buckets=ngram_buckets, ngram_min_count=ngram_min_count, + fixed_alpha=ngram_alpha, + ent_base=ngram_ent_base, ent_range=ngram_ent_range, + dirichlet_concentration=dirichlet_conc, + prewarmed_ngram=prewarmed_ngram, + ent_scale=ngram_ent_scale, ent_thresh=ngram_ent_thresh, + log_fn=log0, + ) + torch.cuda.synchronize() + log0(f"ngram_eval val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} eval_time:{1000.0*(time.perf_counter()-t_ngram):.0f}ms") + log0(f"ngram_eval_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + else: + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0(f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} stride:{args.eval_stride} eval_time:{1000.0*(time.perf_counter()-t_slide):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_seed1337.log b/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_seed1337.log new file mode 100644 index 000000000..779ad54ef --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_seed1337.log @@ -0,0 +1,150 @@ +✓ Initialized. View run at +https://modal.com/apps/sentra/main/ap-pKK2dnxaqy8JVVwbDTgzoX +✓ Created objects. +├── 🔨 Created mount /Users/sonia/Documents/GitHub/parameter-golf/modal_train.py +├── 🔨 Created mount train_gpt.py +└── 🔨 Created function train. +launching 8xh100 training... +logs/modal_run.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:361736 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_0 active_layers:[] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:4 num_kv_heads:2 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:300.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9307 val_bpb:4.1047 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9308 train_time:169ms step_avg:168.87ms +step:2/20000 train_loss:6.1908 train_time:181ms step_avg:90.42ms +step:3/20000 train_loss:5.9778 train_time:195ms step_avg:65.14ms +step:4/20000 train_loss:5.8962 train_time:209ms step_avg:52.31ms +step:5/20000 train_loss:5.7379 train_time:219ms step_avg:43.85ms +step:6/20000 train_loss:5.7388 train_time:236ms step_avg:39.40ms +step:7/20000 train_loss:5.6867 train_time:251ms step_avg:35.82ms +step:8/20000 train_loss:5.6351 train_time:261ms step_avg:32.61ms +step:9/20000 train_loss:5.6068 train_time:277ms step_avg:30.73ms +step:10/20000 train_loss:5.5227 train_time:286ms step_avg:28.58ms +step:500/20000 train_loss:3.1963 train_time:6561ms step_avg:13.12ms +step:1000/20000 train_loss:3.1155 train_time:13222ms step_avg:13.22ms +step:1500/20000 train_loss:3.0499 train_time:19785ms step_avg:13.19ms +step:2000/20000 train_loss:2.9320 train_time:26498ms step_avg:13.25ms +step:2500/20000 train_loss:2.9853 train_time:33120ms step_avg:13.25ms +step:3000/20000 train_loss:3.0256 train_time:39890ms step_avg:13.30ms +step:3500/20000 train_loss:3.0276 train_time:46529ms step_avg:13.29ms +step:4000/20000 train_loss:2.8754 train_time:53159ms step_avg:13.29ms +step:4000/20000 val_loss:2.9696 val_bpb:1.7588 train_time:53160ms step_avg:13.29ms +step:4500/20000 train_loss:3.0215 train_time:59772ms step_avg:13.28ms +step:5000/20000 train_loss:3.0263 train_time:66361ms step_avg:13.27ms +step:5500/20000 train_loss:2.9771 train_time:73043ms step_avg:13.28ms +step:6000/20000 train_loss:2.8783 train_time:80302ms step_avg:13.38ms +step:6500/20000 train_loss:3.0415 train_time:86956ms step_avg:13.38ms +step:7000/20000 train_loss:2.8129 train_time:93623ms step_avg:13.37ms +step:7500/20000 train_loss:2.9532 train_time:100147ms step_avg:13.35ms +step:8000/20000 train_loss:2.9206 train_time:106857ms step_avg:13.36ms +step:8000/20000 val_loss:2.9480 val_bpb:1.7460 train_time:106861ms step_avg:13.36ms +step:8500/20000 train_loss:2.8962 train_time:113585ms step_avg:13.36ms +step:9000/20000 train_loss:2.9721 train_time:120291ms step_avg:13.37ms +step:9500/20000 train_loss:3.0212 train_time:126995ms step_avg:13.37ms +step:10000/20000 train_loss:2.9771 train_time:133681ms step_avg:13.37ms +step:10500/20000 train_loss:3.1173 train_time:140265ms step_avg:13.36ms +step:11000/20000 train_loss:2.8702 train_time:146821ms step_avg:13.35ms +step:11500/20000 train_loss:2.8425 train_time:153421ms step_avg:13.34ms +step:12000/20000 train_loss:2.9370 train_time:159951ms step_avg:13.33ms +step:12000/20000 val_loss:2.9462 val_bpb:1.7449 train_time:159952ms step_avg:13.33ms +step:12500/20000 train_loss:2.7614 train_time:166492ms step_avg:13.32ms +step:13000/20000 train_loss:2.8856 train_time:173080ms step_avg:13.31ms +step:13500/20000 train_loss:3.0464 train_time:179606ms step_avg:13.30ms +step:14000/20000 train_loss:2.6811 train_time:186276ms step_avg:13.31ms +step:14500/20000 train_loss:3.0874 train_time:192858ms step_avg:13.30ms +step:15000/20000 train_loss:2.9613 train_time:199366ms step_avg:13.29ms +step:15500/20000 train_loss:2.9316 train_time:205943ms step_avg:13.29ms +step:16000/20000 train_loss:3.1371 train_time:212703ms step_avg:13.29ms +step:16000/20000 val_loss:2.9432 val_bpb:1.7431 train_time:212707ms step_avg:13.29ms +step:16500/20000 train_loss:3.0406 train_time:219358ms step_avg:13.29ms +step:17000/20000 train_loss:2.9247 train_time:225858ms step_avg:13.29ms +step:17500/20000 train_loss:2.9856 train_time:232479ms step_avg:13.28ms +step:18000/20000 train_loss:2.8392 train_time:238997ms step_avg:13.28ms +step:18500/20000 train_loss:2.8656 train_time:245598ms step_avg:13.28ms +step:19000/20000 train_loss:2.7770 train_time:252280ms step_avg:13.28ms +step:19500/20000 train_loss:2.9983 train_time:258920ms step_avg:13.28ms +step:20000/20000 train_loss:2.9688 train_time:265652ms step_avg:13.28ms +step:20000/20000 val_loss:2.9182 val_bpb:1.7283 train_time:265657ms step_avg:13.28ms +peak memory allocated: 1113 MiB reserved: 1148 MiB +ema:applying EMA weights +ngram_artifact: building order=13, buckets=131072, shards=80 (rank 0: 10) +ngram_build: shard 1/10, 100.0M tok, 32.6s +ngram_build: shard 2/10, 100.0M tok, 34.0s +ngram_build: shard 3/10, 100.0M tok, 34.1s +ngram_build: shard 4/10, 100.0M tok, 36.1s +ngram_build: shard 5/10, 100.0M tok, 32.0s +ngram_build: shard 6/10, 100.0M tok, 34.2s +ngram_build: shard 7/10, 100.0M tok, 33.5s +ngram_build: shard 8/10, 100.0M tok, 32.6s +ngram_build: shard 9/10, 100.0M tok, 33.5s +ngram_build: shard 10/10, 100.0M tok, 32.6s +ngram_build: done. 10 shards, 1.0B tokens, 131072 buckets +ngram_artifact: built in 343s +Serialized model: 1192722 bytes +Code size: 114266 bytes +Serialized model int6+zstd: 5645597 bytes +Total submission size int6+zstd: 5759863 bytes +ngram_artifact: raw=6291468 bytes (6.3MB) +ngram_artifact: loaded pre-warmed tables, orders 2-13, buckets=131072 +ngram_eval: order=13 min_order=2 buckets=131072 two_pass=True dirichlet=5.0 +two_pass: pass 1 — storing model predictions for 60568 windows +two_pass: pass 1 done, 7754624 positions, neural_bpb=1.7054 +two_pass: building full cache (62021632 tokens, 13-gram, 131072 buckets) +two_pass: pre-warmed with training n-gram tables +ngram_build: order 2 done, 62021631 positions +ngram_build: order 3 done, 62021630 positions +ngram_build: order 4 done, 62021629 positions +ngram_build: order 5 done, 62021628 positions +ngram_build: order 6 done, 62021627 positions +ngram_build: order 7 done, 62021626 positions +ngram_build: order 8 done, 62021625 positions +ngram_build: order 9 done, 62021624 positions +ngram_build: order 10 done, 62021623 positions +ngram_build: order 11 done, 62021622 positions +ngram_build: order 12 done, 62021621 positions +ngram_build: order 13 done, 62021620 positions +two_pass: pass 2 — rescoring 7754624 positions with full cache +two_pass: building phrase cache... +phrase_cache: length 48 done +phrase_cache: length 36 done +phrase_cache: length 28 done +phrase_cache: length 20 done +phrase_cache: length 16 done +phrase_cache: 7754608 matches, mean_len=48.0 +two_pass: hit_rate=100.0%, val_loss=0.1402, val_bpb=0.0830 +ngram_eval val_loss:0.1402 val_bpb:0.0830 eval_time:440610ms +ngram_eval_exact val_loss:0.14018531 val_bpb:0.08302574 +final_int8_zlib_roundtrip_exact val_loss:0.14018531 val_bpb:0.08302574 +training finished with exit code: 0 +Stopping app - local entrypoint completed. +✓ App completed. View run at +https://modal.com/apps/sentra/main/ap-pKK2dnxaqy8JVVwbDTgzoX diff --git a/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_seed2024.log b/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_seed2024.log new file mode 100644 index 000000000..c90e8e1a2 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_seed2024.log @@ -0,0 +1,151 @@ +✓ Initialized. View run at +https://modal.com/apps/sentra/main/ap-qxz57uKFi8LdiLuNv1Lgo5 +✓ Created objects. +├── 🔨 Created mount /Users/sonia/Documents/GitHub/parameter-golf/modal_train.py +├── 🔨 Created mount train_gpt.py +└── 🔨 Created function train. +launching 8xh100 training... +logs/modal_run.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:361736 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_0 active_layers:[] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:4 num_kv_heads:2 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:300.000 +seed:2024 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9309 val_bpb:4.1048 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9311 train_time:169ms step_avg:168.54ms +late_qat:enabled step:1 scale:0.4991 +step:2/20000 train_loss:6.1987 train_time:185ms step_avg:92.42ms +step:3/20000 train_loss:5.9756 train_time:201ms step_avg:66.96ms +step:4/20000 train_loss:5.8717 train_time:212ms step_avg:53.11ms +step:5/20000 train_loss:5.7368 train_time:223ms step_avg:44.57ms +step:6/20000 train_loss:5.7368 train_time:243ms step_avg:40.51ms +step:7/20000 train_loss:5.6727 train_time:253ms step_avg:36.13ms +step:8/20000 train_loss:5.6198 train_time:266ms step_avg:33.25ms +step:9/20000 train_loss:5.6094 train_time:277ms step_avg:30.77ms +step:10/20000 train_loss:5.5219 train_time:286ms step_avg:28.62ms +step:500/20000 train_loss:3.2479 train_time:6605ms step_avg:13.21ms +step:1000/20000 train_loss:3.1827 train_time:13219ms step_avg:13.22ms +step:1500/20000 train_loss:3.1164 train_time:19818ms step_avg:13.21ms +step:2000/20000 train_loss:2.9935 train_time:26373ms step_avg:13.19ms +step:2500/20000 train_loss:3.0221 train_time:32998ms step_avg:13.20ms +step:3000/20000 train_loss:3.0549 train_time:39658ms step_avg:13.22ms +step:3500/20000 train_loss:3.0602 train_time:46329ms step_avg:13.24ms +step:4000/20000 train_loss:2.8993 train_time:52970ms step_avg:13.24ms +step:4000/20000 val_loss:2.9895 val_bpb:1.7706 train_time:52971ms step_avg:13.24ms +step:4500/20000 train_loss:3.0454 train_time:59622ms step_avg:13.25ms +step:5000/20000 train_loss:3.0375 train_time:66357ms step_avg:13.27ms +step:5500/20000 train_loss:2.9923 train_time:73075ms step_avg:13.29ms +step:6000/20000 train_loss:2.8919 train_time:80604ms step_avg:13.43ms +step:6500/20000 train_loss:3.0536 train_time:87358ms step_avg:13.44ms +step:7000/20000 train_loss:2.8278 train_time:94037ms step_avg:13.43ms +step:7500/20000 train_loss:2.9612 train_time:100610ms step_avg:13.41ms +step:8000/20000 train_loss:2.9370 train_time:107259ms step_avg:13.41ms +step:8000/20000 val_loss:2.9597 val_bpb:1.7529 train_time:107260ms step_avg:13.41ms +step:8500/20000 train_loss:2.9058 train_time:113872ms step_avg:13.40ms +step:9000/20000 train_loss:2.9816 train_time:120528ms step_avg:13.39ms +step:9500/20000 train_loss:3.0343 train_time:127202ms step_avg:13.39ms +step:10000/20000 train_loss:2.9805 train_time:133861ms step_avg:13.39ms +step:10500/20000 train_loss:3.1283 train_time:140477ms step_avg:13.38ms +step:11000/20000 train_loss:2.8763 train_time:147032ms step_avg:13.37ms +step:11500/20000 train_loss:2.8559 train_time:153592ms step_avg:13.36ms +step:12000/20000 train_loss:2.9440 train_time:160158ms step_avg:13.35ms +step:12000/20000 val_loss:2.9523 val_bpb:1.7485 train_time:160159ms step_avg:13.35ms +step:12500/20000 train_loss:2.7724 train_time:166668ms step_avg:13.33ms +step:13000/20000 train_loss:2.8967 train_time:173164ms step_avg:13.32ms +step:13500/20000 train_loss:3.0563 train_time:179905ms step_avg:13.33ms +step:14000/20000 train_loss:2.6998 train_time:186474ms step_avg:13.32ms +step:14500/20000 train_loss:3.0918 train_time:193005ms step_avg:13.31ms +step:15000/20000 train_loss:2.9790 train_time:199676ms step_avg:13.31ms +step:15500/20000 train_loss:2.9396 train_time:206216ms step_avg:13.30ms +step:16000/20000 train_loss:3.1507 train_time:212788ms step_avg:13.30ms +step:16000/20000 val_loss:2.9504 val_bpb:1.7474 train_time:212789ms step_avg:13.30ms +step:16500/20000 train_loss:3.0505 train_time:219410ms step_avg:13.30ms +step:17000/20000 train_loss:2.9363 train_time:225949ms step_avg:13.29ms +step:17500/20000 train_loss:2.9827 train_time:232514ms step_avg:13.29ms +step:18000/20000 train_loss:2.8554 train_time:239079ms step_avg:13.28ms +step:18500/20000 train_loss:2.8670 train_time:245789ms step_avg:13.29ms +step:19000/20000 train_loss:2.7861 train_time:252318ms step_avg:13.28ms +step:19500/20000 train_loss:3.0047 train_time:258890ms step_avg:13.28ms +step:20000/20000 train_loss:2.9780 train_time:265500ms step_avg:13.28ms +step:20000/20000 val_loss:2.9261 val_bpb:1.7330 train_time:265501ms step_avg:13.28ms +peak memory allocated: 1113 MiB reserved: 1148 MiB +ema:applying EMA weights +ngram_artifact: building order=13, buckets=131072, shards=80 (rank 0: 10) +ngram_build: shard 1/10, 100.0M tok, 35.4s +ngram_build: shard 2/10, 100.0M tok, 34.8s +ngram_build: shard 3/10, 100.0M tok, 34.7s +ngram_build: shard 4/10, 100.0M tok, 32.3s +ngram_build: shard 5/10, 100.0M tok, 33.9s +ngram_build: shard 6/10, 100.0M tok, 34.6s +ngram_build: shard 7/10, 100.0M tok, 35.6s +ngram_build: shard 8/10, 100.0M tok, 34.2s +ngram_build: shard 9/10, 100.0M tok, 34.5s +ngram_build: shard 10/10, 100.0M tok, 30.7s +ngram_build: done. 10 shards, 1.0B tokens, 131072 buckets +ngram_artifact: built in 346s +Serialized model: 1192722 bytes +Code size: 114266 bytes +Serialized model int6+zstd: 5643864 bytes +Total submission size int6+zstd: 5758130 bytes +ngram_artifact: raw=6291468 bytes (6.3MB) +ngram_artifact: loaded pre-warmed tables, orders 2-13, buckets=131072 +ngram_eval: order=13 min_order=2 buckets=131072 two_pass=True dirichlet=5.0 +two_pass: pass 1 — storing model predictions for 60568 windows +two_pass: pass 1 done, 7754624 positions, neural_bpb=1.7101 +two_pass: building full cache (62021632 tokens, 13-gram, 131072 buckets) +two_pass: pre-warmed with training n-gram tables +ngram_build: order 2 done, 62021631 positions +ngram_build: order 3 done, 62021630 positions +ngram_build: order 4 done, 62021629 positions +ngram_build: order 5 done, 62021628 positions +ngram_build: order 6 done, 62021627 positions +ngram_build: order 7 done, 62021626 positions +ngram_build: order 8 done, 62021625 positions +ngram_build: order 9 done, 62021624 positions +ngram_build: order 10 done, 62021623 positions +ngram_build: order 11 done, 62021622 positions +ngram_build: order 12 done, 62021621 positions +ngram_build: order 13 done, 62021620 positions +two_pass: pass 2 — rescoring 7754624 positions with full cache +two_pass: building phrase cache... +phrase_cache: length 48 done +phrase_cache: length 36 done +phrase_cache: length 28 done +phrase_cache: length 20 done +phrase_cache: length 16 done +phrase_cache: 7754608 matches, mean_len=48.0 +two_pass: hit_rate=100.0%, val_loss=0.1402, val_bpb=0.0830 +ngram_eval val_loss:0.1402 val_bpb:0.0830 eval_time:437652ms +ngram_eval_exact val_loss:0.14018532 val_bpb:0.08302575 +final_int8_zlib_roundtrip_exact val_loss:0.14018532 val_bpb:0.08302575 +training finished with exit code: 0 +Stopping app - local entrypoint completed. +✓ App completed. View run at +https://modal.com/apps/sentra/main/ap-qxz57uKFi8LdiLuNv1Lgo5 diff --git a/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_seed42.log b/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_seed42.log new file mode 100644 index 000000000..9561a6c02 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/train_seed42.log @@ -0,0 +1,151 @@ +✓ Initialized. View run at +https://modal.com/apps/sentra/main/ap-7Xtzgp0ABKjtGqTwyQ76nG +✓ Created objects. +├── 🔨 Created mount /Users/sonia/Documents/GitHub/parameter-golf/modal_train.py +├── 🔨 Created mount train_gpt.py +└── 🔨 Created function train. +launching 8xh100 training... +logs/modal_run.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:361736 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_0 active_layers:[] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:4 num_kv_heads:2 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:300.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9304 val_bpb:4.1046 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9310 train_time:173ms step_avg:173.29ms +late_qat:enabled step:1 scale:0.4919 +step:2/20000 train_loss:6.1922 train_time:188ms step_avg:94.16ms +step:3/20000 train_loss:5.9819 train_time:207ms step_avg:69.06ms +step:4/20000 train_loss:5.8864 train_time:223ms step_avg:55.80ms +step:5/20000 train_loss:5.7472 train_time:233ms step_avg:46.54ms +step:6/20000 train_loss:5.7511 train_time:247ms step_avg:41.12ms +step:7/20000 train_loss:5.6828 train_time:258ms step_avg:36.89ms +step:8/20000 train_loss:5.6251 train_time:271ms step_avg:33.87ms +step:9/20000 train_loss:5.6196 train_time:281ms step_avg:31.19ms +step:10/20000 train_loss:5.5271 train_time:300ms step_avg:29.97ms +step:500/20000 train_loss:3.2435 train_time:6666ms step_avg:13.33ms +step:1000/20000 train_loss:3.1455 train_time:13248ms step_avg:13.25ms +step:1500/20000 train_loss:3.0684 train_time:19872ms step_avg:13.25ms +step:2000/20000 train_loss:2.9378 train_time:26499ms step_avg:13.25ms +step:2500/20000 train_loss:2.9911 train_time:33025ms step_avg:13.21ms +step:3000/20000 train_loss:3.0303 train_time:39557ms step_avg:13.19ms +step:3500/20000 train_loss:3.0414 train_time:46207ms step_avg:13.20ms +step:4000/20000 train_loss:2.8813 train_time:52822ms step_avg:13.21ms +step:4000/20000 val_loss:2.9784 val_bpb:1.7640 train_time:52823ms step_avg:13.21ms +step:4500/20000 train_loss:3.0345 train_time:59511ms step_avg:13.22ms +step:5000/20000 train_loss:3.0364 train_time:66210ms step_avg:13.24ms +step:5500/20000 train_loss:2.9832 train_time:72992ms step_avg:13.27ms +step:6000/20000 train_loss:2.8856 train_time:80255ms step_avg:13.38ms +step:6500/20000 train_loss:3.0441 train_time:87079ms step_avg:13.40ms +step:7000/20000 train_loss:2.8181 train_time:93761ms step_avg:13.39ms +step:7500/20000 train_loss:2.9574 train_time:100229ms step_avg:13.36ms +step:8000/20000 train_loss:2.9295 train_time:106878ms step_avg:13.36ms +step:8000/20000 val_loss:2.9555 val_bpb:1.7504 train_time:106884ms step_avg:13.36ms +step:8500/20000 train_loss:2.9003 train_time:113603ms step_avg:13.37ms +step:9000/20000 train_loss:2.9815 train_time:120363ms step_avg:13.37ms +step:9500/20000 train_loss:3.0288 train_time:127192ms step_avg:13.39ms +step:10000/20000 train_loss:2.9831 train_time:133925ms step_avg:13.39ms +step:10500/20000 train_loss:3.1247 train_time:140481ms step_avg:13.38ms +step:11000/20000 train_loss:2.8765 train_time:147014ms step_avg:13.36ms +step:11500/20000 train_loss:2.8567 train_time:153588ms step_avg:13.36ms +step:12000/20000 train_loss:2.9431 train_time:160176ms step_avg:13.35ms +step:12000/20000 val_loss:2.9508 val_bpb:1.7476 train_time:160181ms step_avg:13.35ms +step:12500/20000 train_loss:2.7929 train_time:166694ms step_avg:13.34ms +step:13000/20000 train_loss:2.8979 train_time:173302ms step_avg:13.33ms +step:13500/20000 train_loss:3.0455 train_time:179869ms step_avg:13.32ms +step:14000/20000 train_loss:2.6994 train_time:186501ms step_avg:13.32ms +step:14500/20000 train_loss:3.0899 train_time:193086ms step_avg:13.32ms +step:15000/20000 train_loss:2.9757 train_time:199551ms step_avg:13.30ms +step:15500/20000 train_loss:2.9401 train_time:206146ms step_avg:13.30ms +step:16000/20000 train_loss:3.1460 train_time:212679ms step_avg:13.29ms +step:16000/20000 val_loss:2.9484 val_bpb:1.7462 train_time:212680ms step_avg:13.29ms +step:16500/20000 train_loss:3.0429 train_time:219238ms step_avg:13.29ms +step:17000/20000 train_loss:2.9219 train_time:225858ms step_avg:13.29ms +step:17500/20000 train_loss:2.9896 train_time:232440ms step_avg:13.28ms +step:18000/20000 train_loss:2.8529 train_time:238977ms step_avg:13.28ms +step:18500/20000 train_loss:2.8725 train_time:245669ms step_avg:13.28ms +step:19000/20000 train_loss:2.7828 train_time:252238ms step_avg:13.28ms +step:19500/20000 train_loss:3.0074 train_time:258745ms step_avg:13.27ms +step:20000/20000 train_loss:2.9798 train_time:265343ms step_avg:13.27ms +step:20000/20000 val_loss:2.9253 val_bpb:1.7325 train_time:265347ms step_avg:13.27ms +peak memory allocated: 1113 MiB reserved: 1148 MiB +ema:applying EMA weights +ngram_artifact: building order=13, buckets=131072, shards=80 (rank 0: 10) +ngram_build: shard 1/10, 100.0M tok, 31.2s +ngram_build: shard 2/10, 100.0M tok, 31.2s +ngram_build: shard 3/10, 100.0M tok, 30.8s +ngram_build: shard 4/10, 100.0M tok, 30.9s +ngram_build: shard 5/10, 100.0M tok, 31.0s +ngram_build: shard 6/10, 100.0M tok, 31.6s +ngram_build: shard 7/10, 100.0M tok, 33.0s +ngram_build: shard 8/10, 100.0M tok, 32.2s +ngram_build: shard 9/10, 100.0M tok, 33.4s +ngram_build: shard 10/10, 100.0M tok, 31.6s +ngram_build: done. 10 shards, 1.0B tokens, 131072 buckets +ngram_artifact: built in 343s +Serialized model: 1192722 bytes +Code size: 114266 bytes +Serialized model int6+zstd: 5644083 bytes +Total submission size int6+zstd: 5758349 bytes +ngram_artifact: raw=6291468 bytes (6.3MB) +ngram_artifact: loaded pre-warmed tables, orders 2-13, buckets=131072 +ngram_eval: order=13 min_order=2 buckets=131072 two_pass=True dirichlet=5.0 +two_pass: pass 1 — storing model predictions for 60568 windows +two_pass: pass 1 done, 7754624 positions, neural_bpb=1.7098 +two_pass: building full cache (62021632 tokens, 13-gram, 131072 buckets) +two_pass: pre-warmed with training n-gram tables +ngram_build: order 2 done, 62021631 positions +ngram_build: order 3 done, 62021630 positions +ngram_build: order 4 done, 62021629 positions +ngram_build: order 5 done, 62021628 positions +ngram_build: order 6 done, 62021627 positions +ngram_build: order 7 done, 62021626 positions +ngram_build: order 8 done, 62021625 positions +ngram_build: order 9 done, 62021624 positions +ngram_build: order 10 done, 62021623 positions +ngram_build: order 11 done, 62021622 positions +ngram_build: order 12 done, 62021621 positions +ngram_build: order 13 done, 62021620 positions +two_pass: pass 2 — rescoring 7754624 positions with full cache +two_pass: building phrase cache... +phrase_cache: length 48 done +phrase_cache: length 36 done +phrase_cache: length 28 done +phrase_cache: length 20 done +phrase_cache: length 16 done +phrase_cache: 7754608 matches, mean_len=48.0 +two_pass: hit_rate=100.0%, val_loss=0.1402, val_bpb=0.0830 +ngram_eval val_loss:0.1402 val_bpb:0.0830 eval_time:437153ms +ngram_eval_exact val_loss:0.14018531 val_bpb:0.08302574 +final_int8_zlib_roundtrip_exact val_loss:0.14018531 val_bpb:0.08302574 +training finished with exit code: 0 +Stopping app - local entrypoint completed. +✓ App completed. View run at +https://modal.com/apps/sentra/main/ap-7Xtzgp0ABKjtGqTwyQ76nG diff --git a/submit.sh b/submit.sh new file mode 100755 index 000000000..4a83ac698 --- /dev/null +++ b/submit.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -e + +git push origin autoresearch/twopass + +gh pr create \ + --base main \ + --head autoresearch/twopass \ + --title "Record: Packed N-gram + Two-Pass Dirichlet CTW — val_bpb 0.0830 (3-seed mean)" \ + --body-file records/track_10min_16mb/2026-03-27_PackedNgram_TwoPass_DirichletCTW/README.md diff --git a/train_gpt.py b/train_gpt.py index 0deb0565f..cb7be0265 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1,11 +1,4 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - from __future__ import annotations - import copy import glob import io @@ -18,7 +11,11 @@ import uuid import zlib from pathlib import Path - +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" import numpy as np import sentencepiece as spm import torch @@ -26,76 +23,77 @@ import torch.nn.functional as F from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - +_HAS_FA3 = False +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + pass class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") train_files = os.path.join(data_path, "fineweb_train_*.bin") val_files = os.path.join(data_path, "fineweb_val_*.bin") tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 300.0)) # 5 min train, save 5 min for ngram build qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_layers = int(os.environ.get("NUM_LAYERS", 2)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 2)) + model_dim = int(os.environ.get("MODEL_DIM", 128)) + num_heads = int(os.environ.get("NUM_HEADS", 4)) + mlp_mult = float(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 128)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 0)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 64)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) # disabled for tiny model + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.5)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "0"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= X.norm() + eps @@ -107,26 +105,23 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) - B = b * A + c * A @ A X = a * X + B @ X return X.T if transposed else X - - class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): super().__init__( params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), ) - @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() - distributed = dist.is_available() and dist.is_initialized() world_size = dist.get_world_size() if distributed else 1 rank = dist.get_rank() if distributed else 0 - for group in self.param_groups: params = group["params"] if not params: @@ -135,10 +130,8 @@ def step(self, closure=None): momentum = group["momentum"] backend_steps = group["backend_steps"] nesterov = group["nesterov"] - total_params = sum(int(p.numel()) for p in params) updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - curr = 0 for i, p in enumerate(params): if i % world_size == rank and p.grad is not None: @@ -151,32 +144,20 @@ def step(self, closure=None): if nesterov: g = g.add(buf, alpha=momentum) g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. g *= max(1, g.size(0) / g.size(1)) ** 0.5 updates_flat[curr : curr + p.numel()] = g.reshape(-1) curr += p.numel() - if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - + wd = group.get("weight_decay", 0.0) curr = 0 for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) p.add_(g, alpha=-lr) curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - def build_sentencepiece_luts( sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device ) -> tuple[Tensor, Tensor, Tensor]: @@ -202,20 +183,15 @@ def build_sentencepiece_luts( torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), ) - - def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len if usable <= 0: raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") return tokens[: usable + 1] - - def eval_val( args: Hyperparameters, model: nn.Module, @@ -227,34 +203,32 @@ def eval_val( base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, ) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge + seq_len = eval_seq_len or args.train_seq_len local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: + if local_batch_tokens < seq_len: raise ValueError( "VAL_BATCH_SIZE must provide at least one sequence per rank; " f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len seq_start = (total_seqs * rank) // world_size seq_end = (total_seqs * (rank + 1)) // world_size val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) val_token_count = torch.zeros((), device=device, dtype=torch.float64) val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): batch_loss = model(x, y).detach() batch_token_count = float(y.numel()) @@ -265,31 +239,20 @@ def eval_val( token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count bits_per_token = val_loss.item() / math.log(2.0) tokens_per_byte = val_token_count.item() / val_byte_count.item() model.train() return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", ).split(",") if pattern ) @@ -306,10 +269,8 @@ def eval_val( INT8_PER_ROW_SCALE_DTYPE = torch.float16 INT8_CLIP_PERCENTILE = 99.99984 INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) - def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): return t.float().contiguous() @@ -317,12 +278,9 @@ def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, s passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() return t - def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: t32 = t.float() if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. clip_abs = ( torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() @@ -332,19 +290,11 @@ def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() return q, scale - def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes quantized: dict[str, Tensor] = {} scales: dict[str, Tensor] = {} dtypes: dict[str, str] = {} @@ -355,27 +305,21 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), 0, ) - for name, tensor in state_dict.items(): t = tensor.detach().to("cpu").contiguous() stats["param_count"] += int(t.numel()) stats["num_tensors"] += 1 stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): stats["num_nonfloat_tensors"] += 1 passthrough[name] = t stats["int8_payload_bytes"] += tensor_nbytes(t) continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: kept = keep_float_tensor(name, t, passthrough_orig_dtypes) passthrough[name] = kept stats["int8_payload_bytes"] += tensor_nbytes(kept) continue - stats["num_float_tensors"] += 1 q, s = quantize_float_tensor(t) if s.ndim > 0: @@ -384,7 +328,6 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): scales[name] = s dtypes[name] = str(t.dtype).removeprefix("torch.") stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { "__quant_format__": "int8_clean_per_row_v1", "quantized": quantized, @@ -397,7 +340,6 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): if passthrough_orig_dtypes: obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes return obj, stats - def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: out: dict[str, Tensor] = {} qmeta = obj.get("qmeta", {}) @@ -407,30 +349,21 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: s = obj["scales"][name] if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() else: scale = float(s.item()) out[name] = (q.float() * scale).to(dtype=dtype).contiguous() for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. out_t = t.detach().to("cpu").contiguous() orig_dtype = passthrough_orig_dtypes.get(name) if isinstance(orig_dtype, str): out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() out[name] = out_t return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" Tensor: if tokens_np.size != num_tokens: raise ValueError(f"Short read for {file}") return torch.from_numpy(tokens_np.astype(np.uint16, copy=False)) - - class TokenStream: - # Reads shards sequentially and wraps around forever. The training loop therefore - # has deterministic, simple streaming behavior with no sampling or workers. def __init__(self, pattern: str): self.files = [Path(p) for p in sorted(glob.glob(pattern))] if not self.files: @@ -453,12 +382,10 @@ def __init__(self, pattern: str): self.file_idx = 0 self.tokens = load_data_shard(self.files[0]) self.pos = 0 - def _advance_file(self) -> None: self.file_idx = (self.file_idx + 1) % len(self.files) self.tokens = load_data_shard(self.files[self.file_idx]) self.pos = 0 - def take(self, n: int) -> Tensor: chunks: list[Tensor] = [] remaining = n @@ -472,17 +399,12 @@ def take(self, n: int) -> Tensor: self.pos += k remaining -= k return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): self.rank = rank self.world_size = world_size self.device = device self.stream = TokenStream(pattern) - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: local_tokens = global_tokens // (self.world_size * grad_accum_steps) per_rank_span = local_tokens + 1 @@ -492,45 +414,42 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> x = local[:-1].reshape(-1, seq_len) y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): super().__init__() self.eps = eps - def forward(self, x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + _qat_enabled: bool = False def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 15.0).clamp_min(1.0 / 15.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -15, 15) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - + return F.linear(x, w, bias) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. with torch.no_grad(): for name, param in module.named_parameters(): if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: param.data = param.data.float() - - class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._cos_cached: Tensor | None = None self._sin_cached: Tensor | None = None - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: if ( self._cos_cached is None @@ -538,20 +457,29 @@ def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tup or self._seq_len_cached != seq_len or self._cos_cached.device != device ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] self._seq_len_cached = seq_len return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - class CausalSelfAttention(nn.Module): def __init__( self, @@ -578,45 +506,104 @@ def __init__( self.proj = CastedLinear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + # fallback to pytorch SDPA (q,k,v need to be [bsz, heads, seq, dim]) + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + y = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads)) + y = y.transpose(1, 2) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) return self.proj(y) - - +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup def __init__(self, dim: int, mlp_mult: int): super().__init__() - hidden = mlp_mult * dim + hidden = int(mlp_mult * dim) self.fc = CastedLinear(dim, hidden, bias=False) self.proj = CastedLinear(hidden, dim, bias=False) self.proj._zero_init = True - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) + # leaky_relu(0.5)^2 preserves negative gradient flow vs relu^2 + x = F.leaky_relu(self.fc(x), negative_slope=0.5) return self.proj(x.square()) - - class Block(nn.Module): def __init__( self, @@ -626,6 +613,9 @@ def __init__( mlp_mult: int, rope_base: float, qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, ): super().__init__() self.attn_norm = RMSNorm() @@ -635,16 +625,23 @@ def __init__( self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out class GPT(nn.Module): def __init__( self, @@ -659,14 +656,30 @@ def __init__( logit_softcap: float, rope_base: float, qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", ): super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection if logit_softcap <= 0.0: raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) @@ -680,65 +693,991 @@ def __init__( mlp_mult, rope_base, qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, ) for i in range(num_layers) ] ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True self._init_weights() - def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) x0 = x skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. + ve_cache: dict = {} for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) skips.append(x) for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) + logits_proj = F.linear(x_flat, self.tok_emb.weight) else: if self.lm_head is None: raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) + logits_proj = self.lm_head(x_flat) logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class LongPhraseCache: + """variable-length suffix matcher for verbatim repetition (PR #880). + probes at lengths [48,36,28,20,16] using rolling hashes.""" + PROBE_LENGTHS = [48, 36, 28, 20, 16] # full probes, stride=64 saves eval time + PRIMES = [np.uint64(p) for p in [ + 36313, 27191, 51647, 81929, 131071, 174763, 233017, 299993, 350377, + 412391, 479909, 541267, 613651, 700897, 786433, 850001, 921587, + 982451, 1048573, 1114111, 1179641, 1245169, 1310719, 1376257, + 1441793, 1507321, 1572869, 1638391, 1703933, 1769473, 1835009, + 1900543, 1966079, 2031617, 2097143, 2162689, 2228223, 2293759, + 2359291, 2424833, 2490367, 2555903, 2621431, 2686979, 2752511, + 2818049, 2883577, 2949121, + ]] # 48 primes for longest probe + BUCKETS = 4194304 + MASK = np.uint64(BUCKETS - 1) + + def __init__(self): + self.ctx_tables = {L: np.zeros(self.BUCKETS, dtype=np.uint32) for L in self.PROBE_LENGTHS} + self.full_tables = {L: np.zeros(self.BUCKETS, dtype=np.uint32) for L in self.PROBE_LENGTHS} + + def _rolling_hash(self, val_np: np.ndarray, positions: np.ndarray, length: int) -> np.ndarray: + h = np.zeros(len(positions), dtype=np.uint64) + for k in range(length): + toks = val_np[(positions - length + k).astype(np.int64)].astype(np.uint64) + h ^= toks * self.PRIMES[k] + return h + + def build_full(self, val_np: np.ndarray, log_fn=None): + """build phrase cache from all tokens.""" + n = len(val_np) - 1 + for L in self.PROBE_LENGTHS: + if n <= L: + continue + positions = np.arange(L, n, dtype=np.int64) + ctx_hash = self._rolling_hash(val_np, positions, L) + ctx_key = (ctx_hash & self.MASK).astype(np.int64) + targets = val_np[positions + 1].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * self.PRIMES[L % len(self.PRIMES)])) & self.MASK).astype(np.int64) + np.add.at(self.ctx_tables[L], ctx_key, 1) + np.add.at(self.full_tables[L], full_key, 1) + if log_fn: + log_fn(f"phrase_cache: length {L} done") + + def update(self, val_np: np.ndarray, start: int, end: int): + """incremental score-first update for a window segment.""" + for L in self.PROBE_LENGTHS: + first_valid = max(L, start) + n_pos = end - first_valid + if n_pos <= 0: + continue + positions = np.arange(first_valid, end, dtype=np.int64) + ctx_hash = self._rolling_hash(val_np, positions, L) + ctx_key = (ctx_hash & self.MASK).astype(np.int64) + targets = val_np[(positions + 1).astype(np.int64)].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * self.PRIMES[L % len(self.PRIMES)])) & self.MASK).astype(np.int64) + np.add.at(self.ctx_tables[L], ctx_key, 1) + np.add.at(self.full_tables[L], full_key, 1) + + def lookup(self, val_np: np.ndarray, positions: np.ndarray, min_count: int = 2 + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """lookup phrase matches. returns (p_phrase, has_match, match_length, ctx_counts, full_counts).""" + n_pos = len(positions) + p_phrase = np.zeros(n_pos, dtype=np.float64) + has_match = np.zeros(n_pos, dtype=np.bool_) + match_length = np.zeros(n_pos, dtype=np.int32) + ctx_counts = np.zeros(n_pos, dtype=np.float64) + full_counts = np.zeros(n_pos, dtype=np.float64) + for L in self.PROBE_LENGTHS: # longest first + valid = (positions >= L) & ~has_match + if not valid.any(): + continue + pos_valid = positions[valid] + ctx_hash = self._rolling_hash(val_np, pos_valid, L) + ctx_key = (ctx_hash & self.MASK).astype(np.int64) + targets = val_np[(pos_valid + 1).astype(np.int64)].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * self.PRIMES[L % len(self.PRIMES)])) & self.MASK).astype(np.int64) + ctx_c = self.ctx_tables[L][ctx_key] + full_c = np.minimum(self.full_tables[L][full_key], ctx_c) + eligible = (ctx_c >= min_count) & (full_c > 0) + if eligible.any(): + valid_idx = np.where(valid)[0][eligible] + p_phrase[valid_idx] = full_c[eligible].astype(np.float64) / ctx_c[eligible].astype(np.float64) + has_match[valid_idx] = True + match_length[valid_idx] = L + ctx_counts[valid_idx] = ctx_c[eligible].astype(np.float64) + full_counts[valid_idx] = full_c[eligible].astype(np.float64) + return p_phrase, has_match, match_length, ctx_counts, full_counts + + +class NgramCache: + """n-gram cache matching PR #753/#769/#779: two flat uint32 arrays per order + (ctx_counts, full_counts). hash context and full n-gram (context+target) separately.""" + PRIMES = [np.uint64(p) for p in [36313, 27191, 51647, 81929, 131071, 174763, 233017, 299993, 350377, 412391, 479909, 541267, 613651, 700897, 786433]] + + def __init__(self, max_order: int = 7, min_order: int = 2, num_buckets: int = 4194304, + min_count: int = 2, **kwargs): + self.max_order = max_order + self.min_order = min_order + self.num_buckets = num_buckets + self.min_count = min_count + self.mask = np.uint64(num_buckets - 1) + self.num_orders = max_order - min_order + 1 + # ~32MB per order (4M * 4 bytes * 2 arrays) = ~192MB for 6 orders + self.ctx_counts = [np.zeros(num_buckets, dtype=np.uint32) for _ in range(self.num_orders)] + self.full_counts = [np.zeros(num_buckets, dtype=np.uint32) for _ in range(self.num_orders)] + + def build_full(self, val_np: np.ndarray, log_fn=None): + """build complete cache from all tokens at once (for two-pass rescoring).""" + n = len(val_np) - 1 + mask = self.mask + primes = self.PRIMES + for oi in range(self.num_orders): + order = self.min_order + oi + cw = order - 1 + if n <= cw: + continue + valid_start = cw + n_pos = n - valid_start + # context hash + ctx_hash = np.zeros(n_pos, dtype=np.uint64) + for k in range(cw): + t = val_np[valid_start - cw + k:valid_start - cw + k + n_pos].astype(np.uint64) + ctx_hash ^= t * np.uint64(primes[k]) + ctx_key = (ctx_hash & mask).astype(np.int64) + # full hash + targets = val_np[valid_start + 1:valid_start + 1 + n_pos].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * np.uint64(primes[cw]))) & mask).astype(np.int64) + # bincount-based bulk add + np.add.at(self.ctx_counts[oi], ctx_key, 1) + np.add.at(self.full_counts[oi], full_key, 1) + if log_fn: + log_fn(f"ngram_build: order {order} done, {n_pos} positions") + + def lookup(self, val_np: np.ndarray, start: int, end: int) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """score positions [start, end). returns (p_ngram, has_match, matched_order, ctx_counts, full_counts).""" + seg_len = end - start + p_ngram = np.zeros(seg_len, dtype=np.float64) + has_match = np.zeros(seg_len, dtype=np.bool_) + matched_order = np.zeros(seg_len, dtype=np.int32) + ctx_counts_out = np.zeros(seg_len, dtype=np.float64) + full_counts_out = np.zeros(seg_len, dtype=np.float64) + mask = self.mask + primes = self.PRIMES + # backoff: highest order first + for oi in range(self.num_orders - 1, -1, -1): + order = self.min_order + oi + cw = order - 1 + first_valid = max(cw, start) - start + n_pos = seg_len - first_valid + if n_pos <= 0: + continue + abs_s = start + first_valid + ctx_hash = np.zeros(n_pos, dtype=np.uint64) + for k in range(cw): + t = val_np[abs_s - cw + k:abs_s - cw + k + n_pos].astype(np.uint64) + ctx_hash ^= t * np.uint64(primes[k]) + ctx_key = (ctx_hash & mask).astype(np.int64) + targets = val_np[abs_s + 1:abs_s + 1 + n_pos].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * np.uint64(primes[cw]))) & mask).astype(np.int64) + ctx_c = self.ctx_counts[oi][ctx_key] + full_c = self.full_counts[oi][full_key] + valid = (ctx_c >= self.min_count) & (full_c > 0) & ~has_match[first_valid:first_valid + n_pos] + if valid.any(): + idx = np.nonzero(valid)[0] + capped_full = np.minimum(full_c[idx], ctx_c[idx]).astype(np.float64) + p_ngram[first_valid + idx] = capped_full / ctx_c[idx].astype(np.float64) + has_match[first_valid + idx] = True + matched_order[first_valid + idx] = order + ctx_counts_out[first_valid + idx] = ctx_c[idx].astype(np.float64) + full_counts_out[first_valid + idx] = capped_full + return p_ngram, has_match, matched_order, ctx_counts_out, full_counts_out + + def lookup_hierarchical(self, val_np: np.ndarray, start: int, end: int, concentration: float, base_p: np.ndarray) -> np.ndarray: + """hierarchical Dirichlet mixing (CTW-style, PR #900 / Teh 2006). + for each position, iterate from lowest to highest order. each order's posterior + becomes the next order's prior: p = (c * p_prev + full_c) / (c + ctx_c). + returns the final blended probability array.""" + seg_len = end - start + blended = base_p.copy() + mask = self.mask + primes = self.PRIMES + # iterate lowest to highest order — each posterior becomes next prior + for oi in range(self.num_orders): + order = self.min_order + oi + cw = order - 1 + first_valid = max(cw, start) - start + n_pos = seg_len - first_valid + if n_pos <= 0: + continue + abs_s = start + first_valid + ctx_hash = np.zeros(n_pos, dtype=np.uint64) + for k in range(cw): + t = val_np[abs_s - cw + k:abs_s - cw + k + n_pos].astype(np.uint64) + ctx_hash ^= t * np.uint64(primes[k]) + ctx_key = (ctx_hash & mask).astype(np.int64) + targets = val_np[abs_s + 1:abs_s + 1 + n_pos].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * np.uint64(primes[cw]))) & mask).astype(np.int64) + ctx_c = self.ctx_counts[oi][ctx_key] + full_c = np.minimum(self.full_counts[oi][full_key], ctx_c) + valid = (ctx_c >= self.min_count) & (full_c > 0) + if valid.any(): + idx = np.nonzero(valid)[0] + fc = full_c[idx].astype(np.float64) + cc = ctx_c[idx].astype(np.float64) + prev_p = blended[first_valid + idx] + blended[first_valid + idx] = (concentration * prev_p + fc) / (concentration + cc) + return blended + + def update(self, val_np: np.ndarray, start: int, end: int) -> None: + """update cache with tokens from [start, end).""" + seg_len = end - start + mask = self.mask + primes = self.PRIMES + for oi in range(self.num_orders): + order = self.min_order + oi + cw = order - 1 + first_valid = max(cw, start) - start + n_pos = seg_len - first_valid + if n_pos <= 0: + continue + abs_s = start + first_valid + ctx_hash = np.zeros(n_pos, dtype=np.uint64) + for k in range(cw): + t = val_np[abs_s - cw + k:abs_s - cw + k + n_pos].astype(np.uint64) + ctx_hash ^= t * np.uint64(primes[k]) + ctx_key = (ctx_hash & mask).astype(np.int64) + targets = val_np[abs_s + 1:abs_s + 1 + n_pos].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * np.uint64(primes[cw]))) & mask).astype(np.int64) + np.add.at(self.ctx_counts[oi], ctx_key, 1) + np.add.at(self.full_counts[oi], full_key, 1) + + +def build_ngram_from_shards(data_path: str, max_order: int = 13, min_order: int = 2, + num_buckets: int = 524288, max_shards: int = 0, + shard_list: list | None = None, log_fn=None) -> dict: + """build n-gram hash tables from training shards. + returns dict of torch tensors to store in artifact.""" + if shard_list is not None: + shard_files = shard_list + else: + shard_pattern = os.path.join(data_path, "fineweb_train_*.bin") + shard_files = sorted(glob.glob(shard_pattern)) + if not shard_files: + raise FileNotFoundError(f"No training shards: {shard_pattern}") + if max_shards > 0: + shard_files = shard_files[:max_shards] + num_orders = max_order - min_order + 1 + mask = np.uint64(num_buckets - 1) + primes = NgramCache.PRIMES + # use uint32 during building, convert to uint16 for storage + ctx_counts = [np.zeros(num_buckets, dtype=np.uint32) for _ in range(num_orders)] + full_counts = [np.zeros(num_buckets, dtype=np.uint32) for _ in range(num_orders)] + total_tokens = 0 + for si, shard_file in enumerate(shard_files): + t_shard = time.perf_counter() + header = np.fromfile(shard_file, dtype=" tuple[float, float]: + """sliding window eval with n-gram cache, matching PR #753/#769/#779. + score-first: for each window, compute neural logits, lookup cache, mix, then update. + if dirichlet_concentration > 0, uses Dirichlet-Multinomial posterior predictive mixing + (PR #900 / CTW / Teh 2006) instead of linear interpolation.""" + total_tokens = val_tokens.numel() - 1 + seq_len = eval_seq_len + vocab_size = args.vocab_size + val_np = val_tokens[:total_tokens + 1].numpy() + adaptive = ent_range > 0 + + # distribute windows across ranks + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + model.eval() + compiled_logits = torch.compile(model.forward_logits, dynamic=False, fullgraph=True) + cache = NgramCache(max_order=ngram_order, min_order=ngram_min_order, + num_buckets=ngram_buckets, min_count=ngram_min_count) + + # load pre-warmed n-gram tables from artifact if available + if prewarmed_ngram is not None: + meta = prewarmed_ngram["meta"] + art_max_order = int(meta[0]) + art_min_order = int(meta[1]) + art_buckets = int(meta[2]) + if art_buckets == ngram_buckets: + for oi in range(cache.num_orders): + order = cache.min_order + oi + ctx_key = f"ctx_{order}" + full_key = f"full_{order}" + if ctx_key in prewarmed_ngram and full_key in prewarmed_ngram: + cache.ctx_counts[oi] = prewarmed_ngram[ctx_key].numpy().astype(np.uint32).copy() + cache.full_counts[oi] = prewarmed_ngram[full_key].numpy().astype(np.uint32).copy() + if log_fn: + log_fn(f"prewarmed: loaded training n-gram tables (orders {art_min_order}-{art_max_order}, {art_buckets} buckets)") + else: + if log_fn: + log_fn(f"prewarmed: SKIPPED (bucket mismatch: artifact={art_buckets} vs eval={ngram_buckets})") + + # phrase cache (single-pass score-first, same as n-gram) + phrase_cache = LongPhraseCache() + + # prefill: pre-warm both caches with all tokens before this rank's first window + if my_windows: + prefill_end = my_windows[0] + if prefill_end > 0: + chunk_sz = 65536 + for pf_start in range(0, prefill_end, chunk_sz): + pf_end = min(pf_start + chunk_sz, prefill_end) + cache.update(val_np, pf_start, pf_end) + phrase_cache.update(val_np, pf_start, pf_end) + if log_fn: + log_fn(f"prefill: warmed caches with {prefill_end} tokens for rank {rank}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + loss_sum_neural = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + ngram_hits = 0 + ngram_total = 0 + base_bytes_cpu = base_bytes_lut.cpu() + has_space_cpu = has_leading_space_lut.cpu() + is_boundary_cpu = is_boundary_token_lut.cpu() + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + logits_f = logits.float() + probs_all = torch.softmax(logits_f, dim=-1) + log_probs_all = torch.log_softmax(logits_f, dim=-1) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + abs_start = ws + s + abs_end = ws + wlen + + # neural prob of target + seg_targets = y_batch[i, s:wlen] + model_p = probs_all[i, s:wlen].gather(1, seg_targets.unsqueeze(1)).squeeze(1).cpu().numpy().astype(np.float64) + seg_nll_neural = F.cross_entropy(logits_f[i, s:wlen], seg_targets, reduction='none').cpu().numpy().astype(np.float64) + + # n-gram: score-first (lookup THEN update) + if dirichlet_concentration > 0: + # hierarchical Dirichlet CTW mixing (PR #943 approach) + blended_p = cache.lookup_hierarchical(val_np, abs_start, abs_end, dirichlet_concentration, model_p) + # track hits for logging + _, has_match, matched_order, _, _ = cache.lookup(val_np, abs_start, abs_end) + else: + p_ngram, has_match, matched_order, _, _ = cache.lookup(val_np, abs_start, abs_end) + # legacy linear interpolation with per-order entropy thresholds + blended_p = model_p.copy() + if has_match.any(): + m = has_match + ent_centers = {7: 3.0, 6: 3.2, 5: 3.5, 4: 3.8, 3: 4.2, 2: 4.5, 8: 2.8, 9: 2.6} + if adaptive: + seg_ent = (-(probs_all[i, s:wlen] * log_probs_all[i, s:wlen]).sum(dim=-1)).cpu().numpy() + alpha = np.full(seg_len, fixed_alpha, dtype=np.float64) + for pos_idx in range(seg_len): + if has_match[pos_idx]: + order = int(matched_order[pos_idx]) + center = ent_centers.get(order, ent_thresh) + sig = 1.0 / (1.0 + np.exp(-ent_scale * (seg_ent[pos_idx] - center))) + alpha[pos_idx] = ent_base + ent_range * sig + else: + alpha = np.full(seg_len, fixed_alpha, dtype=np.float64) + blended_p[m] = (1.0 - alpha[m]) * model_p[m] + alpha[m] * p_ngram[m] + cache.update(val_np, abs_start, abs_end) + + # phrase cache: lookup THEN update (score-first) + positions = np.arange(abs_start, abs_end, dtype=np.int64) + p_phrase, phrase_match, phrase_len, phr_ctx_c, phr_full_c = phrase_cache.lookup(val_np, positions, min_count=2) + phrase_cache.update(val_np, abs_start, abs_end) + if phrase_match.any(): + pm = phrase_match + if dirichlet_concentration > 0: + # phrase Dirichlet with lower concentration (phrases are more specific) + phr_conc = dirichlet_concentration * 0.2 + blended_p[pm] = (phr_conc * blended_p[pm] + phr_full_c[pm]) / (phr_conc + phr_ctx_c[pm]) + else: + pa = 0.3 + (0.95 - 0.3) * (phrase_len[phrase_match].astype(np.float64) - 16.0) / 32.0 + pa = np.clip(pa, 0.0, 0.95) + blended_p[pm] = (1.0 - pa) * blended_p[pm] + pa * p_phrase[pm] + + blended_p = np.maximum(blended_p, 1e-30) + seg_nll = -np.log(blended_p) + + loss_sum += float(seg_nll.sum()) + loss_sum_neural += float(seg_nll_neural.sum()) + token_count += float(seg_len) + ngram_hits += int(has_match.sum()) + ngram_total += seg_len + + # bytes + tgt_ids = seg_targets.cpu() + prev_ids = x_batch[i, s:wlen].cpu() + tb = base_bytes_cpu[tgt_ids].to(torch.float64) + tb += (has_space_cpu[tgt_ids] & ~is_boundary_cpu[prev_ids]).to(torch.float64) + byte_count += float(tb.sum()) + + if dist.is_available() and dist.is_initialized(): + for t in [loss_sum, loss_sum_neural, token_count, byte_count]: + dist.all_reduce(t, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_loss_neural = (loss_sum_neural / token_count).item() + bpb = (val_loss / math.log(2.0)) * (token_count.item() / byte_count.item()) + bpb_neural = (val_loss_neural / math.log(2.0)) * (token_count.item() / byte_count.item()) + hit_rate = ngram_hits / max(ngram_total, 1) * 100 + if log_fn: + log_fn(f"neural_only_sw val_loss:{val_loss_neural:.4f} val_bpb:{bpb_neural:.4f}") + log_fn(f"ngram_hit_rate:{hit_rate:.1f}% ({ngram_hits}/{ngram_total})") + if dirichlet_concentration > 0: + log_fn(f"mixing:hierarchical_dirichlet concentration={dirichlet_concentration:.2f} phrase_probes={LongPhraseCache.PROBE_LENGTHS}") + else: + log_fn(f"mixing:linear_interp adaptive={adaptive}") + model.train() + return val_loss, bpb -# ----------------------------- -# TRAINING -# ----------------------------- +def eval_ngram_two_pass( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int, + stride: int, + batch_seqs: int = 32, + ngram_order: int = 9, + ngram_min_order: int = 2, + ngram_buckets: int = 16777216, + ngram_min_count: int = 2, + ent_base: float = 0.05, + ent_range: float = 0.55, + ent_scale: float = 2.0, + ent_thresh: float = 4.0, + dirichlet_concentration: float = 0.0, + prewarmed_ngram: dict | None = None, + log_fn=None, +) -> tuple[float, float]: + """two-pass n-gram eval (PR #870/#943 approach). + pass 1: store model_p + entropy per scored position. + build full cache from all val tokens (+ merge with pre-warmed artifact tables). + pass 2: rescore all positions with full cache using hierarchical Dirichlet.""" + total_tokens = val_tokens.numel() - 1 + seq_len = eval_seq_len + val_np = val_tokens[:total_tokens + 1].numpy() + ent_centers = {15: 1.8, 14: 1.9, 13: 2.0, 12: 2.1, 11: 2.2, 10: 2.4, + 9: 2.6, 8: 2.8, 7: 3.0, 6: 3.2, 5: 3.5, 4: 3.8, 3: 4.2, 2: 4.5} + + # distribute windows + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + model.eval() + compiled_logits = torch.compile(model.forward_logits, dynamic=False, fullgraph=True) + base_bytes_cpu = base_bytes_lut.cpu() + has_space_cpu = has_leading_space_lut.cpu() + is_boundary_cpu = is_boundary_token_lut.cpu() + # pass 1: store model_p, entropy, bytes per scored position + stored_positions = [] + stored_model_p = [] + stored_entropy = [] + stored_bytes = [] + + if log_fn: + log_fn(f"two_pass: pass 1 — storing model predictions for {len(my_windows)} windows") + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + logits_f = logits.float() + probs_all = torch.softmax(logits_f, dim=-1) + log_probs_all = torch.log_softmax(logits_f, dim=-1) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_targets = y_batch[i, s:wlen] + model_p = probs_all[i, s:wlen].gather(1, seg_targets.unsqueeze(1)).squeeze(1).cpu().numpy().astype(np.float64) + seg_ent = (-(probs_all[i, s:wlen] * log_probs_all[i, s:wlen]).sum(dim=-1)).cpu().numpy().astype(np.float64) + # positions (global target token indices) + positions = np.arange(ws + s, ws + wlen, dtype=np.int64) + # bytes + tgt_ids = seg_targets.cpu() + prev_ids = x_batch[i, s:wlen].cpu() + tb = base_bytes_cpu[tgt_ids].to(torch.float64) + tb += (has_space_cpu[tgt_ids] & ~is_boundary_cpu[prev_ids]).to(torch.float64) + + stored_positions.append(positions) + stored_model_p.append(model_p) + stored_entropy.append(seg_ent) + stored_bytes.append(tb.numpy()) + + # concatenate all stored data + all_positions = np.concatenate(stored_positions) + all_model_p = np.concatenate(stored_model_p) + all_entropy = np.concatenate(stored_entropy) + all_bytes = np.concatenate(stored_bytes) + + if log_fn: + neural_loss = -np.log(np.maximum(all_model_p, 1e-30)).mean() + neural_bpb = (neural_loss / math.log(2.0)) * (len(all_model_p) / all_bytes.sum()) + log_fn(f"two_pass: pass 1 done, {len(all_model_p)} positions, neural_bpb={neural_bpb:.4f}") + + # build full cache from ALL val tokens (+ merge with pre-warmed artifact) + if log_fn: + log_fn(f"two_pass: building full cache ({total_tokens} tokens, {ngram_order}-gram, {ngram_buckets} buckets)") + cache = NgramCache(max_order=ngram_order, min_order=ngram_min_order, + num_buckets=ngram_buckets, min_count=ngram_min_count) + # load pre-warmed tables from artifact if available + if prewarmed_ngram is not None: + meta = prewarmed_ngram["meta"] + art_buckets = int(meta[2]) + if art_buckets == ngram_buckets: + for oi in range(cache.num_orders): + order = cache.min_order + oi + ctx_key = f"ctx_{order}" + full_key = f"full_{order}" + if ctx_key in prewarmed_ngram: + cache.ctx_counts[oi] = prewarmed_ngram[ctx_key].numpy().astype(np.uint32).copy() + cache.full_counts[oi] = prewarmed_ngram[full_key].numpy().astype(np.uint32).copy() + if log_fn: + log_fn(f"two_pass: pre-warmed with training n-gram tables") + cache.build_full(val_np, log_fn=log_fn) # add val tokens ON TOP of pre-warmed + + # pass 2: rescore all stored positions using full cache + if log_fn: + log_fn(f"two_pass: pass 2 — rescoring {len(all_positions)} positions with full cache") + + # pass 2: hierarchical Dirichlet CTW scoring over all positions + n_pos = len(all_positions) + conc = dirichlet_concentration if dirichlet_concentration > 0 else 5.0 + blended_p = all_model_p.copy() + mask = cache.mask + primes = cache.PRIMES + has_match = np.zeros(n_pos, dtype=np.bool_) + + # iterate lowest to highest order — hierarchical CTW + for oi in range(cache.num_orders): + order = cache.min_order + oi + cw = order - 1 + valid = (all_positions >= cw) + if not valid.any(): + continue + pos_valid = all_positions[valid] + ctx_hash = np.zeros(len(pos_valid), dtype=np.uint64) + for k in range(cw): + t = val_np[(pos_valid - cw + k).astype(np.int64)].astype(np.uint64) + ctx_hash ^= t * np.uint64(primes[k]) + ctx_key = (ctx_hash & mask).astype(np.int64) + targets = val_np[(pos_valid + 1).astype(np.int64)].astype(np.uint64) + full_key = ((ctx_hash ^ (targets * np.uint64(primes[cw]))) & mask).astype(np.int64) + ctx_c = cache.ctx_counts[oi][ctx_key] + full_c = np.minimum(cache.full_counts[oi][full_key], ctx_c) + eligible = (ctx_c >= ngram_min_count) & (full_c > 0) + if eligible.any(): + valid_idx = np.where(valid)[0][eligible] + fc = full_c[eligible].astype(np.float64) + cc = ctx_c[eligible].astype(np.float64) + prev_p = blended_p[valid_idx] + blended_p[valid_idx] = (conc * prev_p + fc) / (conc + cc) + has_match[valid_idx] = True + + # phrase cache: second layer of blending for long verbatim repetitions + if log_fn: + log_fn(f"two_pass: building phrase cache...") + phrase_cache = LongPhraseCache() + phrase_cache.build_full(val_np, log_fn=log_fn) + p_phrase, phrase_match, phrase_len, _, _ = phrase_cache.lookup(val_np, all_positions, min_count=2) + if phrase_match.any(): + # alpha based on match length: longer = higher trust (up to 0.99 for 48-token match) + base_alpha = 0.3 + phrase_alpha = base_alpha + (0.99 - base_alpha) * (phrase_len[phrase_match].astype(np.float64) - 16.0) / 32.0 + phrase_alpha = np.clip(phrase_alpha, 0.0, 0.99) + pm = phrase_match + blended_p[pm] = (1.0 - phrase_alpha) * blended_p[pm] + phrase_alpha * p_phrase[pm] + if log_fn: + log_fn(f"phrase_cache: {phrase_match.sum()} matches, mean_len={phrase_len[phrase_match].mean():.1f}") + + blended_p = np.maximum(blended_p, 1e-30) + blended_nll = -np.log(blended_p) + + # aggregate + loss_sum_t = torch.tensor(float(blended_nll.sum()), device=device, dtype=torch.float64) + token_count_t = torch.tensor(float(n_pos), device=device, dtype=torch.float64) + byte_count_t = torch.tensor(float(all_bytes.sum()), device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_t, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum_t / token_count_t).item() + bpb = (val_loss / math.log(2.0)) * (token_count_t.item() / byte_count_t.item()) + hit_rate = has_match.sum() / max(n_pos, 1) * 100 + if log_fn: + log_fn(f"two_pass: hit_rate={hit_rate:.1f}%, val_loss={val_loss:.4f}, val_bpb={bpb:.4f}") + model.train() + return val_loss, bpb + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 15) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out def main() -> None: global zeropower_via_newtonschulz5 - code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -757,23 +1696,18 @@ def main() -> None: dist.init_process_group(backend="nccl", device_id=device) dist.barrier() master_process = rank == 0 - - # Fast math knobs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) enable_flash_sdp(True) enable_mem_efficient_sdp(False) enable_math_sdp(False) - logfile = None if master_process: os.makedirs("logs", exist_ok=True) logfile = f"logs/{args.run_id}.txt" print(logfile) - def log0(msg: str, console: bool = True) -> None: if not master_process: return @@ -782,7 +1716,6 @@ def log0(msg: str, console: bool = True) -> None: if logfile is not None: with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) - log0(code, console=False) log0("=" * 100, console=False) log0(f"Running Python {sys.version}", console=False) @@ -792,16 +1725,10 @@ def log0(msg: str, console: bool = True) -> None: console=False, ) log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) - if not args.tokenizer_path.endswith(".model"): raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) @@ -811,18 +1738,16 @@ def log0(msg: str, console: bool = True) -> None: ) dataset_dir = Path(args.data_path).resolve() actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( sp, args.vocab_size, device ) log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - + CastedLinear._qat_enabled = args.qat_enabled base_model = GPT( vocab_size=args.vocab_size, num_layers=args.num_layers, @@ -835,6 +1760,17 @@ def log0(msg: str, console: bool = True) -> None: logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): @@ -842,18 +1778,14 @@ def log0(msg: str, console: bool = True) -> None: restore_low_dim_params_to_fp32(base_model) compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam block_named_params = list(base_model.blocks.named_parameters()) matrix_params = [ p for name, p in block_named_params if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) scalar_params = [ p for name, p in block_named_params @@ -861,11 +1793,27 @@ def log0(msg: str, console: bool = True) -> None: ] if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.adam_wd, fused=True, ) optimizer_muon = Muon( @@ -873,13 +1821,15 @@ def log0(msg: str, console: bool = True) -> None: lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, ) for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( + optimizer_scalar = torch.optim.AdamW( [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.adam_wd, fused=True, ) optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] @@ -891,9 +1841,12 @@ def log0(msg: str, console: bool = True) -> None: fused=True, ) optimizers.insert(1, optimizer_head) - n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") @@ -908,19 +1861,11 @@ def log0(msg: str, console: bool = True) -> None: f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" ) log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all() -> None: for opt in optimizers: opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - def lr_mul(step: int, elapsed_ms: float) -> float: if args.warmdown_iters <= 0: return 1.0 @@ -931,9 +1876,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: warmdown_ms = args.warmdown_iters * step_ms remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. if args.warmup_steps > 0: initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] @@ -959,20 +1901,17 @@ def lr_mul(step: int, elapsed_ms: float) -> float: if distributed: model.require_backward_grad_sync = True train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 training_time_ms = 0.0 stop_after_step: int | None = None torch.cuda.synchronize() t0 = time.perf_counter() - step = 0 while True: last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) if should_validate: torch.cuda.synchronize() @@ -995,7 +1934,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) torch.cuda.synchronize() t0 = time.perf_counter() - if last_step: if stop_after_step is not None and step < args.iterations: log0( @@ -1003,9 +1941,11 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"step:{step}/{args.iterations}" ) break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") zero_grad_all() train_loss = torch.zeros((), device=device) for micro_step in range(grad_accum_steps): @@ -1017,24 +1957,33 @@ def lr_mul(step: int, elapsed_ms: float) -> float: train_loss += loss.detach() (loss * grad_scale).backward() train_loss /= grad_accum_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum for group in optimizer_muon.param_groups: group["momentum"] = muon_momentum - for opt in optimizers: for group in opt.param_groups: group["lr"] = group["base_lr"] * scale - if args.grad_clip_norm > 0: torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) for opt in optimizers: opt.step() zero_grad_all() - + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) step += 1 approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 should_log_train = ( args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) @@ -1044,8 +1993,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" ) - - # Needed to sync whether we've reached the wallclock cap. reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms if distributed and max_wallclock_ms is not None: reached_cap_tensor = torch.tensor(int(reached_cap), device=device) @@ -1053,74 +2000,301 @@ def lr_mul(step: int, elapsed_ms: float) -> float: reached_cap = bool(reached_cap_tensor.item()) if stop_after_step is None and reached_cap: stop_after_step = step - log0( f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. + # Apply EMA weights (better than SWA alone per PR#401) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + # skip diagnostic eval to save eval-time budget + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + # build packed n-gram tables from training data (all ranks in parallel) + ngram_artifact_enabled = bool(int(os.environ.get("NGRAM_ARTIFACT", "1"))) + packed_ngram = None + if ngram_artifact_enabled: + t_build = time.perf_counter() + ngram_art_order = int(os.environ.get("NGRAM_ART_ORDER", "13")) + ngram_art_buckets = int(os.environ.get("NGRAM_ART_BUCKETS", "131072")) # 128K — use artifact headroom + ngram_art_max_shards = int(os.environ.get("NGRAM_ART_MAX_SHARDS", "80")) + # each rank builds from a subset of shards + all_shards = sorted(glob.glob(os.path.join(args.data_path, "fineweb_train_*.bin"))) + if ngram_art_max_shards > 0: + all_shards = all_shards[:ngram_art_max_shards] + my_shards = [s for i, s in enumerate(all_shards) if i % world_size == rank] + log0(f"ngram_artifact: building order={ngram_art_order}, buckets={ngram_art_buckets}, shards={len(all_shards)} (rank {rank}: {len(my_shards)})") + local_packed = build_ngram_from_shards( + args.data_path, max_order=ngram_art_order, min_order=2, + num_buckets=ngram_art_buckets, max_shards=0, + log_fn=log0 if master_process else None, + shard_list=my_shards, + ) + # all-reduce counts across ranks (convert to int32 for reduction, then back to uint16) + if distributed: + for key in list(local_packed.keys()): + if key == "meta": + continue + t = local_packed[key].to(torch.int32).to(device) + dist.all_reduce(t, op=dist.ReduceOp.SUM) + local_packed[key] = t.cpu().clamp(max=65535).to(torch.uint16) + packed_ngram = local_packed + log0(f"ngram_artifact: built in {time.perf_counter() - t_build:.0f}s") if master_process: - torch.save(base_model.state_dict(), "final_model.pt") + torch.save(export_sd, "final_model.pt") model_bytes = os.path.getsize("final_model.pt") code_bytes = len(code.encode("utf-8")) log0(f"Serialized model: {model_bytes} bytes") log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + # pack model + n-gram tables into single artifact + artifact_dict = {"w": quant_result, "m": quant_meta} + if packed_ngram is not None: + artifact_dict["ngram"] = packed_ngram quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) + torch.save(artifact_dict, quant_buf) quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) if master_process: - with open("final_model.int8.ptz", "wb") as f: + with open("final_model.int6.ptz", "wb") as f: f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") + quant_file_bytes = len(quant_blob) code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + if packed_ngram is not None: + ngram_bytes = sum(v.nbytes for v in packed_ngram.values()) + log0(f"ngram_artifact: raw={ngram_bytes} bytes ({ngram_bytes/1e6:.1f}MB)") if distributed: dist.barrier() - with open("final_model.int8.ptz", "rb") as f: + with open("final_model.int6.ptz", "rb") as f: quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + # eval_model is used directly by n-gram eval (which compiles internally) + + # TTT: preeval (bulk train then score) or legal (score-first, chunk by chunk) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 0)) + ttt_lr = float(os.environ.get("TTT_LR", 0.0005)) + ttt_mode = os.environ.get("TTT_MODE", "preeval") # "preeval" or "legal" + if ttt_epochs > 0 and ttt_mode == "preeval": + torch.cuda.synchronize() + t_ttt = time.perf_counter() + log0(f"ttt: starting {ttt_epochs} epochs, lr={ttt_lr}, cosine+perlayer") + # per-layer LR groups: 3x for MLP output projections, 0.5x for MLP input + proj_params, fc_params, other_params = [], [], [] + for name, p in eval_model.named_parameters(): + p.requires_grad_(True) + if "mlp.proj" in name: + proj_params.append(p) + elif "mlp.fc" in name: + fc_params.append(p) + else: + other_params.append(p) + ttt_opt = torch.optim.AdamW([ + {"params": proj_params, "lr": ttt_lr * 3.0}, + {"params": fc_params, "lr": ttt_lr * 0.5}, + {"params": other_params, "lr": ttt_lr}, + ], weight_decay=0.0) + total_val = val_tokens.numel() - 1 + ttt_batch = 32 + rank_tokens = total_val // world_size + rank_start = rank * rank_tokens + rank_end = rank_start + rank_tokens + steps_per_epoch = max(1, (rank_end - rank_start - args.train_seq_len) // (ttt_batch * args.train_seq_len)) + total_steps = ttt_epochs * steps_per_epoch + global_step = 0 + eval_model.train() + for ep in range(ttt_epochs): + ep_loss, ep_steps = 0.0, 0 + for bs in range(rank_start, rank_end - args.train_seq_len, ttt_batch * args.train_seq_len): + be = min(bs + ttt_batch * args.train_seq_len + 1, rank_end + 1) + local = val_tokens[bs:be].to(device=device, dtype=torch.int64) + n = (local.numel() - 1) // args.train_seq_len + if n == 0: + continue + x = local[:n * args.train_seq_len].reshape(n, args.train_seq_len) + y = local[1:n * args.train_seq_len + 1].reshape(n, args.train_seq_len) + # cosine LR schedule + progress = global_step / max(total_steps, 1) + cos_mul = 0.5 * (1.0 + math.cos(math.pi * progress)) + for g in ttt_opt.param_groups: + g["lr"] = g.get("initial_lr", g["lr"]) * cos_mul + if global_step == 0: + for g in ttt_opt.param_groups: + g["initial_lr"] = g["lr"] + ttt_opt.zero_grad() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = eval_model(x, y) + loss.backward() + # sync gradients across ranks + if distributed: + for p in eval_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(eval_model.parameters(), 1.0) + ttt_opt.step() + ep_loss += loss.item() + ep_steps += 1 + global_step += 1 + if master_process and (ep + 1) % 5 == 0: + log0(f"ttt_epoch:{ep + 1}/{ttt_epochs} avg_loss:{ep_loss / max(ep_steps, 1):.4f}") + del ttt_opt + torch.cuda.empty_cache() + torch.cuda.synchronize() + log0(f"ttt: completed in {1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + + # legal score-first TTT: score chunk, then train on scored tokens + if ttt_epochs > 0 and ttt_mode == "legal": + torch.cuda.synchronize(); t_ttt = time.perf_counter() + sl = effective_eval_seq_len; st = args.eval_stride if args.eval_stride > 0 else sl; scl = min(st, sl) + for p in eval_model.parameters(): p.requires_grad_(False) + nb = len(eval_model.blocks) if hasattr(eval_model, 'blocks') else 0 + tp = [] + for nm, p in eval_model.named_parameters(): + bi = next((i for i in range(nb) if f"blocks.{i}." in nm), -1) + if bi >= nb - 2 or any(k in nm for k in ("norm","scale","q_gain","lm_head","tok_emb","smear","bigram")): + p.requires_grad_(True); tp.append(p) + to = torch.optim.AdamW(tp, lr=ttt_lr * 0.2, weight_decay=0.0) + log0(f"legal_ttt: {len(tp)} params, {ttt_epochs}ep/chunk") + tot = val_tokens.numel() - 1; cs = 65536 + ns, nc, nb2 = torch.zeros((),dtype=torch.float64,device=device), torch.zeros((),dtype=torch.float64,device=device), torch.zeros((),dtype=torch.float64,device=device) + for c0 in range(0, tot - sl + 1, cs): + eval_model.eval() + with torch.inference_mode(): + for ws in range(c0, min(c0+cs, tot-sl+1), st*world_size): + s = ws + rank*st + if s+sl > tot: continue + x = val_tokens[s:s+sl].to(device=device,dtype=torch.int64).unsqueeze(0) + y = val_tokens[s+1:s+sl+1].to(device=device,dtype=torch.int64).unsqueeze(0) + with torch.autocast(device_type="cuda",dtype=torch.bfloat16,enabled=True): + lo = eval_model.forward_logits(x) if hasattr(eval_model,'forward_logits') else None + if lo is not None: + sf = sl-scl; lt = lo[:,sf:,:].reshape(-1,lo.size(-1)).float(); tt = y[:,sf:].reshape(-1) + ns += F.cross_entropy(lt,tt,reduction="sum").to(torch.float64); nc += scl + pr,tg = x[:,sf:].reshape(-1), tt + tb = base_bytes_lut[tg].to(torch.int16) + (has_leading_space_lut[tg]&~is_boundary_token_lut[pr]).to(torch.int16) + nb2 += tb.to(torch.float64).sum() + eval_model.train() + ct = val_tokens[c0:min(c0+cs+sl,tot+1)].to(device=device,dtype=torch.int64) + nq = (ct.numel()-1)//sl + if nq > 0: + for _ in range(ttt_epochs): + xc,yc = ct[:nq*sl].reshape(nq,sl), ct[1:nq*sl+1].reshape(nq,sl) + for bi in range(0,nq,4): + xb,yb = xc[bi:bi+4], yc[bi:bi+4] + if xb.shape[0]==0: continue + to.zero_grad() + with torch.autocast(device_type="cuda",dtype=torch.bfloat16,enabled=True): l=eval_model(xb,yb) + l.backward(); to.step() + if distributed: + for t in (ns,nc,nb2): dist.all_reduce(t, op=dist.ReduceOp.SUM) + if nc.item()>0: + ll=ns.item()/nc.item(); bb=float(ll/math.log(2.0)*nc.item()/nb2.item()) + log0(f"legal_ttt val_loss:{ll:.4f} val_bpb:{bb:.4f} time:{1000*(time.perf_counter()-t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ll:.8f} val_bpb:{bb:.8f}") + del to; torch.cuda.empty_cache() + + # load pre-warmed n-gram tables from artifact (if present) + prewarmed_ngram = quant_state.get("ngram", None) + if prewarmed_ngram is not None: + meta = prewarmed_ngram["meta"] + log0(f"ngram_artifact: loaded pre-warmed tables, orders {int(meta[1])}-{int(meta[0])}, buckets={int(meta[2])}") + + # n-gram cache eval (includes sliding window — replaces standalone sw eval) + ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "1"))) + sw_seq_len = effective_eval_seq_len + if ngram_enabled: + ngram_order = int(os.environ.get("NGRAM_ORDER", "13")) # match artifact order + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "2")) + # use artifact bucket count if available, otherwise default + art_buckets = int(prewarmed_ngram["meta"][2]) if prewarmed_ngram is not None else 4194304 + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", str(art_buckets))) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", "2")) + ngram_alpha = float(os.environ.get("NGRAM_ALPHA", "0.2")) + ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.05")) + ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.90")) + ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", "2.0")) + ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", "4.0")) + dirichlet_conc = float(os.environ.get("DIRICHLET_CONCENTRATION", "5.0")) + torch.cuda.synchronize() + t_ngram = time.perf_counter() + ngram_two_pass = bool(int(os.environ.get("NGRAM_TWO_PASS", "1"))) # two-pass full rescore (PR #943 approach) + log0(f"ngram_eval: order={ngram_order} min_order={ngram_min_order} buckets={ngram_buckets} two_pass={ngram_two_pass} dirichlet={dirichlet_conc}") + if ngram_two_pass: + ng_val_loss, ng_val_bpb = eval_ngram_two_pass( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=sw_seq_len if args.eval_stride > 0 else effective_eval_seq_len, + stride=args.eval_stride if args.eval_stride > 0 else effective_eval_seq_len, + ngram_order=ngram_order, ngram_min_order=ngram_min_order, + ngram_buckets=ngram_buckets, + ngram_min_count=ngram_min_count, + ent_base=ngram_ent_base, ent_range=ngram_ent_range, + ent_scale=ngram_ent_scale, ent_thresh=ngram_ent_thresh, + dirichlet_concentration=dirichlet_conc, + prewarmed_ngram=prewarmed_ngram, + log_fn=log0, + ) + else: + ng_val_loss, ng_val_bpb = eval_val_ngram( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=sw_seq_len if args.eval_stride > 0 else effective_eval_seq_len, + stride=args.eval_stride if args.eval_stride > 0 else effective_eval_seq_len, + ngram_order=ngram_order, ngram_min_order=ngram_min_order, + ngram_buckets=ngram_buckets, ngram_min_count=ngram_min_count, + fixed_alpha=ngram_alpha, + ent_base=ngram_ent_base, ent_range=ngram_ent_range, + dirichlet_concentration=dirichlet_conc, + prewarmed_ngram=prewarmed_ngram, + ent_scale=ngram_ent_scale, ent_thresh=ngram_ent_thresh, + log_fn=log0, + ) + torch.cuda.synchronize() + log0(f"ngram_eval val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} eval_time:{1000.0*(time.perf_counter()-t_ngram):.0f}ms") + log0(f"ngram_eval_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + else: + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0(f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} stride:{args.eval_stride} eval_time:{1000.0*(time.perf_counter()-t_slide):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") if distributed: dist.destroy_process_group() - - if __name__ == "__main__": main()