diff --git a/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/README.md b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/README.md new file mode 100644 index 000000000..ede692f33 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/README.md @@ -0,0 +1,47 @@ +# QAT + Neural Cache + LoRA TTT (Non-Record Submission) + +**val_bpb: 1.4245** (sliding window, post int5/int6+zstd quantization roundtrip, 1 seed) + +This is a non-record submission exploring three eval-time techniques stacked on the current #1 training recipe. The QAT implementation has a bug (quantization penalty is ~0.25 BPB instead of expected ~0.02), making this run non-competitive. Submitting for transparency and to document the approach for iteration. + +## Approach + +Built on PR by @thwu1 (Int5-MLP + BigramHash + SWA), adding: + +### 1. Quantization-Aware Training (QAT) +STE fake-quantization during training: int5 (clip=15) for MLP layers, int6 (clip=31) for attention. The model learns to be robust to quantization noise. **Bug found:** The STE uses symmetric clipping while the export uses percentile-based per-row scaling — this mismatch caused the model to optimize for the wrong quantization target, resulting in a 0.25 BPB penalty instead of the expected ~0.02. + +### 2. Neural Cache +During sliding window eval, maintain a ring buffer of pre-lm_head hidden states (dim=512, bf16). For each token, compute cosine similarity against cached states, build a cache distribution via softmax-weighted scatter, and interpolate with model predictions using logaddexp. Causal token-by-token scoring with document boundary resets prevents information leakage. + +### 3. LoRA Test-Time Training +Per-document rank-8 LoRA adaptation on lm_head, Q, and V projections during evaluation. Documents batched (batch_size=64), chunks scored before training (no leakage), with entropy-gated updates. + +## Architecture +- 10 layers, 512 dim, 8 heads, 4 KV heads (GQA), 3x MLP (1536 hidden) +- BigramHash(10240, dim=128), SmearGate, orthogonal init +- Muon optimizer: matrix_lr=0.02, WD=0.04, momentum=0.99 +- SWA: last 40% of warmdown, every 50 steps, 24 checkpoints averaged +- seq_len=2048, batch=786K tokens + +## Results +| Seed | Pre-quant val_bpb | Post-quant sliding val_bpb | Steps | Artifact | +|------|-------------------|---------------------------|-------|----------| +| 1337 | 1.1739 | 1.4245 | 5109 | 15.77 MB | + +## Known Issues +1. **QAT mismatch:** STE clip ranges don't match export quantization format — needs per-row percentile clipping in the STE to match `quantize_intN_per_row` +2. **Pre-quant BPB already worse than SOTA:** 1.1739 vs 1.1428 — QAT may be hurting convergence with current hyperparameters +3. Only 1 seed (need 3+ for statistical significance) + +## Next Steps +- Run without QAT to verify base recipe reproduces 1.1428 +- Fix QAT to match exact export quantization format +- Run neural cache + TTT eval on a working checkpoint +- Sweep cache hyperparameters (theta, lambda) + +## Command +```bash +RUN_ID=run1_seed1337 SEED=1337 QAT_ENABLED=1 EVAL_STRIDE=64 EVAL_STRATEGY=combined \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/submission.json b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/submission.json new file mode 100644 index 000000000..7f01df4dc --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/submission.json @@ -0,0 +1,9 @@ +{ + "name": "QAT + Neural Cache + LoRA TTT (non-record)", + "val_bpb": 1.4245, + "bytes_total": 15766801, + "blurb": "Non-record submission exploring QAT + neural cache + LoRA TTT on top of #1 recipe. QAT implementation has export mismatch bug causing 0.25 BPB quantization penalty. Submitting to document approach and iterate.", + "author": "Bortlesboat", + "github_id": "Bortlesboat", + "date": "2026-03-20" +} diff --git a/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py new file mode 100644 index 000000000..3af662bfc --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py @@ -0,0 +1,1709 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Test-time training (LoRA) hyperparameters. + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_entropy_gate = float(os.environ.get("TTT_ENTROPY_GATE", 0.0)) + + # Neural cache hyperparameters. + cache_size = int(os.environ.get("CACHE_SIZE", 2048)) + cache_theta = float(os.environ.get("CACHE_THETA", 5.0)) + cache_lambda = float(os.environ.get("CACHE_LAMBDA", 0.05)) + + # QAT (Quantization-Aware Training) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "1"))) + + # Eval strategy + eval_strategy = os.environ.get("EVAL_STRATEGY", "sliding") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # QAT fields: set per-module after model construction + _qat_clip: int = 0 # 0=disabled, 15=int5, 31=int6 + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self._qat_clip > 0 and self.training and w.ndim == 2: + # STE fake quantization matching int5/int6 export format + w_f = w.float() + clip = self._qat_clip + amax = w_f.abs().amax(dim=-1, keepdim=True).clamp_min(1e-12) + scale = amax / clip + w_q = (torch.clamp(torch.round(w_f / scale), -(clip + 1), clip) * scale) + w = w + (w_q - w_f).detach() # Straight-through estimator + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Expand KV heads for GQA compatibility with older PyTorch + if self.num_kv_heads != self.num_heads: + reps = self.num_heads // self.num_kv_heads + k = k[:, :, None, :, :].expand(-1, -1, reps, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim) + v = v[:, :, None, :, :].expand(-1, -1, reps, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + qd = lora.q_loras[i] if lora else None + vd = lora.v_loras[i] if lora else None + x = self.blocks[i](x, x0, qd, vd) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd = lora.q_loras[bi] if lora else None + vd = lora.v_loras[bi] if lora else None + x = self.blocks[bi](x, x0, qd, vd) + x_norm = self.final_norm(x) + x_flat = x_norm.reshape(-1, x_norm.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits_proj = logits_proj + (lora.lm_head_lora(x_norm).reshape(-1, logits_proj.size(-1)) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if lora: + bsz, sl = input_ids.shape + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(bsz, sl) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward_logits_and_hidden(self, input_ids: Tensor) -> tuple[Tensor, Tensor]: + """Forward pass returning (logits, hidden_states). For neural cache eval.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + hidden = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(hidden, self.tok_emb.weight.to(hidden.dtype)) + else: + logits = self.lm_head(hidden) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + return logits, hidden + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_with_cache( + logits_hidden_fn, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + seq_len: int, + stride: int, + cache_size: int = 2048, + cache_theta: float = 5.0, + cache_lambda: float = 0.05, + eval_batch_seqs: int = 64, +) -> tuple[float, float]: + """Sliding window eval with neural cache interpolation.""" + total = val_tokens.numel() - 1 + + # Build windows + windows: list[tuple[int, int]] = [] + p = 0 + while p + seq_len <= total: + s = 0 if p == 0 else (seq_len - stride) + windows.append((p, s)) + p += stride + + n = len(windows) + per_rank = (n + world_size - 1) // world_size + my_start = rank * per_rank + my_end = min(my_start + per_rank, n) + my_windows = windows[my_start:my_end] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Neural cache ring buffer (on GPU for fast lookup) + cache_keys: Tensor | None = None # [cache_size, model_dim] + cache_vals: Tensor | None = None # [cache_size] + cache_len = 0 + cache_ptr = 0 + cache_initialized = False + + with torch.inference_mode(): + for i in range(0, len(my_windows), eval_batch_seqs): + batch = my_windows[i : i + eval_batch_seqs] + bs = len(batch) + + x_list = [val_tokens[w : w + seq_len] for w, _ in batch] + y_list = [val_tokens[w + 1 : w + seq_len + 1] for w, _ in batch] + pad = eval_batch_seqs - bs + if pad > 0: + x_list.extend([x_list[-1]] * pad) + y_list.extend([y_list[-1]] * pad) + + x = torch.stack(x_list).to(device=device, dtype=torch.int64) + y = torch.stack(y_list).to(device=device, dtype=torch.int64) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits, hidden = logits_hidden_fn(x) + + # Initialize cache on first pass + if not cache_initialized: + model_dim = hidden.shape[-1] + cache_keys = torch.zeros(cache_size, model_dim, device=device, dtype=torch.bfloat16) + cache_vals = torch.zeros(cache_size, device=device, dtype=torch.long) + cache_initialized = True + + for b in range(bs): + s = batch[b][1] + scored_logits = logits[b, s:].float() # [num_scored, vocab] + scored_targets = y[b, s:] + scored_hidden = hidden[b, s:] # [num_scored, model_dim] + scored_input = x[b, s : s + scored_targets.numel()] + ns = scored_targets.numel() + vocab_size = scored_logits.shape[-1] + + log_probs_model = F.log_softmax(scored_logits, dim=-1) + + # Token-by-token scoring + cache update (causal: each token + # sees only cache entries from BEFORE it, and BOS resets are + # applied before scoring the token that follows a BOS). + for t_idx in range(ns): + # Reset cache at document boundaries BEFORE scoring + if is_boundary_token_lut[scored_input[t_idx]]: + cache_len = 0 + cache_ptr = 0 + + target_id = scored_targets[t_idx] + + if cache_len > 0: + active_len = min(cache_len, cache_size) + keys = cache_keys[:active_len] + vals = cache_vals[:active_len] + h = scored_hidden[t_idx].to(torch.bfloat16).unsqueeze(0) + h_norm = F.normalize(h, dim=-1) + k_norm = F.normalize(keys, dim=-1) + sim = (h_norm.float() @ k_norm.float().T).squeeze(0) + cache_attn = F.softmax(cache_theta * sim, dim=-1) + cache_probs = torch.zeros(vocab_size, device=device) + cache_probs.scatter_add_(0, vals, cache_attn) + log_cache = torch.log(cache_probs + 1e-10) + log_final_t = torch.logaddexp( + math.log(1 - cache_lambda) + log_probs_model[t_idx], + math.log(cache_lambda) + log_cache, + ) + else: + log_final_t = log_probs_model[t_idx] + + loss_sum += -log_final_t[target_id].to(torch.float64) + tok_count += 1 + + # Byte counting + prev_id = scored_input[t_idx] + tb = base_bytes_lut[target_id].to(torch.float64) + if has_leading_space_lut[target_id] and not is_boundary_token_lut[prev_id]: + tb += 1.0 + byte_count += tb + + # Update cache AFTER scoring (causal) + idx = cache_ptr % cache_size + cache_keys[idx] = scored_hidden[t_idx].to(torch.bfloat16) + cache_vals[idx] = target_id + cache_ptr += 1 + cache_len += 1 + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(tok_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / tok_count).item() + bpb = val_loss / math.log(2.0) * (tok_count.item() / byte_count.item()) + return val_loss, bpb + + +# ----------------------------- +# TEST-TIME TRAINING (LoRA) +# ----------------------------- + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """LoRA for a linear layer, with independent weights per batch element.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) + self.B.zero_() + + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + + +def _build_ttt_optimizer(lora, args: Hyperparameters): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + + +def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document.""" + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl: Tensor, x: Tensor, y: Tensor, + batch_i: int, chunk_offset: int, chunk_len: int, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, +): + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Evaluate with batched LoRA test-time training.""" + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank = args.ttt_lora_rank + + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1] + toks = chunk.to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb( + ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + # Entropy gating: only train on chunks where model is surprised + if args.ttt_entropy_gate > 0: + with torch.no_grad(): + gate = (per_doc > args.ttt_entropy_gate).float() + mask = mask * gate + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + local_test = bool(int(os.environ.get("LOCAL_TEST", "0"))) + use_compile = not local_test + if use_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(local_test) + enable_math_sdp(local_test) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + + # Enable QAT: int5 for MLP weights, int6 for attention weights + if args.qat_enabled: + for name, module in base_model.named_modules(): + if isinstance(module, CastedLinear): + if ".mlp." in name: + module._qat_clip = 15 # int5 + elif ".attn." in name: + module._qat_clip = 31 # int6 + # Embeddings, bigram proj, skip weights: no QAT (kept in fp16/fp32) + log0(f"QAT enabled: int5 for MLP, int6 for attention") + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if use_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + # Eval on int6-roundtripped weights + val_tokens_eval = val_tokens + torch.cuda.synchronize() + t_qeval = time.perf_counter() + + strategy = args.eval_strategy + log0(f"eval_strategy:{strategy}") + + if strategy == "sliding_cache": + log0("Running sliding window + neural cache eval...") + q_val_loss, q_val_bpb = eval_val_sliding_with_cache( + base_model.forward_logits_and_hidden, rank, world_size, device, + val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + args.train_seq_len, args.eval_stride, + cache_size=args.cache_size, cache_theta=args.cache_theta, + cache_lambda=args.cache_lambda, eval_batch_seqs=args.eval_batch_seqs, + ) + elif strategy == "ttt": + log0("Running LoRA TTT eval...") + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + elif strategy == "combined": + log0("[combined] Phase 1: Sliding window + neural cache...") + cache_val_loss, cache_val_bpb = eval_val_sliding_with_cache( + base_model.forward_logits_and_hidden, rank, world_size, device, + val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + args.train_seq_len, args.eval_stride, + cache_size=args.cache_size, cache_theta=args.cache_theta, + cache_lambda=args.cache_lambda, eval_batch_seqs=args.eval_batch_seqs, + ) + log0(f"[combined] sliding+cache val_bpb:{cache_val_bpb:.4f}") + + log0("[combined] Phase 2: LoRA TTT...") + torch._dynamo.reset() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0(f"[combined] ttt val_bpb:{ttt_val_bpb:.4f}") + + q_val_bpb = min(cache_val_bpb, ttt_val_bpb) + q_val_loss = cache_val_loss if cache_val_bpb <= ttt_val_bpb else ttt_val_loss + log0(f"[combined] best individual val_bpb:{q_val_bpb:.4f}") + else: + # Default: sliding window only (original behavior) + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_seed1337.log b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_seed1337.log new file mode 100644 index 000000000..001332150 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_seed1337.log @@ -0,0 +1,135 @@ +W0321 01:46:21.203000 128328209879680 torch/distributed/run.py:779] +W0321 01:46:21.203000 128328209879680 torch/distributed/run.py:779] ***************************************** +W0321 01:46:21.203000 128328209879680 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0321 01:46:21.203000 128328209879680 torch/distributed/run.py:779] ***************************************** +logs/run1_seed1337.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +QAT enabled: int5 for MLP, int6 for attention +model_params:25517137 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9284 val_bpb:4.1034 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9301 train_time:142ms step_avg:141.56ms +step:2/20000 train_loss:7.6390 train_time:227ms step_avg:113.40ms +step:3/20000 train_loss:7.2632 train_time:329ms step_avg:109.78ms +step:4/20000 train_loss:8.0002 train_time:431ms step_avg:107.65ms +step:5/20000 train_loss:8.3939 train_time:532ms step_avg:106.33ms +step:6/20000 train_loss:8.2110 train_time:634ms step_avg:105.71ms +step:7/20000 train_loss:7.5933 train_time:737ms step_avg:105.26ms +step:8/20000 train_loss:6.8652 train_time:848ms step_avg:106.01ms +step:9/20000 train_loss:6.3185 train_time:950ms step_avg:105.56ms +step:10/20000 train_loss:6.0389 train_time:1051ms step_avg:105.15ms +step:100/20000 train_loss:3.1871 train_time:9904ms step_avg:99.04ms +step:200/20000 train_loss:2.4389 train_time:22606ms step_avg:113.03ms +step:300/20000 train_loss:2.5740 train_time:35201ms step_avg:117.34ms +step:400/20000 train_loss:2.4384 train_time:47337ms step_avg:118.34ms +step:500/20000 train_loss:2.4215 train_time:57243ms step_avg:114.49ms +step:500/20000 val_loss:2.3730 val_bpb:1.4054 train_time:57275ms step_avg:114.55ms +step:600/20000 train_loss:2.3558 train_time:69828ms step_avg:116.38ms +step:700/20000 train_loss:2.3646 train_time:82162ms step_avg:117.37ms +step:800/20000 train_loss:2.2581 train_time:94412ms step_avg:118.01ms +step:900/20000 train_loss:2.1456 train_time:106569ms step_avg:118.41ms +step:1000/20000 train_loss:2.2949 train_time:116466ms step_avg:116.47ms +step:1000/20000 val_loss:2.2417 val_bpb:1.3277 train_time:116497ms step_avg:116.50ms +step:1100/20000 train_loss:2.3445 train_time:128599ms step_avg:116.91ms +step:1200/20000 train_loss:2.3735 train_time:140813ms step_avg:117.34ms +step:1300/20000 train_loss:2.1190 train_time:153410ms step_avg:118.01ms +step:1400/20000 train_loss:2.2040 train_time:165709ms step_avg:118.36ms +step:1500/20000 train_loss:2.2390 train_time:175586ms step_avg:117.06ms +step:1500/20000 val_loss:2.1996 val_bpb:1.3027 train_time:175622ms step_avg:117.08ms +step:1600/20000 train_loss:2.0922 train_time:187806ms step_avg:117.38ms +step:1700/20000 train_loss:2.1616 train_time:199863ms step_avg:117.57ms +step:1800/20000 train_loss:2.1825 train_time:212078ms step_avg:117.82ms +step:1900/20000 train_loss:2.1487 train_time:221958ms step_avg:116.82ms +step:2000/20000 train_loss:2.0841 train_time:234240ms step_avg:117.12ms +step:2000/20000 val_loss:2.1478 val_bpb:1.2721 train_time:234270ms step_avg:117.14ms +step:2100/20000 train_loss:2.0660 train_time:246613ms step_avg:117.43ms +step:2200/20000 train_loss:2.1503 train_time:258634ms step_avg:117.56ms +step:2300/20000 train_loss:2.1247 train_time:270977ms step_avg:117.82ms +step:2400/20000 train_loss:2.0777 train_time:280850ms step_avg:117.02ms +step:2500/20000 train_loss:2.1808 train_time:292995ms step_avg:117.20ms +step:2500/20000 val_loss:2.1119 val_bpb:1.2508 train_time:293027ms step_avg:117.21ms +step:2600/20000 train_loss:2.1129 train_time:305130ms step_avg:117.36ms +step:2700/20000 train_loss:2.1044 train_time:317404ms step_avg:117.56ms +step:2800/20000 train_loss:2.1562 train_time:329614ms step_avg:117.72ms +step:2900/20000 train_loss:2.0199 train_time:339464ms step_avg:117.06ms +step:3000/20000 train_loss:2.1550 train_time:351572ms step_avg:117.19ms +step:3000/20000 val_loss:2.0833 val_bpb:1.2339 train_time:351604ms step_avg:117.20ms +step:3100/20000 train_loss:2.0313 train_time:363827ms step_avg:117.36ms +step:3200/20000 train_loss:2.1638 train_time:375880ms step_avg:117.46ms +step:3300/20000 train_loss:2.0580 train_time:385737ms step_avg:116.89ms +step:3400/20000 train_loss:2.0075 train_time:397920ms step_avg:117.04ms +step:3500/20000 train_loss:2.1618 train_time:409939ms step_avg:117.13ms +step:3500/20000 val_loss:2.0608 val_bpb:1.2205 train_time:409971ms step_avg:117.13ms +step:3600/20000 train_loss:2.0780 train_time:422156ms step_avg:117.27ms +step:3700/20000 train_loss:2.0675 train_time:434492ms step_avg:117.43ms +step:3800/20000 train_loss:2.0481 train_time:444348ms step_avg:116.93ms +step:3900/20000 train_loss:2.0536 train_time:456623ms step_avg:117.08ms +swa:start step:3950 +step:4000/20000 train_loss:1.9516 train_time:469061ms step_avg:117.27ms +step:4000/20000 val_loss:2.0358 val_bpb:1.2057 train_time:469125ms step_avg:117.28ms +step:4100/20000 train_loss:1.9851 train_time:481241ms step_avg:117.38ms +step:4200/20000 train_loss:2.1214 train_time:493527ms step_avg:117.51ms +step:4300/20000 train_loss:2.0258 train_time:503439ms step_avg:117.08ms +step:4400/20000 train_loss:2.0011 train_time:515722ms step_avg:117.21ms +step:4500/20000 train_loss:2.0881 train_time:527915ms step_avg:117.31ms +step:4500/20000 val_loss:2.0097 val_bpb:1.1903 train_time:527979ms step_avg:117.33ms +step:4600/20000 train_loss:1.8090 train_time:540172ms step_avg:117.43ms +step:4700/20000 train_loss:2.2016 train_time:550086ms step_avg:117.04ms +step:4800/20000 train_loss:2.3969 train_time:562216ms step_avg:117.13ms +step:4900/20000 train_loss:2.0086 train_time:574556ms step_avg:117.26ms +step:5000/20000 train_loss:2.0654 train_time:586649ms step_avg:117.33ms +step:5000/20000 val_loss:1.9846 val_bpb:1.1754 train_time:586715ms step_avg:117.34ms +step:5100/20000 train_loss:2.0883 train_time:599026ms step_avg:117.46ms +step:5109/20000 val_loss:1.9821 val_bpb:1.1739 train_time:599967ms step_avg:117.43ms +stopping_early: wallclock_cap train_time:599967ms step:5109/20000 +peak memory allocated: 23856 MiB reserved: 24366 MiB +swa:applying averaged 24 checkpoints +Serialized model: 98437014 bytes +Code size: 73542 bytes +Total submission size: 98510556 bytes +Serialized model int6+zstd: 15691550 bytes +Total submission size int8+zlib: 15765092 bytes +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +eval_strategy:combined +[combined] Phase 1: Sliding window + neural cache... diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/README.md b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/README.md new file mode 100644 index 000000000..d7488c582 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/README.md @@ -0,0 +1,78 @@ +# 10L Int5-MLP + Multi-Order N-gram Backoff (0.9123 BPB) + +**val_bpb: 0.9123** (mean of 3 seeds, post int5/int6+zstd quantization roundtrip) + +**Record delta vs merged SOTA (PR #549, 1.1194 BPB):** -0.2071 nats, std=0.0003, p < 0.001 + +## Compliance + +- **Score-first**: every token's BPB is finalized before that token updates any cache table +- **Backward-looking only**: n-gram cache uses only previously scored tokens, never future tokens +- **No target-aware gating**: interpolation alpha depends solely on model entropy (its own output distribution), never on ground-truth labels +- **No future-token access**: cache tables are updated AFTER the segment is scored +- **Self-contained**: no network calls, no external data, no training data access during eval + +## Results + +| Seed | val_bpb | artifact_bytes | +|------|---------|----------------| +| 42 | 0.9128 | 15,320,000 | +| 1337 | 0.9121 | 15,630,000 | +| 2024 | 0.9121 | 15,330,000 | +| **Mean** | **0.9123 +/- 0.0003** | | + +## Architecture + +- 10 layers, d=512, 8 heads, 4 KV heads (GQA) +- MLP: 3x expansion (1536), LeakyReLU(0.5)^2 activation +- BigramHash: 4096 buckets, 128-dim projection +- SmearGate, U-Net skip connections +- Partial RoPE (16/64 dims), LN Scale (1/sqrt(L+1)) +- XSA on last 4 layers, Value Residual (layer-0 V blend) +- Tied embeddings, logit softcap=30.0 + +## Training + +- Muon optimizer (matrices) + AdamW (embeddings/scalars), WD=0.04 +- EMA: decay=0.997, updated every 10 steps on GPU +- Warmdown: 3500 steps, warmup: 5 steps +- Wallclock cap: 600s on 8xH100 (~6020 steps) +- val_loss_every=0 to maximize training steps + +## Quantization + +- Int5 per-row for MLP weights, Int6 per-row for attention +- FP16 passthrough for small/control tensors +- Magnitude pruning (3% threshold) before quantization +- zstd-22 compression + +## Evaluation: Multi-Order N-gram Backoff + +Legal score-first hashed n-gram cache with entropy-adaptive interpolation: + +- Orders 2 through 7 with backoff (highest matching order wins) +- Separate hash tables per order (4M buckets each, uint32 counts) +- Entropy-adaptive alpha: `alpha = 0.05 + 0.55 * sigmoid(2 * (H - 4.0))` + - Low model entropy (confident): alpha near 0.05, trust model + - High model entropy (uncertain): alpha near 0.60, trust n-gram +- Score-first: cache updated only AFTER segment scoring +- Sliding window stride=64, eval_batch_seqs=64 +- Eval time: ~163s on 8xH100 (well within 10-min budget) + +## Based on + +- thwu1's 10L Int5-MLP architecture (base model) +- PR #727 (multi-order n-gram backoff concept) +- PR #549 (LeakyReLU^2 + score-first TTT) +- PR #287 (XSA, EMA, Partial RoPE, LN Scale) + +## Reproduce + +```bash +SEED=42 torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +Disable n-gram cache (base model only): +```bash +NGRAM_EVAL_ORDER=0 SEED=42 torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/cached_challenge_fineweb.py b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/cached_challenge_fineweb.py new file mode 100644 index 000000000..fa8029be4 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/cached_challenge_fineweb.py @@ -0,0 +1,157 @@ +import argparse +import json +import os +import shutil +from pathlib import Path + +from huggingface_hub import hf_hub_download + + +REPO_ID = os.environ.get("MATCHED_FINEWEB_REPO_ID", "willdepueoai/parameter-golf") +REMOTE_ROOT_PREFIX = os.environ.get("MATCHED_FINEWEB_REMOTE_ROOT_PREFIX", "datasets") +ROOT = Path(__file__).resolve().parent +DATASETS_DIR = ROOT / "datasets" +TOKENIZERS_DIR = ROOT / "tokenizers" + +def dataset_dir_for_variant(name: str) -> str: + if name == "byte260": + return "fineweb10B_byte260" + if name.startswith("sp") and name[2:].isdigit(): + return f"fineweb10B_{name}" + raise ValueError(f"unsupported variant {name!r}; expected byte260 or sp") + + +def local_path_for_remote(relative_path: str) -> Path: + remote_path = Path(relative_path) + if REMOTE_ROOT_PREFIX and remote_path.parts[:1] == (REMOTE_ROOT_PREFIX,): + remote_path = remote_path.relative_to(REMOTE_ROOT_PREFIX) + if remote_path.parts[:1] == ("datasets",): + return DATASETS_DIR.joinpath(*remote_path.parts[1:]) + if remote_path.parts[:1] == ("tokenizers",): + return TOKENIZERS_DIR.joinpath(*remote_path.parts[1:]) + return ROOT / remote_path + + +def get(relative_path: str) -> None: + destination = local_path_for_remote(relative_path) + if destination.exists(): + return + if destination.is_symlink(): + destination.unlink() + + remote_path = Path(relative_path) + cached_path = Path( + hf_hub_download( + repo_id=REPO_ID, + filename=remote_path.name, + subfolder=remote_path.parent.as_posix() if remote_path.parent != Path(".") else None, + repo_type="dataset", + ) + ) + # HF cache entries may be snapshot symlinks. Resolve to the underlying blob so we + # always materialize a real file in data/, not a broken relative symlink. + cached_source = cached_path.resolve(strict=True) + destination.parent.mkdir(parents=True, exist_ok=True) + try: + os.link(cached_source, destination) + except OSError: + shutil.copy2(cached_source, destination) + + +def manifest_path() -> Path: + return local_path_for_remote(f"{REMOTE_ROOT_PREFIX}/manifest.json") + + +def load_manifest(*, skip_manifest_download: bool) -> dict: + path = manifest_path() + if not path.is_file(): + if skip_manifest_download: + raise FileNotFoundError( + f"manifest.json is required for manifest-driven shard counts but is not present locally at {path}" + ) + get(f"{REMOTE_ROOT_PREFIX}/manifest.json") + return json.loads(path.read_text(encoding="utf-8")) + + +def artifact_paths_for_tokenizer(tokenizer_entry: dict) -> list[str]: + artifacts = [] + for key in ("model_path", "vocab_path", "path"): + value = tokenizer_entry.get(key) + if value: + artifacts.append(str(value)) + if not artifacts: + raise ValueError(f"tokenizer entry is missing downloadable artifacts: {tokenizer_entry}") + return artifacts + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Download challenge FineWeb shards from Hugging Face") + parser.add_argument( + "train_shards_positional", + nargs="?", + type=int, + default=None, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--train-shards", + type=int, + default=80, + help="Number of training shards to download for the selected variant. Defaults to 80.", + ) + parser.add_argument( + "--variant", + default="sp1024", + help="Tokenizer family to download, for example sp1024, sp4096, or byte260.", + ) + parser.add_argument( + "--skip-manifest", + action="store_true", + help="Skip downloading manifest.json.", + ) + parser.add_argument( + "--with-docs", + action="store_true", + help="Also download docs_selected.jsonl and its sidecar for tokenizer retraining or dataset re-export.", + ) + return parser + + +def main() -> None: + args = build_parser().parse_args() + dataset_dir = dataset_dir_for_variant(args.variant) + train_shards = args.train_shards_positional if args.train_shards_positional is not None else args.train_shards + if train_shards < 0: + raise ValueError("train_shards must be non-negative") + + manifest = load_manifest(skip_manifest_download=args.skip_manifest) + dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir), None) + if dataset_entry is None: + raise ValueError(f"dataset {dataset_dir} not found in {REMOTE_ROOT_PREFIX}/manifest.json") + max_train_shards = int((dataset_entry.get("stats") or {}).get("files_train")) + val_shards = int((dataset_entry.get("stats") or {}).get("files_val")) + if train_shards > max_train_shards: + raise ValueError( + f"{args.variant} only has {max_train_shards} training shards on {REPO_ID}, requested {train_shards}" + ) + tokenizer_name = dataset_entry.get("tokenizer_name") + tokenizer_entry = next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) + if tokenizer_entry is None: + raise ValueError(f"tokenizer {tokenizer_name} not found in {REMOTE_ROOT_PREFIX}/manifest.json") + + if args.with_docs: + get(f"{REMOTE_ROOT_PREFIX}/docs_selected.jsonl") + get(f"{REMOTE_ROOT_PREFIX}/docs_selected.source_manifest.json") + + dataset_prefix = f"{REMOTE_ROOT_PREFIX}/datasets/{dataset_dir}" + for i in range(val_shards): + get(f"{dataset_prefix}/fineweb_val_{i:06d}.bin") + for i in range(train_shards): + get(f"{dataset_prefix}/fineweb_train_{i:06d}.bin") + + for artifact_path in artifact_paths_for_tokenizer(tokenizer_entry): + get(f"{REMOTE_ROOT_PREFIX}/{artifact_path}") + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/runpod_launch.sh b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/runpod_launch.sh new file mode 100644 index 000000000..199d95508 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/runpod_launch.sh @@ -0,0 +1,60 @@ +#!/bin/bash +set -e +echo "=== Parameter Golf V6 RunPod Setup ===" +pip install sentencepiece zstandard huggingface_hub 2>/dev/null + +# Data setup +if [ ! -d "./data/datasets/fineweb10B_sp1024" ]; then + if [ -d "./datasets/fineweb10B_sp1024" ]; then + mkdir -p data + ln -sf "$(pwd)/datasets" data/datasets + ln -sf "$(pwd)/tokenizers" data/tokenizers + else + python3 cached_challenge_fineweb.py --variant sp1024 + mkdir -p data + ln -sf "$(pwd)/datasets" data/datasets + ln -sf "$(pwd)/tokenizers" data/tokenizers + fi +fi +echo "Data ready: $(ls data/datasets/fineweb10B_sp1024/ | wc -l) files" + +MODE=${1:-default} +SEED=${SEED:-42} +echo "=== Mode: $MODE | Seed: $SEED ===" + +case $MODE in + smoke) + # 60-second smoke test — catches crashes before burning a full run ($0.40 vs $8) + echo "SMOKE TEST: 60s training + quick eval — catching crashes early" + MAX_WALLCLOCK_SECONDS=60 VAL_LOSS_EVERY=0 NGRAM_EVAL_ORDER=0 \ + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + echo "SMOKE TEST PASSED — safe to run full" + ;; + default) + echo "V6: 10L d=512 4KV LeakyReLU^2 XSA4 PartialRoPE VR EMA + 7-gram backoff + entropy-adaptive" + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + ;; + fast) + # Smoke test then full run back-to-back + echo "=== SMOKE TEST (60s) ===" + MAX_WALLCLOCK_SECONDS=60 VAL_LOSS_EVERY=0 NGRAM_EVAL_ORDER=0 \ + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + echo "=== SMOKE PASSED — LAUNCHING FULL RUN ===" + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + ;; + no_ngram) + echo "Ablation: no n-gram cache" + NGRAM_EVAL_ORDER=0 SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + ;; + three_seed) + for S in 42 1337 2024; do + echo "=== Seed $S ===" + SEED=$S torchrun --standalone --nproc_per_node=8 train_gpt.py + done + ;; + *) + echo "Modes: smoke|default|fast|no_ngram|three_seed" + exit 1 + ;; +esac +echo "=== Done ===" diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/submission.json b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/submission.json new file mode 100644 index 000000000..e9da12b8d --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/submission.json @@ -0,0 +1,10 @@ +{ + "author": "Bortlesboat", + "github_id": "Bortlesboat", + "name": "10L Int5-MLP + BigramHash(4096) + Multi-Order N-gram Backoff + Entropy-Adaptive Alpha", + "blurb": "10 layers, d=512, GQA 8H/4KV. LeakyReLU(0.5)^2, Partial RoPE(16/64), LN Scale, XSA last 4, Value Residual. EMA(0.997). Mixed int5/int6 + zstd-22. Eval: multi-order hashed n-gram backoff (orders 2-7) with entropy-adaptive alpha. Mean of 3 seeds.", + "date": "2026-03-25", + "val_loss": 1.5404, + "val_bpb": 0.9123, + "bytes_total": 15320000 +} diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/train_gpt.py b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/train_gpt.py new file mode 100644 index 000000000..7721b5a33 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/train_gpt.py @@ -0,0 +1,1541 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) # 0=skip mid-train val, maximize training steps + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 5)) # minimal warmup, maximize real steps + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 64)) # larger batch for faster eval (no gradients) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # Partial RoPE: only rotate first rope_dims dims (0 = full head_dim) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + + # LN Scale: dampen norm inputs by 1/sqrt(layer_idx+1) for deeper layers + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + + # XSA: exclusive self-attention on last N layers (0 = disabled) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # proven: last 4 layers + + # EMA: exponential moving average (replaces SWA when enabled) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "0"))) # OFF by default, EMA replaces it + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # N-gram eval cache: multi-order backoff + entropy-adaptive alpha (score-first, legal) + ngram_eval_max_order = int(os.environ.get("NGRAM_EVAL_ORDER", 7)) # max n-gram order + ngram_eval_min_order = int(os.environ.get("NGRAM_EVAL_MIN_ORDER", 2)) # min backoff order + ngram_eval_alpha = float(os.environ.get("NGRAM_EVAL_ALPHA", 0.40)) # base alpha + ngram_eval_min_count = int(os.environ.get("NGRAM_EVAL_MIN_COUNT", 2)) + ngram_eval_buckets = int(os.environ.get("NGRAM_EVAL_BUCKETS", 4_194_304)) + ngram_eval_entropy = bool(int(os.environ.get("NGRAM_EVAL_ENTROPY", "1"))) + ngram_eval_ent_base = float(os.environ.get("NGRAM_EVAL_ENT_BASE", 0.05)) + ngram_eval_ent_range = float(os.environ.get("NGRAM_EVAL_ENT_RANGE", 0.55)) + ngram_eval_ent_scale = float(os.environ.get("NGRAM_EVAL_ENT_SCALE", 2.0)) + ngram_eval_ent_thresh = float(os.environ.get("NGRAM_EVAL_ENT_THRESH", 4.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, + rope_dims: int = 0, use_xsa: bool = False): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.rope_dims = rope_dims if rope_dims > 0 else self.head_dim + self.use_xsa = use_xsa + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.rope_dims, base=rope_base) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Exclusive self-attention: subtract self-value from attention output.""" + # y is post-attention [bsz, heads, seq, head_dim], v is [bsz, kv_heads, seq, head_dim] + if self.num_kv_heads != self.num_heads: + v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + return y - v / v.size(2) + + def forward(self, x: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + # Value Residual: blend with layer-0 V + if v0 is not None: + v = 0.5 * (v + v0) + v_out = v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dims < self.head_dim: + # Partial RoPE: rotate only first rope_dims, pass rest through + q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] + k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if self.num_kv_heads != self.num_heads: + n_rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(n_rep, dim=1) + v_sdpa = v.repeat_interleave(n_rep, dim=1) + else: + v_sdpa = v + y = F.scaled_dot_product_attention( + q, k, v_sdpa, attn_mask=None, is_causal=True, + ) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y), v_out + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, + qk_gain_init: float, rope_dims: int = 0, use_xsa: bool = False, ln_scale_factor: float = 1.0): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_dims=rope_dims, use_xsa=use_xsa) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = ln_scale_factor + + def forward(self, x: Tensor, x0: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor]: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_out, v_out = self.attn(self.attn_norm(x) * s, v0=v0) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x, v_out + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + rope_dims: int = 0, + ln_scale: bool = False, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + rope_dims=rope_dims, + use_xsa=(i >= num_layers - xsa_last_n) if xsa_last_n > 0 else False, + ln_scale_factor=1.0 / math.sqrt(i + 1) if ln_scale else 1.0) + for i in range(num_layers) + ]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def _forward_body(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0: Tensor | None = None + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, v_out = self.blocks[i](x, x0, v0=v0) + if v0 is None: + v0 = v_out + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + return x + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self._forward_body(input_ids) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self._forward_body(input_ids) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + eval_start = time.perf_counter() + eval_budget_s = 570.0 # 30s margin from 10-min eval budget + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + eval_elapsed = time.perf_counter() - eval_start + if eval_elapsed > eval_budget_s: + if rank == 0: + print(f" FAILSAFE: eval time {eval_elapsed:.0f}s exceeds {eval_budget_s}s budget, returning partial results", flush=True) + break + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ngram( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + """Sliding eval with multi-order n-gram backoff + entropy-adaptive alpha (score-first, legal).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + max_order = args.ngram_eval_max_order + min_order = args.ngram_eval_min_order + buckets = args.ngram_eval_buckets + min_count = args.ngram_eval_min_count + use_entropy = args.ngram_eval_entropy + ent_base = args.ngram_eval_ent_base + ent_range = args.ngram_eval_ent_range + ent_scale = args.ngram_eval_ent_scale + ent_thresh = args.ngram_eval_ent_thresh + base_alpha = args.ngram_eval_alpha + n_orders = max_order - min_order + 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + val_np = val_tokens.numpy() + ctx_tables = [np.zeros((buckets,), dtype=np.uint32) for _ in range(n_orders)] + full_tables = [np.zeros((buckets,), dtype=np.uint32) for _ in range(n_orders)] + mask = np.uint64(buckets - 1) + primes = np.array( + [np.uint64(36313), np.uint64(27191), np.uint64(51647), np.uint64(81929), + np.uint64(131071), np.uint64(175447), np.uint64(209591)], + dtype=np.uint64, + ) + + if rank == 0: + print(f"ngram_cache:enabled orders={min_order}-{max_order} backoff " + f"entropy={use_entropy} alpha={base_alpha} " + f"ent_base={ent_base} ent_range={ent_range} " + f"min_count={min_count} buckets={buckets}", flush=True) + + loss_sum = 0.0 + token_count = 0.0 + byte_count = 0.0 + + eval_start = time.perf_counter() + eval_budget_s = 570.0 + # Pre-allocate eval buffers (avoid per-batch allocation) + x_buf = torch.zeros(batch_seqs, seq_len, dtype=torch.int64, device=device) + y_buf = torch.zeros(batch_seqs, seq_len, dtype=torch.int64, device=device) + base_model.eval() + # Compile eval path for faster inference + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + eval_elapsed = time.perf_counter() - eval_start + if eval_elapsed > eval_budget_s: + if rank == 0: + print(f" FAILSAFE: ngram eval time {eval_elapsed:.0f}s exceeds budget", flush=True) + break + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = x_buf[:bsz] + y_batch = y_buf[:bsz] + x_batch.zero_() + y_batch.zero_() + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + + seg_nll = nll[i, s:wlen].to(torch.float64).cpu().numpy() + seg_model_p = np.exp(-seg_nll) + n_seg = len(seg_nll) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + + # Entropy-adaptive alpha + if use_entropy: + with torch.no_grad(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + alpha_per_tok = ent_base + ent_range / ( + 1.0 + np.exp(-ent_scale * (seg_ent - ent_thresh))) + + # Precompute hashes for all orders + order_data = [] + for oi in range(n_orders): + ctx_w = min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + order_data.append(None) + continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * primes[k % len(primes)] + ctx_key = (ctx_hash & mask).astype(np.int64) + tgt_np = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt_np * primes[ctx_w % len(primes)])) & mask).astype(np.int64) + order_data.append((v_idx, ctx_key, full_key)) + + # Multi-order backoff: highest order first + best_p_ng = np.full(n_seg, -1.0) + for oi in range(n_orders - 1, -1, -1): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + ctx_counts = ctx_tables[oi][ctx_key].astype(np.float64) + full_counts = full_tables[oi][full_key].astype(np.float64) + has_match = (ctx_counts >= float(min_count)) & (full_counts > 0) + needs_fill = has_match & (best_p_ng[v_idx] < 0) + if needs_fill.any(): + fill_idx = v_idx[needs_fill] + p = np.minimum(full_counts[needs_fill], ctx_counts[needs_fill]) / np.maximum(ctx_counts[needs_fill], 1.0) + best_p_ng[fill_idx] = np.clip(p, 0.0, 1.0) + + # Mix model probability with n-gram + has_match = best_p_ng >= 0 + if has_match.any(): + if use_entropy: + alpha = alpha_per_tok[has_match] + else: + alpha = base_alpha + seg_model_p[has_match] = (1.0 - alpha) * seg_model_p[has_match] + alpha * best_p_ng[has_match] + seg_nll = -np.log(np.clip(seg_model_p, 1e-12, 1.0)) + + # Score-first: update ALL order tables AFTER scoring + for oi in range(n_orders): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + np.add.at(ctx_tables[oi], ctx_key, 1) + np.add.at(full_tables[oi], full_key, 1) + + loss_sum += float(seg_nll.sum()) + token_count += float(seg_len) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += float(tb.sum().item()) + + if rank == 0 and (bi // batch_seqs) % 200 == 0 and bi > 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + cur_bpb = (loss_sum / max(token_count, 1.0)) / math.log(2.0) * (token_count / max(byte_count, 1.0)) + elapsed = time.perf_counter() - eval_start + print(f" ngram_eval [{pct:5.1f}%] bpb={cur_bpb:.6f} t={elapsed:.0f}s", flush=True) + + _loss = torch.tensor(loss_sum, device=device, dtype=torch.float64) + _toks = torch.tensor(token_count, device=device, dtype=torch.float64) + _bytes = torch.tensor(byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(_toks, op=dist.ReduceOp.SUM) + dist.all_reduce(_bytes, op=dist.ReduceOp.SUM) + + val_loss = _loss.item() / max(_toks.item(), 1.0) + val_bpb = val_loss / math.log(2.0) * (_toks.item() / max(_bytes.item(), 1.0)) + # Coverage check: warn if eval was cut short + total_expected = sum(1 for ws in window_starts + if (min(ws + seq_len, total_tokens) - ws - (0 if ws == 0 else max(min(ws + seq_len, total_tokens) - ws - stride, 0))) > 0) + coverage = _toks.item() / max(total_expected * stride, 1.0) # approximate + elapsed = time.perf_counter() - eval_start + if rank == 0: + print(f" ngram_eval DONE: bpb={val_bpb:.6f} tokens={_toks.item():.0f} t={elapsed:.0f}s", flush=True) + if elapsed >= eval_budget_s - 10: + print(f" WARNING: eval used {elapsed:.0f}s of {eval_budget_s}s budget — results may be from partial coverage", flush=True) + base_model.train() + return val_loss, val_bpb + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + try: + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + except ImportError: + pass + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + torch._dynamo.config.optimize_ddp = False + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # EMA shadow model (kept on GPU to avoid PCIe bottleneck) + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().clone() for name, t in base_model.state_dict().items()} + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + + # EMA update every 10 steps (GPU-resident, amortize overhead) + if ema_state is not None and step % 10 == 0: + decay = args.ema_decay ** 10 # compensate for batched updates + with torch.no_grad(): + for name, param in base_model.state_dict().items(): + ema_state[name].lerp_(param.detach(), 1.0 - decay) + + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # Apply EMA if enabled (overrides SWA) + if args.ema_enabled and ema_state is not None: + log0("ema:applying shadow model") + current_state = base_model.state_dict() + ema_applied = { + name: tensor.to(dtype=current_state[name].dtype, device=current_state[name].device) + for name, tensor in ema_state.items() + } + base_model.load_state_dict(ema_applied, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + total_bytes = quant_file_bytes + code_bytes + log0(f"Total submission size: {total_bytes} bytes ({total_bytes/1e6:.2f} MB)") + if total_bytes > 16_000_000: + log0(f"FAILSAFE: artifact {total_bytes} bytes EXCEEDS 16MB limit! Aborting eval.") + sys.exit(1) + log0(f"SIZE CHECK PASSED: {total_bytes/1e6:.2f} MB < 16.00 MB") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.ngram_eval_max_order >= 2 and args.eval_stride > 0: + log0(f"final_eval_mode:sliding_ngram orders={args.ngram_eval_min_order}-{args.ngram_eval_max_order} " + f"alpha={args.ngram_eval_alpha} entropy={args.ngram_eval_entropy} stride:{args.eval_stride}") + q_val_loss, q_val_bpb = eval_val_sliding_ngram( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + batch_seqs=args.eval_batch_seqs, + ) + elif args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/v6_seed1337_2024.log b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/v6_seed1337_2024.log new file mode 100644 index 000000000..69716b925 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/v6_seed1337_2024.log @@ -0,0 +1,260 @@ +=== Seed 1337 === +W0326 02:49:04.490000 131659747332736 torch/distributed/run.py:779] +W0326 02:49:04.490000 131659747332736 torch/distributed/run.py:779] ***************************************** +W0326 02:49:04.490000 131659747332736 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 02:49:04.490000 131659747332736 torch/distributed/run.py:779] ***************************************** +logs/2236ee0d-b3f5-4169-a4bc-93283c9719e1.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9296 train_time:202ms step_avg:201.96ms +step:2/20000 train_loss:7.8829 train_time:289ms step_avg:144.61ms +step:3/20000 train_loss:7.1437 train_time:388ms step_avg:129.20ms +step:4/20000 train_loss:7.6901 train_time:486ms step_avg:121.41ms +step:5/20000 train_loss:8.1498 train_time:583ms step_avg:116.64ms +step:6/20000 train_loss:8.0248 train_time:681ms step_avg:113.53ms +step:7/20000 train_loss:7.6134 train_time:780ms step_avg:111.37ms +step:8/20000 train_loss:7.1269 train_time:878ms step_avg:109.74ms +step:9/20000 train_loss:6.7005 train_time:976ms step_avg:108.44ms +step:10/20000 train_loss:6.4465 train_time:1086ms step_avg:108.56ms +step:100/20000 train_loss:3.1764 train_time:9893ms step_avg:98.93ms +step:200/20000 train_loss:2.3873 train_time:19778ms step_avg:98.89ms +step:300/20000 train_loss:2.5480 train_time:29664ms step_avg:98.88ms +step:400/20000 train_loss:2.4215 train_time:39584ms step_avg:98.96ms +step:500/20000 train_loss:2.3991 train_time:49465ms step_avg:98.93ms +step:600/20000 train_loss:2.3343 train_time:59437ms step_avg:99.06ms +step:700/20000 train_loss:2.3452 train_time:69418ms step_avg:99.17ms +step:800/20000 train_loss:2.2350 train_time:79410ms step_avg:99.26ms +step:900/20000 train_loss:2.1321 train_time:89403ms step_avg:99.34ms +step:1000/20000 train_loss:2.2784 train_time:99337ms step_avg:99.34ms +step:1100/20000 train_loss:2.3178 train_time:109329ms step_avg:99.39ms +step:1200/20000 train_loss:2.3546 train_time:119325ms step_avg:99.44ms +step:1300/20000 train_loss:2.1032 train_time:129307ms step_avg:99.47ms +step:1400/20000 train_loss:2.1856 train_time:139303ms step_avg:99.50ms +step:1500/20000 train_loss:2.2236 train_time:149226ms step_avg:99.48ms +step:1600/20000 train_loss:2.0802 train_time:159204ms step_avg:99.50ms +step:1700/20000 train_loss:2.1439 train_time:169184ms step_avg:99.52ms +step:1800/20000 train_loss:2.1605 train_time:179166ms step_avg:99.54ms +step:1900/20000 train_loss:2.1317 train_time:189115ms step_avg:99.53ms +step:2000/20000 train_loss:2.0698 train_time:199093ms step_avg:99.55ms +step:2100/20000 train_loss:2.0486 train_time:209078ms step_avg:99.56ms +step:2200/20000 train_loss:2.1455 train_time:219060ms step_avg:99.57ms +step:2300/20000 train_loss:2.1091 train_time:229034ms step_avg:99.58ms +step:2400/20000 train_loss:2.0668 train_time:238959ms step_avg:99.57ms +step:2500/20000 train_loss:2.1744 train_time:248943ms step_avg:99.58ms +step:2600/20000 train_loss:2.1116 train_time:258921ms step_avg:99.58ms +step:2700/20000 train_loss:2.0964 train_time:268892ms step_avg:99.59ms +step:2800/20000 train_loss:2.1509 train_time:278862ms step_avg:99.59ms +step:2900/20000 train_loss:2.0205 train_time:288763ms step_avg:99.57ms +step:3000/20000 train_loss:2.1526 train_time:298723ms step_avg:99.57ms +step:3100/20000 train_loss:2.0227 train_time:308683ms step_avg:99.58ms +step:3200/20000 train_loss:2.1615 train_time:318647ms step_avg:99.58ms +step:3300/20000 train_loss:2.0566 train_time:328559ms step_avg:99.56ms +step:3400/20000 train_loss:2.0045 train_time:338519ms step_avg:99.56ms +step:3500/20000 train_loss:2.1602 train_time:348488ms step_avg:99.57ms +step:3600/20000 train_loss:2.0773 train_time:358446ms step_avg:99.57ms +step:3700/20000 train_loss:2.0728 train_time:368407ms step_avg:99.57ms +step:3800/20000 train_loss:2.0488 train_time:378319ms step_avg:99.56ms +step:3900/20000 train_loss:2.0530 train_time:388281ms step_avg:99.56ms +step:4000/20000 train_loss:1.9521 train_time:398246ms step_avg:99.56ms +step:4100/20000 train_loss:1.9892 train_time:408207ms step_avg:99.56ms +step:4200/20000 train_loss:2.1251 train_time:418176ms step_avg:99.57ms +step:4300/20000 train_loss:2.0324 train_time:428085ms step_avg:99.55ms +step:4400/20000 train_loss:2.0079 train_time:438047ms step_avg:99.56ms +step:4500/20000 train_loss:2.1006 train_time:448015ms step_avg:99.56ms +step:4600/20000 train_loss:1.8171 train_time:457980ms step_avg:99.56ms +step:4700/20000 train_loss:2.2117 train_time:467882ms step_avg:99.55ms +step:4800/20000 train_loss:2.4033 train_time:477842ms step_avg:99.55ms +step:4900/20000 train_loss:2.0217 train_time:487801ms step_avg:99.55ms +step:5000/20000 train_loss:2.0761 train_time:497766ms step_avg:99.55ms +step:5100/20000 train_loss:2.0999 train_time:507716ms step_avg:99.55ms +step:5200/20000 train_loss:2.0152 train_time:517610ms step_avg:99.54ms +step:5300/20000 train_loss:1.9789 train_time:527572ms step_avg:99.54ms +step:5400/20000 train_loss:2.0207 train_time:537533ms step_avg:99.54ms +step:5500/20000 train_loss:1.9874 train_time:547490ms step_avg:99.54ms +step:5600/20000 train_loss:1.9250 train_time:557451ms step_avg:99.54ms +step:5700/20000 train_loss:1.9831 train_time:567362ms step_avg:99.54ms +step:5800/20000 train_loss:1.9653 train_time:577406ms step_avg:99.55ms +step:5900/20000 train_loss:1.8742 train_time:587364ms step_avg:99.55ms +step:6000/20000 train_loss:1.9150 train_time:597321ms step_avg:99.55ms +step:6028/20000 val_loss:1.9521 val_bpb:1.1561 train_time:600095ms step_avg:99.55ms +stopping_early: wallclock_cap train_time:600095ms step:6028/20000 +peak memory allocated: 25196 MiB reserved: 25954 MiB +ema:applying shadow model +Serialized model: 96864150 bytes +Code size: 68444 bytes +Total submission size: 96932594 bytes +Serialized model int6+zstd: 15565623 bytes +Total submission size: 15634067 bytes (15.63 MB) +SIZE CHECK PASSED: 15.63 MB < 16.00 MB +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_ngram orders=2-7 alpha=0.4 entropy=True stride:64 +ngram_cache:enabled orders=2-7 backoff entropy=True alpha=0.4 ent_base=0.05 ent_range=0.55 min_count=2 buckets=4194304 + ngram_eval [ 10.6%] bpb=1.114368 t=27s + ngram_eval [ 21.2%] bpb=1.096085 t=41s + ngram_eval [ 31.8%] bpb=1.072011 t=55s + ngram_eval [ 42.3%] bpb=1.043196 t=68s + ngram_eval [ 52.9%] bpb=1.015398 t=82s + ngram_eval [ 63.5%] bpb=0.989285 t=96s + ngram_eval [ 74.0%] bpb=0.967902 t=110s + ngram_eval [ 84.6%] bpb=0.947068 t=124s + ngram_eval [ 95.2%] bpb=0.926718 t=138s + ngram_eval DONE: bpb=0.912141 tokens=62023616 t=158s +final_int8_zlib_roundtrip val_loss:1.5401 val_bpb:0.9121 eval_time:158096ms +final_int8_zlib_roundtrip_exact val_loss:1.54010812 val_bpb:0.91214120 +=== Seed 2024 === +W0326 03:03:59.020000 133054078444160 torch/distributed/run.py:779] +W0326 03:03:59.020000 133054078444160 torch/distributed/run.py:779] ***************************************** +W0326 03:03:59.020000 133054078444160 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 03:03:59.020000 133054078444160 torch/distributed/run.py:779] ***************************************** +logs/5e38a299-e6cc-493f-a92b-b0f1b276d42a.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:600.000 +seed:2024 +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9305 train_time:197ms step_avg:197.21ms +step:2/20000 train_loss:7.8500 train_time:285ms step_avg:142.45ms +step:3/20000 train_loss:7.1338 train_time:383ms step_avg:127.73ms +step:4/20000 train_loss:7.7812 train_time:481ms step_avg:120.33ms +step:5/20000 train_loss:8.1436 train_time:580ms step_avg:116.01ms +step:6/20000 train_loss:7.9960 train_time:680ms step_avg:113.29ms +step:7/20000 train_loss:7.6734 train_time:778ms step_avg:111.08ms +step:8/20000 train_loss:7.2128 train_time:875ms step_avg:109.43ms +step:9/20000 train_loss:6.5919 train_time:973ms step_avg:108.15ms +step:10/20000 train_loss:6.3320 train_time:1083ms step_avg:108.32ms +step:100/20000 train_loss:3.1827 train_time:9926ms step_avg:99.26ms +step:200/20000 train_loss:2.3904 train_time:19812ms step_avg:99.06ms +step:300/20000 train_loss:2.5369 train_time:29748ms step_avg:99.16ms +step:400/20000 train_loss:2.4103 train_time:39707ms step_avg:99.27ms +step:500/20000 train_loss:2.3940 train_time:49644ms step_avg:99.29ms +step:600/20000 train_loss:2.3301 train_time:59647ms step_avg:99.41ms +step:700/20000 train_loss:2.3469 train_time:69650ms step_avg:99.50ms +step:800/20000 train_loss:2.2375 train_time:79657ms step_avg:99.57ms +step:900/20000 train_loss:2.1259 train_time:89651ms step_avg:99.61ms +step:1000/20000 train_loss:2.2782 train_time:99606ms step_avg:99.61ms +step:1100/20000 train_loss:2.3212 train_time:109590ms step_avg:99.63ms +step:1200/20000 train_loss:2.3558 train_time:119589ms step_avg:99.66ms +step:1300/20000 train_loss:2.1057 train_time:129585ms step_avg:99.68ms +step:1400/20000 train_loss:2.1860 train_time:139575ms step_avg:99.70ms +step:1500/20000 train_loss:2.2268 train_time:149514ms step_avg:99.68ms +step:1600/20000 train_loss:2.0769 train_time:159512ms step_avg:99.69ms +step:1700/20000 train_loss:2.1489 train_time:169508ms step_avg:99.71ms +step:1800/20000 train_loss:2.1528 train_time:179508ms step_avg:99.73ms +step:1900/20000 train_loss:2.1279 train_time:189464ms step_avg:99.72ms +step:2000/20000 train_loss:2.0704 train_time:199447ms step_avg:99.72ms +step:2100/20000 train_loss:2.0534 train_time:209440ms step_avg:99.73ms +step:2200/20000 train_loss:2.1535 train_time:219426ms step_avg:99.74ms +step:2300/20000 train_loss:2.1110 train_time:229422ms step_avg:99.75ms +step:2400/20000 train_loss:2.0698 train_time:239342ms step_avg:99.73ms +step:2500/20000 train_loss:2.1724 train_time:249332ms step_avg:99.73ms +step:2600/20000 train_loss:2.1099 train_time:259325ms step_avg:99.74ms +step:2700/20000 train_loss:2.0988 train_time:269301ms step_avg:99.74ms +step:2800/20000 train_loss:2.1513 train_time:279286ms step_avg:99.75ms +step:2900/20000 train_loss:2.0195 train_time:289288ms step_avg:99.75ms +step:3000/20000 train_loss:2.1573 train_time:299265ms step_avg:99.75ms +step:3100/20000 train_loss:2.0244 train_time:309237ms step_avg:99.75ms +step:3200/20000 train_loss:2.1597 train_time:319210ms step_avg:99.75ms +step:3300/20000 train_loss:2.0578 train_time:329129ms step_avg:99.74ms +step:3400/20000 train_loss:2.0050 train_time:339104ms step_avg:99.74ms +step:3500/20000 train_loss:2.1617 train_time:349084ms step_avg:99.74ms +step:3600/20000 train_loss:2.0753 train_time:359048ms step_avg:99.74ms +step:3700/20000 train_loss:2.0756 train_time:369029ms step_avg:99.74ms +step:3800/20000 train_loss:2.0522 train_time:378955ms step_avg:99.73ms +step:3900/20000 train_loss:2.0542 train_time:388923ms step_avg:99.72ms +step:4000/20000 train_loss:1.9538 train_time:398905ms step_avg:99.73ms +step:4100/20000 train_loss:1.9919 train_time:408873ms step_avg:99.73ms +step:4200/20000 train_loss:2.1284 train_time:418857ms step_avg:99.73ms +step:4300/20000 train_loss:2.0319 train_time:428768ms step_avg:99.71ms +step:4400/20000 train_loss:2.0114 train_time:438736ms step_avg:99.71ms +step:4500/20000 train_loss:2.1011 train_time:448704ms step_avg:99.71ms +step:4600/20000 train_loss:1.8148 train_time:458677ms step_avg:99.71ms +step:4700/20000 train_loss:2.2173 train_time:468596ms step_avg:99.70ms +step:4800/20000 train_loss:2.4029 train_time:478553ms step_avg:99.70ms +step:4900/20000 train_loss:2.0207 train_time:488521ms step_avg:99.70ms +step:5000/20000 train_loss:2.0810 train_time:498486ms step_avg:99.70ms +step:5100/20000 train_loss:2.1044 train_time:508449ms step_avg:99.70ms +step:5200/20000 train_loss:2.0161 train_time:518362ms step_avg:99.69ms +step:5300/20000 train_loss:1.9831 train_time:528326ms step_avg:99.68ms +step:5400/20000 train_loss:2.0235 train_time:538292ms step_avg:99.68ms +step:5500/20000 train_loss:1.9933 train_time:548262ms step_avg:99.68ms +step:5600/20000 train_loss:1.9257 train_time:558233ms step_avg:99.68ms +step:5700/20000 train_loss:1.9870 train_time:568131ms step_avg:99.67ms +step:5800/20000 train_loss:1.9716 train_time:578101ms step_avg:99.67ms +step:5900/20000 train_loss:1.8747 train_time:588063ms step_avg:99.67ms +step:6000/20000 train_loss:1.9170 train_time:598026ms step_avg:99.67ms +step:6020/20000 val_loss:1.9546 val_bpb:1.1576 train_time:600016ms step_avg:99.67ms +stopping_early: wallclock_cap train_time:600016ms step:6020/20000 +peak memory allocated: 25196 MiB reserved: 25954 MiB +ema:applying shadow model +Serialized model: 96864150 bytes +Code size: 68444 bytes +Total submission size: 96932594 bytes +Serialized model int6+zstd: 15258786 bytes +Total submission size: 15327230 bytes (15.33 MB) +SIZE CHECK PASSED: 15.33 MB < 16.00 MB +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_ngram orders=2-7 alpha=0.4 entropy=True stride:64 +ngram_cache:enabled orders=2-7 backoff entropy=True alpha=0.4 ent_base=0.05 ent_range=0.55 min_count=2 buckets=4194304 + ngram_eval [ 10.6%] bpb=1.115965 t=26s + ngram_eval [ 21.2%] bpb=1.098010 t=40s + ngram_eval [ 31.8%] bpb=1.073715 t=54s + ngram_eval [ 42.3%] bpb=1.044457 t=68s + ngram_eval [ 52.9%] bpb=1.016426 t=82s + ngram_eval [ 63.5%] bpb=0.989973 t=96s + ngram_eval [ 74.0%] bpb=0.968386 t=110s + ngram_eval [ 84.6%] bpb=0.947327 t=124s + ngram_eval [ 95.2%] bpb=0.926810 t=138s + ngram_eval DONE: bpb=0.912061 tokens=62023616 t=157s +final_int8_zlib_roundtrip val_loss:1.5400 val_bpb:0.9121 eval_time:157534ms +final_int8_zlib_roundtrip_exact val_loss:1.53997207 val_bpb:0.91206062 diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/v6_seed42.log b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/v6_seed42.log new file mode 100644 index 000000000..cc2537ae5 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/v6_seed42.log @@ -0,0 +1,262 @@ +=== Parameter Golf V6 RunPod Setup === +Requirement already satisfied: sentencepiece in /usr/local/lib/python3.11/dist-packages (0.2.1) +Requirement already satisfied: zstandard in /usr/local/lib/python3.11/dist-packages (0.25.0) +Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.11/dist-packages (1.8.0) +Requirement already satisfied: filelock>=3.10.0 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (3.13.1) +Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (2024.2.0) +Requirement already satisfied: hf-xet<2.0.0,>=1.4.2 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (1.4.2) +Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (0.27.2) +Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (24.1) +Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (6.0.2) +Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (4.67.3) +Requirement already satisfied: typer in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (0.24.1) +Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (4.9.0) +Requirement already satisfied: anyio in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (4.6.0) +Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (2024.8.30) +Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (1.0.5) +Requirement already satisfied: idna in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (3.10) +Requirement already satisfied: sniffio in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (1.3.1) +Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface_hub) (0.14.0) +Requirement already satisfied: click>=8.2.1 in /usr/local/lib/python3.11/dist-packages (from typer->huggingface_hub) (8.3.1) +Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from typer->huggingface_hub) (1.5.4) +Requirement already satisfied: rich>=12.3.0 in /usr/local/lib/python3.11/dist-packages (from typer->huggingface_hub) (14.3.3) +Requirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.11/dist-packages (from typer->huggingface_hub) (0.0.4) +Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich>=12.3.0->typer->huggingface_hub) (4.0.0) +Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich>=12.3.0->typer->huggingface_hub) (2.18.0) +Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich>=12.3.0->typer->huggingface_hub) (0.1.2) +Data ready: 81 files +=== Mode: fast | Seed: 42 === +=== SMOKE TEST (60s) === +W0326 02:15:34.184000 125465632354944 torch/distributed/run.py:779] +W0326 02:15:34.184000 125465632354944 torch/distributed/run.py:779] ***************************************** +W0326 02:15:34.184000 125465632354944 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 02:15:34.184000 125465632354944 torch/distributed/run.py:779] ***************************************** +logs/fb64bf1e-299a-48cf-992f-496c2c98ba77.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:60.000 +seed:42 +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9311 train_time:155ms step_avg:155.17ms +step:2/20000 train_loss:7.8294 train_time:243ms step_avg:121.57ms +step:3/20000 train_loss:7.6012 train_time:341ms step_avg:113.70ms +step:4/20000 train_loss:7.1930 train_time:439ms step_avg:109.69ms +step:5/20000 train_loss:6.7663 train_time:537ms step_avg:107.39ms +step:6/20000 train_loss:6.4573 train_time:635ms step_avg:105.84ms +step:7/20000 train_loss:6.2066 train_time:733ms step_avg:104.72ms +step:8/20000 train_loss:6.0283 train_time:831ms step_avg:103.89ms +step:9/20000 train_loss:5.8627 train_time:929ms step_avg:103.22ms +step:10/20000 train_loss:5.7430 train_time:1039ms step_avg:103.88ms +step:100/20000 train_loss:3.5739 train_time:9853ms step_avg:98.53ms +step:200/20000 train_loss:2.7985 train_time:19730ms step_avg:98.65ms +step:300/20000 train_loss:2.8685 train_time:29616ms step_avg:98.72ms +step:400/20000 train_loss:2.7181 train_time:39554ms step_avg:98.88ms +step:500/20000 train_loss:2.6698 train_time:49457ms step_avg:98.91ms +step:600/20000 train_loss:2.6216 train_time:59424ms step_avg:99.04ms +step:606/20000 val_loss:2.7451 val_bpb:1.6258 train_time:60030ms step_avg:99.06ms +stopping_early: wallclock_cap train_time:60030ms step:606/20000 +peak memory allocated: 25387 MiB reserved: 26052 MiB +ema:applying shadow model +Serialized model: 96864150 bytes +Code size: 68444 bytes +Total submission size: 96932594 bytes +Serialized model int6+zstd: 15529045 bytes +Total submission size: 15597489 bytes (15.60 MB) +SIZE CHECK PASSED: 15.60 MB < 16.00 MB +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_window stride:64 batch_seqs:64 + sliding_eval [ 0.1%] 64/121136 windows running_bpb=2.760234 + sliding_eval [ 2.7%] 3264/121136 windows running_bpb=2.798400 + sliding_eval [ 5.3%] 6464/121136 windows running_bpb=2.856406 + sliding_eval [ 8.0%] 9664/121136 windows running_bpb=2.864834 + sliding_eval [ 10.6%] 12864/121136 windows running_bpb=2.848683 + sliding_eval [ 13.3%] 16064/121136 windows running_bpb=2.856259 + sliding_eval [ 15.9%] 19264/121136 windows running_bpb=2.845270 + sliding_eval [ 18.5%] 22464/121136 windows running_bpb=2.847447 + sliding_eval [ 21.2%] 25664/121136 windows running_bpb=2.855087 + sliding_eval [ 23.8%] 28864/121136 windows running_bpb=2.857381 + sliding_eval [ 26.5%] 32064/121136 windows running_bpb=2.861731 + sliding_eval [ 29.1%] 35264/121136 windows running_bpb=2.858092 + sliding_eval [ 31.8%] 38464/121136 windows running_bpb=2.857593 + sliding_eval [ 34.4%] 41664/121136 windows running_bpb=2.864121 + sliding_eval [ 37.0%] 44864/121136 windows running_bpb=2.867615 + sliding_eval [ 39.7%] 48064/121136 windows running_bpb=2.865379 + sliding_eval [ 42.3%] 51264/121136 windows running_bpb=2.867973 + sliding_eval [ 45.0%] 54464/121136 windows running_bpb=2.869055 + sliding_eval [ 47.6%] 57664/121136 windows running_bpb=2.872260 + sliding_eval [ 50.2%] 60864/121136 windows running_bpb=2.869337 + sliding_eval [ 52.9%] 64064/121136 windows running_bpb=2.868171 + sliding_eval [ 55.5%] 67264/121136 windows running_bpb=2.865998 + sliding_eval [ 58.2%] 70464/121136 windows running_bpb=2.863157 + sliding_eval [ 60.8%] 73664/121136 windows running_bpb=2.862934 + sliding_eval [ 63.5%] 76864/121136 windows running_bpb=2.862642 + sliding_eval [ 66.1%] 80064/121136 windows running_bpb=2.862865 + sliding_eval [ 68.7%] 83264/121136 windows running_bpb=2.865234 + sliding_eval [ 71.4%] 86464/121136 windows running_bpb=2.864276 + sliding_eval [ 74.0%] 89664/121136 windows running_bpb=2.864835 + sliding_eval [ 76.7%] 92864/121136 windows running_bpb=2.865814 + sliding_eval [ 79.3%] 96064/121136 windows running_bpb=2.866785 + sliding_eval [ 81.9%] 99264/121136 windows running_bpb=2.869917 + sliding_eval [ 84.6%] 102464/121136 windows running_bpb=2.870053 + sliding_eval [ 87.2%] 105664/121136 windows running_bpb=2.868611 + sliding_eval [ 89.9%] 108864/121136 windows running_bpb=2.869601 + sliding_eval [ 92.5%] 112064/121136 windows running_bpb=2.869285 + sliding_eval [ 95.2%] 115264/121136 windows running_bpb=2.870947 + sliding_eval [ 97.8%] 118464/121136 windows running_bpb=2.871991 +final_int8_zlib_roundtrip val_loss:4.7670 val_bpb:2.8233 eval_time:250733ms +final_int8_zlib_roundtrip_exact val_loss:4.76695835 val_bpb:2.82326872 +=== SMOKE PASSED — LAUNCHING FULL RUN === +W0326 02:23:19.177000 139504813302400 torch/distributed/run.py:779] +W0326 02:23:19.177000 139504813302400 torch/distributed/run.py:779] ***************************************** +W0326 02:23:19.177000 139504813302400 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 02:23:19.177000 139504813302400 torch/distributed/run.py:779] ***************************************** +logs/fab63796-c8ab-453f-a55b-6d0c22e51348.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9311 train_time:154ms step_avg:154.03ms +step:2/20000 train_loss:7.8294 train_time:241ms step_avg:120.54ms +step:3/20000 train_loss:7.2314 train_time:339ms step_avg:113.12ms +step:4/20000 train_loss:7.8870 train_time:438ms step_avg:109.48ms +step:5/20000 train_loss:7.9718 train_time:537ms step_avg:107.31ms +step:6/20000 train_loss:7.7965 train_time:636ms step_avg:105.97ms +step:7/20000 train_loss:7.4614 train_time:734ms step_avg:104.87ms +step:8/20000 train_loss:7.2162 train_time:832ms step_avg:104.01ms +step:9/20000 train_loss:6.8091 train_time:930ms step_avg:103.38ms +step:10/20000 train_loss:6.4127 train_time:1040ms step_avg:104.02ms +step:100/20000 train_loss:3.1737 train_time:9885ms step_avg:98.85ms +step:200/20000 train_loss:2.3673 train_time:19788ms step_avg:98.94ms +step:300/20000 train_loss:2.5429 train_time:29740ms step_avg:99.13ms +step:400/20000 train_loss:2.4084 train_time:39706ms step_avg:99.27ms +step:500/20000 train_loss:2.3997 train_time:49630ms step_avg:99.26ms +step:600/20000 train_loss:2.3387 train_time:59626ms step_avg:99.38ms +step:700/20000 train_loss:2.3448 train_time:69627ms step_avg:99.47ms +step:800/20000 train_loss:2.2368 train_time:79628ms step_avg:99.53ms +step:900/20000 train_loss:2.1275 train_time:89624ms step_avg:99.58ms +step:1000/20000 train_loss:2.2804 train_time:99545ms step_avg:99.55ms +step:1100/20000 train_loss:2.3267 train_time:109542ms step_avg:99.58ms +step:1200/20000 train_loss:2.3560 train_time:119526ms step_avg:99.60ms +step:1300/20000 train_loss:2.1035 train_time:129507ms step_avg:99.62ms +step:1400/20000 train_loss:2.1871 train_time:139489ms step_avg:99.63ms +step:1500/20000 train_loss:2.2271 train_time:149418ms step_avg:99.61ms +step:1600/20000 train_loss:2.0803 train_time:159401ms step_avg:99.63ms +step:1700/20000 train_loss:2.1484 train_time:169380ms step_avg:99.64ms +step:1800/20000 train_loss:2.1565 train_time:179365ms step_avg:99.65ms +step:1900/20000 train_loss:2.1295 train_time:189290ms step_avg:99.63ms +step:2000/20000 train_loss:2.0741 train_time:199281ms step_avg:99.64ms +step:2100/20000 train_loss:2.0525 train_time:209261ms step_avg:99.65ms +step:2200/20000 train_loss:2.1768 train_time:219233ms step_avg:99.65ms +step:2300/20000 train_loss:2.1123 train_time:229213ms step_avg:99.66ms +step:2400/20000 train_loss:2.0732 train_time:239129ms step_avg:99.64ms +step:2500/20000 train_loss:2.1744 train_time:249107ms step_avg:99.64ms +step:2600/20000 train_loss:2.1134 train_time:259082ms step_avg:99.65ms +step:2700/20000 train_loss:2.1019 train_time:269050ms step_avg:99.65ms +step:2800/20000 train_loss:2.1543 train_time:279024ms step_avg:99.65ms +step:2900/20000 train_loss:2.0209 train_time:288931ms step_avg:99.63ms +step:3000/20000 train_loss:2.1559 train_time:298908ms step_avg:99.64ms +step:3100/20000 train_loss:2.0257 train_time:308889ms step_avg:99.64ms +step:3200/20000 train_loss:2.1604 train_time:318871ms step_avg:99.65ms +step:3300/20000 train_loss:2.0583 train_time:328862ms step_avg:99.66ms +step:3400/20000 train_loss:2.0056 train_time:338838ms step_avg:99.66ms +step:3500/20000 train_loss:2.1597 train_time:348807ms step_avg:99.66ms +step:3600/20000 train_loss:2.0758 train_time:358789ms step_avg:99.66ms +step:3700/20000 train_loss:2.0777 train_time:368752ms step_avg:99.66ms +step:3800/20000 train_loss:2.0524 train_time:378652ms step_avg:99.65ms +step:3900/20000 train_loss:2.0557 train_time:388620ms step_avg:99.65ms +step:4000/20000 train_loss:1.9542 train_time:398579ms step_avg:99.64ms +step:4100/20000 train_loss:1.9897 train_time:408557ms step_avg:99.65ms +step:4200/20000 train_loss:2.1255 train_time:418528ms step_avg:99.65ms +step:4300/20000 train_loss:2.0382 train_time:428435ms step_avg:99.64ms +step:4400/20000 train_loss:2.0127 train_time:438410ms step_avg:99.64ms +step:4500/20000 train_loss:2.1025 train_time:448377ms step_avg:99.64ms +step:4600/20000 train_loss:1.8174 train_time:458345ms step_avg:99.64ms +step:4700/20000 train_loss:2.2110 train_time:468250ms step_avg:99.63ms +step:4800/20000 train_loss:2.4039 train_time:478225ms step_avg:99.63ms +step:4900/20000 train_loss:2.0271 train_time:488192ms step_avg:99.63ms +step:5000/20000 train_loss:2.0833 train_time:498167ms step_avg:99.63ms +step:5100/20000 train_loss:2.1044 train_time:508132ms step_avg:99.63ms +step:5200/20000 train_loss:2.0173 train_time:518036ms step_avg:99.62ms +step:5300/20000 train_loss:1.9812 train_time:528005ms step_avg:99.62ms +step:5400/20000 train_loss:2.0219 train_time:537963ms step_avg:99.62ms +step:5500/20000 train_loss:1.9927 train_time:547926ms step_avg:99.62ms +step:5600/20000 train_loss:1.9309 train_time:557898ms step_avg:99.62ms +step:5700/20000 train_loss:1.9887 train_time:567811ms step_avg:99.62ms +step:5800/20000 train_loss:1.9694 train_time:577779ms step_avg:99.62ms +step:5900/20000 train_loss:1.8751 train_time:587745ms step_avg:99.62ms +step:6000/20000 train_loss:1.9175 train_time:597708ms step_avg:99.62ms +step:6023/20000 val_loss:1.9552 val_bpb:1.1580 train_time:599992ms step_avg:99.62ms +stopping_early: wallclock_cap train_time:599992ms step:6023/20000 +peak memory allocated: 25196 MiB reserved: 25954 MiB +ema:applying shadow model +Serialized model: 96864150 bytes +Code size: 68444 bytes +Total submission size: 96932594 bytes +Serialized model int6+zstd: 15247350 bytes +Total submission size: 15315794 bytes (15.32 MB) +SIZE CHECK PASSED: 15.32 MB < 16.00 MB +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_ngram orders=2-7 alpha=0.4 entropy=True stride:64 +ngram_cache:enabled orders=2-7 backoff entropy=True alpha=0.4 ent_base=0.05 ent_range=0.55 min_count=2 buckets=4194304 + ngram_eval [ 10.6%] bpb=1.115739 t=29s + ngram_eval [ 21.2%] bpb=1.097598 t=44s + ngram_eval [ 31.8%] bpb=1.073534 t=58s + ngram_eval [ 42.3%] bpb=1.044492 t=72s + ngram_eval [ 52.9%] bpb=1.016678 t=86s + ngram_eval [ 63.5%] bpb=0.990279 t=100s + ngram_eval [ 74.0%] bpb=0.968857 t=114s + ngram_eval [ 84.6%] bpb=0.947937 t=128s + ngram_eval [ 95.2%] bpb=0.927495 t=142s + ngram_eval DONE: bpb=0.912769 tokens=62023616 t=163s +final_int8_zlib_roundtrip val_loss:1.5412 val_bpb:0.9128 eval_time:162854ms +final_int8_zlib_roundtrip_exact val_loss:1.54116746 val_bpb:0.91276859 +=== Done ===