diff --git a/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/README.md b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/README.md new file mode 100644 index 000000000..ede692f33 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/README.md @@ -0,0 +1,47 @@ +# QAT + Neural Cache + LoRA TTT (Non-Record Submission) + +**val_bpb: 1.4245** (sliding window, post int5/int6+zstd quantization roundtrip, 1 seed) + +This is a non-record submission exploring three eval-time techniques stacked on the current #1 training recipe. The QAT implementation has a bug (quantization penalty is ~0.25 BPB instead of expected ~0.02), making this run non-competitive. Submitting for transparency and to document the approach for iteration. + +## Approach + +Built on PR by @thwu1 (Int5-MLP + BigramHash + SWA), adding: + +### 1. Quantization-Aware Training (QAT) +STE fake-quantization during training: int5 (clip=15) for MLP layers, int6 (clip=31) for attention. The model learns to be robust to quantization noise. **Bug found:** The STE uses symmetric clipping while the export uses percentile-based per-row scaling — this mismatch caused the model to optimize for the wrong quantization target, resulting in a 0.25 BPB penalty instead of the expected ~0.02. + +### 2. Neural Cache +During sliding window eval, maintain a ring buffer of pre-lm_head hidden states (dim=512, bf16). For each token, compute cosine similarity against cached states, build a cache distribution via softmax-weighted scatter, and interpolate with model predictions using logaddexp. Causal token-by-token scoring with document boundary resets prevents information leakage. + +### 3. LoRA Test-Time Training +Per-document rank-8 LoRA adaptation on lm_head, Q, and V projections during evaluation. Documents batched (batch_size=64), chunks scored before training (no leakage), with entropy-gated updates. + +## Architecture +- 10 layers, 512 dim, 8 heads, 4 KV heads (GQA), 3x MLP (1536 hidden) +- BigramHash(10240, dim=128), SmearGate, orthogonal init +- Muon optimizer: matrix_lr=0.02, WD=0.04, momentum=0.99 +- SWA: last 40% of warmdown, every 50 steps, 24 checkpoints averaged +- seq_len=2048, batch=786K tokens + +## Results +| Seed | Pre-quant val_bpb | Post-quant sliding val_bpb | Steps | Artifact | +|------|-------------------|---------------------------|-------|----------| +| 1337 | 1.1739 | 1.4245 | 5109 | 15.77 MB | + +## Known Issues +1. **QAT mismatch:** STE clip ranges don't match export quantization format — needs per-row percentile clipping in the STE to match `quantize_intN_per_row` +2. **Pre-quant BPB already worse than SOTA:** 1.1739 vs 1.1428 — QAT may be hurting convergence with current hyperparameters +3. Only 1 seed (need 3+ for statistical significance) + +## Next Steps +- Run without QAT to verify base recipe reproduces 1.1428 +- Fix QAT to match exact export quantization format +- Run neural cache + TTT eval on a working checkpoint +- Sweep cache hyperparameters (theta, lambda) + +## Command +```bash +RUN_ID=run1_seed1337 SEED=1337 QAT_ENABLED=1 EVAL_STRIDE=64 EVAL_STRATEGY=combined \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/submission.json b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/submission.json new file mode 100644 index 000000000..7f01df4dc --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/submission.json @@ -0,0 +1,9 @@ +{ + "name": "QAT + Neural Cache + LoRA TTT (non-record)", + "val_bpb": 1.4245, + "bytes_total": 15766801, + "blurb": "Non-record submission exploring QAT + neural cache + LoRA TTT on top of #1 recipe. QAT implementation has export mismatch bug causing 0.25 BPB quantization penalty. Submitting to document approach and iterate.", + "author": "Bortlesboat", + "github_id": "Bortlesboat", + "date": "2026-03-20" +} diff --git a/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py new file mode 100644 index 000000000..3af662bfc --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py @@ -0,0 +1,1709 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Test-time training (LoRA) hyperparameters. + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_entropy_gate = float(os.environ.get("TTT_ENTROPY_GATE", 0.0)) + + # Neural cache hyperparameters. + cache_size = int(os.environ.get("CACHE_SIZE", 2048)) + cache_theta = float(os.environ.get("CACHE_THETA", 5.0)) + cache_lambda = float(os.environ.get("CACHE_LAMBDA", 0.05)) + + # QAT (Quantization-Aware Training) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "1"))) + + # Eval strategy + eval_strategy = os.environ.get("EVAL_STRATEGY", "sliding") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # QAT fields: set per-module after model construction + _qat_clip: int = 0 # 0=disabled, 15=int5, 31=int6 + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self._qat_clip > 0 and self.training and w.ndim == 2: + # STE fake quantization matching int5/int6 export format + w_f = w.float() + clip = self._qat_clip + amax = w_f.abs().amax(dim=-1, keepdim=True).clamp_min(1e-12) + scale = amax / clip + w_q = (torch.clamp(torch.round(w_f / scale), -(clip + 1), clip) * scale) + w = w + (w_q - w_f).detach() # Straight-through estimator + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Expand KV heads for GQA compatibility with older PyTorch + if self.num_kv_heads != self.num_heads: + reps = self.num_heads // self.num_kv_heads + k = k[:, :, None, :, :].expand(-1, -1, reps, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim) + v = v[:, :, None, :, :].expand(-1, -1, reps, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + qd = lora.q_loras[i] if lora else None + vd = lora.v_loras[i] if lora else None + x = self.blocks[i](x, x0, qd, vd) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd = lora.q_loras[bi] if lora else None + vd = lora.v_loras[bi] if lora else None + x = self.blocks[bi](x, x0, qd, vd) + x_norm = self.final_norm(x) + x_flat = x_norm.reshape(-1, x_norm.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits_proj = logits_proj + (lora.lm_head_lora(x_norm).reshape(-1, logits_proj.size(-1)) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if lora: + bsz, sl = input_ids.shape + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(bsz, sl) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward_logits_and_hidden(self, input_ids: Tensor) -> tuple[Tensor, Tensor]: + """Forward pass returning (logits, hidden_states). For neural cache eval.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + hidden = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(hidden, self.tok_emb.weight.to(hidden.dtype)) + else: + logits = self.lm_head(hidden) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + return logits, hidden + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_with_cache( + logits_hidden_fn, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + seq_len: int, + stride: int, + cache_size: int = 2048, + cache_theta: float = 5.0, + cache_lambda: float = 0.05, + eval_batch_seqs: int = 64, +) -> tuple[float, float]: + """Sliding window eval with neural cache interpolation.""" + total = val_tokens.numel() - 1 + + # Build windows + windows: list[tuple[int, int]] = [] + p = 0 + while p + seq_len <= total: + s = 0 if p == 0 else (seq_len - stride) + windows.append((p, s)) + p += stride + + n = len(windows) + per_rank = (n + world_size - 1) // world_size + my_start = rank * per_rank + my_end = min(my_start + per_rank, n) + my_windows = windows[my_start:my_end] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Neural cache ring buffer (on GPU for fast lookup) + cache_keys: Tensor | None = None # [cache_size, model_dim] + cache_vals: Tensor | None = None # [cache_size] + cache_len = 0 + cache_ptr = 0 + cache_initialized = False + + with torch.inference_mode(): + for i in range(0, len(my_windows), eval_batch_seqs): + batch = my_windows[i : i + eval_batch_seqs] + bs = len(batch) + + x_list = [val_tokens[w : w + seq_len] for w, _ in batch] + y_list = [val_tokens[w + 1 : w + seq_len + 1] for w, _ in batch] + pad = eval_batch_seqs - bs + if pad > 0: + x_list.extend([x_list[-1]] * pad) + y_list.extend([y_list[-1]] * pad) + + x = torch.stack(x_list).to(device=device, dtype=torch.int64) + y = torch.stack(y_list).to(device=device, dtype=torch.int64) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits, hidden = logits_hidden_fn(x) + + # Initialize cache on first pass + if not cache_initialized: + model_dim = hidden.shape[-1] + cache_keys = torch.zeros(cache_size, model_dim, device=device, dtype=torch.bfloat16) + cache_vals = torch.zeros(cache_size, device=device, dtype=torch.long) + cache_initialized = True + + for b in range(bs): + s = batch[b][1] + scored_logits = logits[b, s:].float() # [num_scored, vocab] + scored_targets = y[b, s:] + scored_hidden = hidden[b, s:] # [num_scored, model_dim] + scored_input = x[b, s : s + scored_targets.numel()] + ns = scored_targets.numel() + vocab_size = scored_logits.shape[-1] + + log_probs_model = F.log_softmax(scored_logits, dim=-1) + + # Token-by-token scoring + cache update (causal: each token + # sees only cache entries from BEFORE it, and BOS resets are + # applied before scoring the token that follows a BOS). + for t_idx in range(ns): + # Reset cache at document boundaries BEFORE scoring + if is_boundary_token_lut[scored_input[t_idx]]: + cache_len = 0 + cache_ptr = 0 + + target_id = scored_targets[t_idx] + + if cache_len > 0: + active_len = min(cache_len, cache_size) + keys = cache_keys[:active_len] + vals = cache_vals[:active_len] + h = scored_hidden[t_idx].to(torch.bfloat16).unsqueeze(0) + h_norm = F.normalize(h, dim=-1) + k_norm = F.normalize(keys, dim=-1) + sim = (h_norm.float() @ k_norm.float().T).squeeze(0) + cache_attn = F.softmax(cache_theta * sim, dim=-1) + cache_probs = torch.zeros(vocab_size, device=device) + cache_probs.scatter_add_(0, vals, cache_attn) + log_cache = torch.log(cache_probs + 1e-10) + log_final_t = torch.logaddexp( + math.log(1 - cache_lambda) + log_probs_model[t_idx], + math.log(cache_lambda) + log_cache, + ) + else: + log_final_t = log_probs_model[t_idx] + + loss_sum += -log_final_t[target_id].to(torch.float64) + tok_count += 1 + + # Byte counting + prev_id = scored_input[t_idx] + tb = base_bytes_lut[target_id].to(torch.float64) + if has_leading_space_lut[target_id] and not is_boundary_token_lut[prev_id]: + tb += 1.0 + byte_count += tb + + # Update cache AFTER scoring (causal) + idx = cache_ptr % cache_size + cache_keys[idx] = scored_hidden[t_idx].to(torch.bfloat16) + cache_vals[idx] = target_id + cache_ptr += 1 + cache_len += 1 + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(tok_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / tok_count).item() + bpb = val_loss / math.log(2.0) * (tok_count.item() / byte_count.item()) + return val_loss, bpb + + +# ----------------------------- +# TEST-TIME TRAINING (LoRA) +# ----------------------------- + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """LoRA for a linear layer, with independent weights per batch element.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) + self.B.zero_() + + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + + +def _build_ttt_optimizer(lora, args: Hyperparameters): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + + +def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document.""" + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl: Tensor, x: Tensor, y: Tensor, + batch_i: int, chunk_offset: int, chunk_len: int, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, +): + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Evaluate with batched LoRA test-time training.""" + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank = args.ttt_lora_rank + + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1] + toks = chunk.to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb( + ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + # Entropy gating: only train on chunks where model is surprised + if args.ttt_entropy_gate > 0: + with torch.no_grad(): + gate = (per_doc > args.ttt_entropy_gate).float() + mask = mask * gate + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + local_test = bool(int(os.environ.get("LOCAL_TEST", "0"))) + use_compile = not local_test + if use_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(local_test) + enable_math_sdp(local_test) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + + # Enable QAT: int5 for MLP weights, int6 for attention weights + if args.qat_enabled: + for name, module in base_model.named_modules(): + if isinstance(module, CastedLinear): + if ".mlp." in name: + module._qat_clip = 15 # int5 + elif ".attn." in name: + module._qat_clip = 31 # int6 + # Embeddings, bigram proj, skip weights: no QAT (kept in fp16/fp32) + log0(f"QAT enabled: int5 for MLP, int6 for attention") + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if use_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + # Eval on int6-roundtripped weights + val_tokens_eval = val_tokens + torch.cuda.synchronize() + t_qeval = time.perf_counter() + + strategy = args.eval_strategy + log0(f"eval_strategy:{strategy}") + + if strategy == "sliding_cache": + log0("Running sliding window + neural cache eval...") + q_val_loss, q_val_bpb = eval_val_sliding_with_cache( + base_model.forward_logits_and_hidden, rank, world_size, device, + val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + args.train_seq_len, args.eval_stride, + cache_size=args.cache_size, cache_theta=args.cache_theta, + cache_lambda=args.cache_lambda, eval_batch_seqs=args.eval_batch_seqs, + ) + elif strategy == "ttt": + log0("Running LoRA TTT eval...") + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + elif strategy == "combined": + log0("[combined] Phase 1: Sliding window + neural cache...") + cache_val_loss, cache_val_bpb = eval_val_sliding_with_cache( + base_model.forward_logits_and_hidden, rank, world_size, device, + val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + args.train_seq_len, args.eval_stride, + cache_size=args.cache_size, cache_theta=args.cache_theta, + cache_lambda=args.cache_lambda, eval_batch_seqs=args.eval_batch_seqs, + ) + log0(f"[combined] sliding+cache val_bpb:{cache_val_bpb:.4f}") + + log0("[combined] Phase 2: LoRA TTT...") + torch._dynamo.reset() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0(f"[combined] ttt val_bpb:{ttt_val_bpb:.4f}") + + q_val_bpb = min(cache_val_bpb, ttt_val_bpb) + q_val_loss = cache_val_loss if cache_val_bpb <= ttt_val_bpb else ttt_val_loss + log0(f"[combined] best individual val_bpb:{q_val_bpb:.4f}") + else: + # Default: sliding window only (original behavior) + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_seed1337.log b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_seed1337.log new file mode 100644 index 000000000..001332150 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_seed1337.log @@ -0,0 +1,135 @@ +W0321 01:46:21.203000 128328209879680 torch/distributed/run.py:779] +W0321 01:46:21.203000 128328209879680 torch/distributed/run.py:779] ***************************************** +W0321 01:46:21.203000 128328209879680 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0321 01:46:21.203000 128328209879680 torch/distributed/run.py:779] ***************************************** +logs/run1_seed1337.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +QAT enabled: int5 for MLP, int6 for attention +model_params:25517137 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9284 val_bpb:4.1034 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9301 train_time:142ms step_avg:141.56ms +step:2/20000 train_loss:7.6390 train_time:227ms step_avg:113.40ms +step:3/20000 train_loss:7.2632 train_time:329ms step_avg:109.78ms +step:4/20000 train_loss:8.0002 train_time:431ms step_avg:107.65ms +step:5/20000 train_loss:8.3939 train_time:532ms step_avg:106.33ms +step:6/20000 train_loss:8.2110 train_time:634ms step_avg:105.71ms +step:7/20000 train_loss:7.5933 train_time:737ms step_avg:105.26ms +step:8/20000 train_loss:6.8652 train_time:848ms step_avg:106.01ms +step:9/20000 train_loss:6.3185 train_time:950ms step_avg:105.56ms +step:10/20000 train_loss:6.0389 train_time:1051ms step_avg:105.15ms +step:100/20000 train_loss:3.1871 train_time:9904ms step_avg:99.04ms +step:200/20000 train_loss:2.4389 train_time:22606ms step_avg:113.03ms +step:300/20000 train_loss:2.5740 train_time:35201ms step_avg:117.34ms +step:400/20000 train_loss:2.4384 train_time:47337ms step_avg:118.34ms +step:500/20000 train_loss:2.4215 train_time:57243ms step_avg:114.49ms +step:500/20000 val_loss:2.3730 val_bpb:1.4054 train_time:57275ms step_avg:114.55ms +step:600/20000 train_loss:2.3558 train_time:69828ms step_avg:116.38ms +step:700/20000 train_loss:2.3646 train_time:82162ms step_avg:117.37ms +step:800/20000 train_loss:2.2581 train_time:94412ms step_avg:118.01ms +step:900/20000 train_loss:2.1456 train_time:106569ms step_avg:118.41ms +step:1000/20000 train_loss:2.2949 train_time:116466ms step_avg:116.47ms +step:1000/20000 val_loss:2.2417 val_bpb:1.3277 train_time:116497ms step_avg:116.50ms +step:1100/20000 train_loss:2.3445 train_time:128599ms step_avg:116.91ms +step:1200/20000 train_loss:2.3735 train_time:140813ms step_avg:117.34ms +step:1300/20000 train_loss:2.1190 train_time:153410ms step_avg:118.01ms +step:1400/20000 train_loss:2.2040 train_time:165709ms step_avg:118.36ms +step:1500/20000 train_loss:2.2390 train_time:175586ms step_avg:117.06ms +step:1500/20000 val_loss:2.1996 val_bpb:1.3027 train_time:175622ms step_avg:117.08ms +step:1600/20000 train_loss:2.0922 train_time:187806ms step_avg:117.38ms +step:1700/20000 train_loss:2.1616 train_time:199863ms step_avg:117.57ms +step:1800/20000 train_loss:2.1825 train_time:212078ms step_avg:117.82ms +step:1900/20000 train_loss:2.1487 train_time:221958ms step_avg:116.82ms +step:2000/20000 train_loss:2.0841 train_time:234240ms step_avg:117.12ms +step:2000/20000 val_loss:2.1478 val_bpb:1.2721 train_time:234270ms step_avg:117.14ms +step:2100/20000 train_loss:2.0660 train_time:246613ms step_avg:117.43ms +step:2200/20000 train_loss:2.1503 train_time:258634ms step_avg:117.56ms +step:2300/20000 train_loss:2.1247 train_time:270977ms step_avg:117.82ms +step:2400/20000 train_loss:2.0777 train_time:280850ms step_avg:117.02ms +step:2500/20000 train_loss:2.1808 train_time:292995ms step_avg:117.20ms +step:2500/20000 val_loss:2.1119 val_bpb:1.2508 train_time:293027ms step_avg:117.21ms +step:2600/20000 train_loss:2.1129 train_time:305130ms step_avg:117.36ms +step:2700/20000 train_loss:2.1044 train_time:317404ms step_avg:117.56ms +step:2800/20000 train_loss:2.1562 train_time:329614ms step_avg:117.72ms +step:2900/20000 train_loss:2.0199 train_time:339464ms step_avg:117.06ms +step:3000/20000 train_loss:2.1550 train_time:351572ms step_avg:117.19ms +step:3000/20000 val_loss:2.0833 val_bpb:1.2339 train_time:351604ms step_avg:117.20ms +step:3100/20000 train_loss:2.0313 train_time:363827ms step_avg:117.36ms +step:3200/20000 train_loss:2.1638 train_time:375880ms step_avg:117.46ms +step:3300/20000 train_loss:2.0580 train_time:385737ms step_avg:116.89ms +step:3400/20000 train_loss:2.0075 train_time:397920ms step_avg:117.04ms +step:3500/20000 train_loss:2.1618 train_time:409939ms step_avg:117.13ms +step:3500/20000 val_loss:2.0608 val_bpb:1.2205 train_time:409971ms step_avg:117.13ms +step:3600/20000 train_loss:2.0780 train_time:422156ms step_avg:117.27ms +step:3700/20000 train_loss:2.0675 train_time:434492ms step_avg:117.43ms +step:3800/20000 train_loss:2.0481 train_time:444348ms step_avg:116.93ms +step:3900/20000 train_loss:2.0536 train_time:456623ms step_avg:117.08ms +swa:start step:3950 +step:4000/20000 train_loss:1.9516 train_time:469061ms step_avg:117.27ms +step:4000/20000 val_loss:2.0358 val_bpb:1.2057 train_time:469125ms step_avg:117.28ms +step:4100/20000 train_loss:1.9851 train_time:481241ms step_avg:117.38ms +step:4200/20000 train_loss:2.1214 train_time:493527ms step_avg:117.51ms +step:4300/20000 train_loss:2.0258 train_time:503439ms step_avg:117.08ms +step:4400/20000 train_loss:2.0011 train_time:515722ms step_avg:117.21ms +step:4500/20000 train_loss:2.0881 train_time:527915ms step_avg:117.31ms +step:4500/20000 val_loss:2.0097 val_bpb:1.1903 train_time:527979ms step_avg:117.33ms +step:4600/20000 train_loss:1.8090 train_time:540172ms step_avg:117.43ms +step:4700/20000 train_loss:2.2016 train_time:550086ms step_avg:117.04ms +step:4800/20000 train_loss:2.3969 train_time:562216ms step_avg:117.13ms +step:4900/20000 train_loss:2.0086 train_time:574556ms step_avg:117.26ms +step:5000/20000 train_loss:2.0654 train_time:586649ms step_avg:117.33ms +step:5000/20000 val_loss:1.9846 val_bpb:1.1754 train_time:586715ms step_avg:117.34ms +step:5100/20000 train_loss:2.0883 train_time:599026ms step_avg:117.46ms +step:5109/20000 val_loss:1.9821 val_bpb:1.1739 train_time:599967ms step_avg:117.43ms +stopping_early: wallclock_cap train_time:599967ms step:5109/20000 +peak memory allocated: 23856 MiB reserved: 24366 MiB +swa:applying averaged 24 checkpoints +Serialized model: 98437014 bytes +Code size: 73542 bytes +Total submission size: 98510556 bytes +Serialized model int6+zstd: 15691550 bytes +Total submission size int8+zlib: 15765092 bytes +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +eval_strategy:combined +[combined] Phase 1: Sliding window + neural cache... diff --git a/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/README.md b/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/README.md new file mode 100644 index 000000000..3d22b6b48 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/README.md @@ -0,0 +1,53 @@ +# 10L Int5-MLP + BigramHash(4096) + SWA + +**val_bpb: 1.1507** (mean of 3 seeds, sliding window stride=64, post int5/int6+zstd quantization roundtrip) + +## Results + +| Seed | val_bpb | artifact_bytes | +|------|---------|----------------| +| 42 | 1.1508 | 15,620,994 | +| 1337 | 1.1499 | 15,290,882 | +| 2024 | 1.1514 | 15,327,813 | +| **Mean** | **1.1507 +/- 0.0006** | | + +## Architecture + +- 10 layers, d=512, 8 heads, 4 KV heads (GQA) +- MLP: 3x expansion (1536), relu^2 activation +- BigramHash: 4096 buckets, 128-dim projection +- SmearGate (learned previous-token blending) +- U-Net skip connections with learned gates +- RoPE (base=10000), logit softcap=30.0 +- Tied embeddings + +## Training + +- Muon optimizer (matrices) + AdamW (embeddings/scalars) +- WD=0.04, warmdown=3000 steps +- SWA: start_frac=0.4, every=50 steps +- Wallclock cap: 600s on 8xH100 (~6200 steps) +- Batch: 786,432 tokens, seq_len=2048 + +## Quantization + +- Int5 per-row for MLP weights (clip_range=15) +- Int6 per-row for attention weights (clip_range=31) +- FP16 passthrough for small/control tensors +- Magnitude pruning (3% threshold) before quantization +- zstd-22 compression + +## Evaluation + +- Sliding window eval, stride=64, batch_seqs=32 +- ~258s eval time on 8xH100 + +## Based on + +- thwu1's 10L Int5-MLP submission (1.1428 BPB) with reduced BigramHash for size margin + +## Reproduce + +```bash +SEED=42 torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/submission.json b/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/submission.json new file mode 100644 index 000000000..acc4a7131 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/submission.json @@ -0,0 +1,10 @@ +{ + "author": "Bortlesboat", + "github_id": "Bortlesboat", + "name": "10L Int5-MLP + BigramHash(4096) + SWA", + "blurb": "10 layers, d=512, GQA 8H/4KV. Mixed int5/int6 quantization + zstd-22. BigramHash(4096, dim=128). SmearGate + OrthoInit. SWA(frac=0.4). Muon WD=0.04, warmdown=3000. Sliding window eval stride=64. Mean of 3 seeds.", + "date": "2026-03-25", + "val_loss": 1.9429, + "val_bpb": 1.1507, + "bytes_total": 15620994 +} diff --git a/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/train_gpt.py b/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/train_gpt.py new file mode 100644 index 000000000..47d6ccbad --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/train_gpt.py @@ -0,0 +1,1250 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if self.num_kv_heads != self.num_heads: + n_rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + eval_start = time.perf_counter() + eval_budget_s = 570.0 # 30s margin from 10-min eval budget + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + eval_elapsed = time.perf_counter() - eval_start + if eval_elapsed > eval_budget_s: + if rank == 0: + print(f" FAILSAFE: eval time {eval_elapsed:.0f}s exceeds {eval_budget_s}s budget, returning partial results", flush=True) + break + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + try: + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + except ImportError: + pass + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + torch._dynamo.config.optimize_ddp = False + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + total_bytes = quant_file_bytes + code_bytes + log0(f"Total submission size: {total_bytes} bytes ({total_bytes/1e6:.2f} MB)") + if total_bytes > 16_000_000: + log0(f"FAILSAFE: artifact {total_bytes} bytes EXCEEDS 16MB limit! Aborting eval.") + sys.exit(1) + log0(f"SIZE CHECK PASSED: {total_bytes/1e6:.2f} MB < 16.00 MB") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/train_seed1337.log b/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/train_seed1337.log new file mode 100644 index 000000000..af3c9f493 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/train_seed1337.log @@ -0,0 +1,224 @@ +W0325 07:16:15.605000 130575319020160 torch/distributed/run.py:779] +W0325 07:16:15.605000 130575319020160 torch/distributed/run.py:779] ***************************************** +W0325 07:16:15.605000 130575319020160 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 07:16:15.605000 130575319020160 torch/distributed/run.py:779] ***************************************** +logs/20dd3955-5bac-461f-9e3b-a050655d8025.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9275 val_bpb:4.1028 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9296 train_time:149ms step_avg:148.54ms +step:2/20000 train_loss:7.6884 train_time:233ms step_avg:116.34ms +step:3/20000 train_loss:7.2759 train_time:328ms step_avg:109.23ms +step:4/20000 train_loss:7.9423 train_time:424ms step_avg:105.95ms +step:5/20000 train_loss:8.3472 train_time:519ms step_avg:103.85ms +step:6/20000 train_loss:8.1676 train_time:614ms step_avg:102.35ms +step:7/20000 train_loss:7.5226 train_time:709ms step_avg:101.35ms +step:8/20000 train_loss:6.8361 train_time:805ms step_avg:100.64ms +step:9/20000 train_loss:6.5260 train_time:900ms step_avg:100.01ms +step:10/20000 train_loss:6.2617 train_time:996ms step_avg:99.60ms +step:100/20000 train_loss:3.2133 train_time:9616ms step_avg:96.16ms +step:200/20000 train_loss:2.4273 train_time:19264ms step_avg:96.32ms +step:300/20000 train_loss:2.5783 train_time:28935ms step_avg:96.45ms +step:400/20000 train_loss:2.4390 train_time:38629ms step_avg:96.57ms +step:500/20000 train_loss:2.4157 train_time:48293ms step_avg:96.59ms +step:500/20000 val_loss:2.3782 val_bpb:1.4085 train_time:48304ms step_avg:96.61ms +step:600/20000 train_loss:2.3519 train_time:58036ms step_avg:96.73ms +step:700/20000 train_loss:2.3673 train_time:67783ms step_avg:96.83ms +step:800/20000 train_loss:2.2538 train_time:77525ms step_avg:96.91ms +step:900/20000 train_loss:2.1454 train_time:87273ms step_avg:96.97ms +step:1000/20000 train_loss:2.2913 train_time:96959ms step_avg:96.96ms +step:1000/20000 val_loss:2.2449 val_bpb:1.3295 train_time:96971ms step_avg:96.97ms +step:1100/20000 train_loss:2.3476 train_time:106707ms step_avg:97.01ms +step:1200/20000 train_loss:2.3674 train_time:116443ms step_avg:97.04ms +step:1300/20000 train_loss:2.1193 train_time:126268ms step_avg:97.13ms +step:1400/20000 train_loss:2.2023 train_time:136006ms step_avg:97.15ms +step:1500/20000 train_loss:2.2397 train_time:145670ms step_avg:97.11ms +step:1500/20000 val_loss:2.2002 val_bpb:1.3031 train_time:145681ms step_avg:97.12ms +step:1600/20000 train_loss:2.0923 train_time:155383ms step_avg:97.11ms +step:1700/20000 train_loss:2.1601 train_time:165101ms step_avg:97.12ms +step:1800/20000 train_loss:2.1761 train_time:174815ms step_avg:97.12ms +step:1900/20000 train_loss:2.1475 train_time:184461ms step_avg:97.08ms +step:2000/20000 train_loss:2.0845 train_time:194162ms step_avg:97.08ms +step:2000/20000 val_loss:2.1466 val_bpb:1.2713 train_time:194173ms step_avg:97.09ms +step:2100/20000 train_loss:2.0628 train_time:203854ms step_avg:97.07ms +step:2200/20000 train_loss:2.1442 train_time:213548ms step_avg:97.07ms +step:2300/20000 train_loss:2.1195 train_time:223233ms step_avg:97.06ms +step:2400/20000 train_loss:2.0791 train_time:232858ms step_avg:97.02ms +step:2500/20000 train_loss:2.1789 train_time:242535ms step_avg:97.01ms +step:2500/20000 val_loss:2.1176 val_bpb:1.2542 train_time:242546ms step_avg:97.02ms +step:2600/20000 train_loss:2.1207 train_time:252231ms step_avg:97.01ms +step:2700/20000 train_loss:2.1130 train_time:261912ms step_avg:97.00ms +step:2800/20000 train_loss:2.1604 train_time:271584ms step_avg:96.99ms +step:2900/20000 train_loss:2.0358 train_time:281206ms step_avg:96.97ms +step:3000/20000 train_loss:2.1692 train_time:290895ms step_avg:96.97ms +step:3000/20000 val_loss:2.1023 val_bpb:1.2451 train_time:290906ms step_avg:96.97ms +step:3100/20000 train_loss:2.0492 train_time:300575ms step_avg:96.96ms +step:3200/20000 train_loss:2.1812 train_time:310242ms step_avg:96.95ms +step:3300/20000 train_loss:2.0789 train_time:319868ms step_avg:96.93ms +step:3400/20000 train_loss:2.0252 train_time:329534ms step_avg:96.92ms +step:3500/20000 train_loss:2.1870 train_time:339206ms step_avg:96.92ms +step:3500/20000 val_loss:2.0857 val_bpb:1.2352 train_time:339217ms step_avg:96.92ms +step:3600/20000 train_loss:2.0989 train_time:348873ms step_avg:96.91ms +step:3700/20000 train_loss:2.0945 train_time:358542ms step_avg:96.90ms +step:3800/20000 train_loss:2.0716 train_time:368150ms step_avg:96.88ms +step:3900/20000 train_loss:2.0747 train_time:377830ms step_avg:96.88ms +step:4000/20000 train_loss:1.9727 train_time:387503ms step_avg:96.88ms +step:4000/20000 val_loss:2.0646 val_bpb:1.2228 train_time:387514ms step_avg:96.88ms +step:4100/20000 train_loss:2.0125 train_time:397163ms step_avg:96.87ms +step:4200/20000 train_loss:2.1507 train_time:406841ms step_avg:96.87ms +step:4300/20000 train_loss:2.0532 train_time:416445ms step_avg:96.85ms +step:4400/20000 train_loss:2.0278 train_time:426108ms step_avg:96.84ms +step:4500/20000 train_loss:2.1176 train_time:435774ms step_avg:96.84ms +step:4500/20000 val_loss:2.0397 val_bpb:1.2081 train_time:435786ms step_avg:96.84ms +step:4600/20000 train_loss:1.8352 train_time:445438ms step_avg:96.83ms +step:4700/20000 train_loss:2.2230 train_time:455057ms step_avg:96.82ms +step:4800/20000 train_loss:2.4225 train_time:464733ms step_avg:96.82ms +step:4900/20000 train_loss:2.0402 train_time:474405ms step_avg:96.82ms +swa:start step:5000 +step:5000/20000 train_loss:2.0954 train_time:484078ms step_avg:96.82ms +step:5000/20000 val_loss:2.0163 val_bpb:1.1942 train_time:484169ms step_avg:96.83ms +step:5100/20000 train_loss:2.1199 train_time:493854ms step_avg:96.83ms +step:5200/20000 train_loss:2.0331 train_time:503514ms step_avg:96.83ms +step:5300/20000 train_loss:1.9993 train_time:513250ms step_avg:96.84ms +step:5400/20000 train_loss:2.0376 train_time:522959ms step_avg:96.84ms +step:5500/20000 train_loss:2.0063 train_time:532694ms step_avg:96.85ms +step:5500/20000 val_loss:1.9904 val_bpb:1.1788 train_time:532745ms step_avg:96.86ms +step:5600/20000 train_loss:1.9444 train_time:542423ms step_avg:96.86ms +step:5700/20000 train_loss:2.0014 train_time:552079ms step_avg:96.86ms +step:5800/20000 train_loss:1.9841 train_time:561796ms step_avg:96.86ms +step:5900/20000 train_loss:1.8864 train_time:571499ms step_avg:96.86ms +step:6000/20000 train_loss:1.9303 train_time:581248ms step_avg:96.87ms +step:6000/20000 val_loss:1.9645 val_bpb:1.1635 train_time:581299ms step_avg:96.88ms +step:6100/20000 train_loss:1.9038 train_time:590921ms step_avg:96.87ms +step:6194/20000 val_loss:1.9577 val_bpb:1.1595 train_time:600071ms step_avg:96.88ms +stopping_early: wallclock_cap train_time:600071ms step:6194/20000 +peak memory allocated: 25197 MiB reserved: 25280 MiB +swa:applying averaged 24 checkpoints +Serialized model: 96864150 bytes +Code size: 53947 bytes +Total submission size: 96918097 bytes +Serialized model int6+zstd: 15236935 bytes +Total submission size: 15290882 bytes (15.29 MB) +SIZE CHECK PASSED: 15.29 MB < 16.00 MB +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_window stride:64 batch_seqs:32 + sliding_eval [ 0.0%] 32/121136 windows running_bpb=1.212943 + sliding_eval [ 1.3%] 1632/121136 windows running_bpb=1.144652 + sliding_eval [ 2.7%] 3232/121136 windows running_bpb=1.144706 + sliding_eval [ 4.0%] 4832/121136 windows running_bpb=1.138295 + sliding_eval [ 5.3%] 6432/121136 windows running_bpb=1.150602 + sliding_eval [ 6.6%] 8032/121136 windows running_bpb=1.151797 + sliding_eval [ 8.0%] 9632/121136 windows running_bpb=1.153445 + sliding_eval [ 9.3%] 11232/121136 windows running_bpb=1.148814 + sliding_eval [ 10.6%] 12832/121136 windows running_bpb=1.146503 + sliding_eval [ 11.9%] 14432/121136 windows running_bpb=1.148155 + sliding_eval [ 13.2%] 16032/121136 windows running_bpb=1.157046 + sliding_eval [ 14.6%] 17632/121136 windows running_bpb=1.155636 + sliding_eval [ 15.9%] 19232/121136 windows running_bpb=1.156994 + sliding_eval [ 17.2%] 20832/121136 windows running_bpb=1.155387 + sliding_eval [ 18.5%] 22432/121136 windows running_bpb=1.153958 + sliding_eval [ 19.8%] 24032/121136 windows running_bpb=1.154325 + sliding_eval [ 21.2%] 25632/121136 windows running_bpb=1.155682 + sliding_eval [ 22.5%] 27232/121136 windows running_bpb=1.156220 + sliding_eval [ 23.8%] 28832/121136 windows running_bpb=1.162211 + sliding_eval [ 25.1%] 30432/121136 windows running_bpb=1.159600 + sliding_eval [ 26.4%] 32032/121136 windows running_bpb=1.160541 + sliding_eval [ 27.8%] 33632/121136 windows running_bpb=1.159183 + sliding_eval [ 29.1%] 35232/121136 windows running_bpb=1.158487 + sliding_eval [ 30.4%] 36832/121136 windows running_bpb=1.158106 + sliding_eval [ 31.7%] 38432/121136 windows running_bpb=1.158726 + sliding_eval [ 33.0%] 40032/121136 windows running_bpb=1.156341 + sliding_eval [ 34.4%] 41632/121136 windows running_bpb=1.155426 + sliding_eval [ 35.7%] 43232/121136 windows running_bpb=1.155782 + sliding_eval [ 37.0%] 44832/121136 windows running_bpb=1.154693 + sliding_eval [ 38.3%] 46432/121136 windows running_bpb=1.154535 + sliding_eval [ 39.7%] 48032/121136 windows running_bpb=1.153819 + sliding_eval [ 41.0%] 49632/121136 windows running_bpb=1.155007 + sliding_eval [ 42.3%] 51232/121136 windows running_bpb=1.156107 + sliding_eval [ 43.6%] 52832/121136 windows running_bpb=1.156644 + sliding_eval [ 44.9%] 54432/121136 windows running_bpb=1.156142 + sliding_eval [ 46.3%] 56032/121136 windows running_bpb=1.156486 + sliding_eval [ 47.6%] 57632/121136 windows running_bpb=1.155541 + sliding_eval [ 48.9%] 59232/121136 windows running_bpb=1.151669 + sliding_eval [ 50.2%] 60832/121136 windows running_bpb=1.151800 + sliding_eval [ 51.5%] 62432/121136 windows running_bpb=1.152747 + sliding_eval [ 52.9%] 64032/121136 windows running_bpb=1.152949 + sliding_eval [ 54.2%] 65632/121136 windows running_bpb=1.152795 + sliding_eval [ 55.5%] 67232/121136 windows running_bpb=1.151539 + sliding_eval [ 56.8%] 68832/121136 windows running_bpb=1.151232 + sliding_eval [ 58.1%] 70432/121136 windows running_bpb=1.150524 + sliding_eval [ 59.5%] 72032/121136 windows running_bpb=1.150606 + sliding_eval [ 60.8%] 73632/121136 windows running_bpb=1.150584 + sliding_eval [ 62.1%] 75232/121136 windows running_bpb=1.150767 + sliding_eval [ 63.4%] 76832/121136 windows running_bpb=1.150466 + sliding_eval [ 64.7%] 78432/121136 windows running_bpb=1.151063 + sliding_eval [ 66.1%] 80032/121136 windows running_bpb=1.151378 + sliding_eval [ 67.4%] 81632/121136 windows running_bpb=1.151041 + sliding_eval [ 68.7%] 83232/121136 windows running_bpb=1.152109 + sliding_eval [ 70.0%] 84832/121136 windows running_bpb=1.154006 + sliding_eval [ 71.4%] 86432/121136 windows running_bpb=1.153328 + sliding_eval [ 72.7%] 88032/121136 windows running_bpb=1.154049 + sliding_eval [ 74.0%] 89632/121136 windows running_bpb=1.154407 + sliding_eval [ 75.3%] 91232/121136 windows running_bpb=1.154369 + sliding_eval [ 76.6%] 92832/121136 windows running_bpb=1.153934 + sliding_eval [ 78.0%] 94432/121136 windows running_bpb=1.154175 + sliding_eval [ 79.3%] 96032/121136 windows running_bpb=1.153559 + sliding_eval [ 80.6%] 97632/121136 windows running_bpb=1.156369 + sliding_eval [ 81.9%] 99232/121136 windows running_bpb=1.156395 + sliding_eval [ 83.2%] 100832/121136 windows running_bpb=1.156419 + sliding_eval [ 84.6%] 102432/121136 windows running_bpb=1.156053 + sliding_eval [ 85.9%] 104032/121136 windows running_bpb=1.155522 + sliding_eval [ 87.2%] 105632/121136 windows running_bpb=1.154779 + sliding_eval [ 88.5%] 107232/121136 windows running_bpb=1.154745 + sliding_eval [ 89.8%] 108832/121136 windows running_bpb=1.155380 + sliding_eval [ 91.2%] 110432/121136 windows running_bpb=1.155403 + sliding_eval [ 92.5%] 112032/121136 windows running_bpb=1.155400 + sliding_eval [ 93.8%] 113632/121136 windows running_bpb=1.155853 + sliding_eval [ 95.1%] 115232/121136 windows running_bpb=1.155580 + sliding_eval [ 96.4%] 116832/121136 windows running_bpb=1.155193 + sliding_eval [ 97.8%] 118432/121136 windows running_bpb=1.155506 + sliding_eval [ 99.1%] 120032/121136 windows running_bpb=1.155584 +final_int8_zlib_roundtrip val_loss:1.9416 val_bpb:1.1499 eval_time:257631ms +final_int8_zlib_roundtrip_exact val_loss:1.94158659 val_bpb:1.14991999 diff --git a/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/train_seed42_2024.log b/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/train_seed42_2024.log new file mode 100644 index 000000000..48fc821c5 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_Int5MLP_BigramHash4096_SWA/train_seed42_2024.log @@ -0,0 +1,451 @@ +=== Seed 42 === +W0325 07:37:51.697000 139285672440448 torch/distributed/run.py:779] +W0325 07:37:51.697000 139285672440448 torch/distributed/run.py:779] ***************************************** +W0325 07:37:51.697000 139285672440448 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 07:37:51.697000 139285672440448 torch/distributed/run.py:779] ***************************************** +logs/61c2f947-2af4-49cb-bae6-0987a307f3e1.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9283 val_bpb:4.1033 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9311 train_time:147ms step_avg:146.80ms +step:2/20000 train_loss:7.6460 train_time:231ms step_avg:115.29ms +step:3/20000 train_loss:7.2978 train_time:326ms step_avg:108.65ms +step:4/20000 train_loss:8.1124 train_time:421ms step_avg:105.21ms +step:5/20000 train_loss:8.2634 train_time:516ms step_avg:103.12ms +step:6/20000 train_loss:8.1496 train_time:611ms step_avg:101.82ms +step:7/20000 train_loss:7.5086 train_time:705ms step_avg:100.77ms +step:8/20000 train_loss:6.9069 train_time:802ms step_avg:100.24ms +step:9/20000 train_loss:6.4557 train_time:897ms step_avg:99.64ms +step:10/20000 train_loss:6.1781 train_time:992ms step_avg:99.18ms +step:100/20000 train_loss:3.2019 train_time:9593ms step_avg:95.93ms +step:200/20000 train_loss:2.4533 train_time:19221ms step_avg:96.11ms +step:300/20000 train_loss:2.5855 train_time:28943ms step_avg:96.48ms +step:400/20000 train_loss:2.4406 train_time:38625ms step_avg:96.56ms +step:500/20000 train_loss:2.4193 train_time:48267ms step_avg:96.53ms +step:500/20000 val_loss:2.3832 val_bpb:1.4115 train_time:48279ms step_avg:96.56ms +step:600/20000 train_loss:2.3501 train_time:57973ms step_avg:96.62ms +step:700/20000 train_loss:2.3631 train_time:67696ms step_avg:96.71ms +step:800/20000 train_loss:2.2571 train_time:77427ms step_avg:96.78ms +step:900/20000 train_loss:2.1501 train_time:87164ms step_avg:96.85ms +step:1000/20000 train_loss:2.2898 train_time:96830ms step_avg:96.83ms +step:1000/20000 val_loss:2.2439 val_bpb:1.3289 train_time:96841ms step_avg:96.84ms +step:1100/20000 train_loss:2.3480 train_time:106549ms step_avg:96.86ms +step:1200/20000 train_loss:2.3698 train_time:116264ms step_avg:96.89ms +step:1300/20000 train_loss:2.1180 train_time:125963ms step_avg:96.89ms +step:1400/20000 train_loss:2.1990 train_time:135680ms step_avg:96.91ms +step:1500/20000 train_loss:2.2343 train_time:145320ms step_avg:96.88ms +step:1500/20000 val_loss:2.1988 val_bpb:1.3023 train_time:145331ms step_avg:96.89ms +step:1600/20000 train_loss:2.0879 train_time:155018ms step_avg:96.89ms +step:1700/20000 train_loss:2.1539 train_time:164705ms step_avg:96.89ms +step:1800/20000 train_loss:2.1700 train_time:174379ms step_avg:96.88ms +step:1900/20000 train_loss:2.1374 train_time:183995ms step_avg:96.84ms +step:2000/20000 train_loss:2.0789 train_time:193685ms step_avg:96.84ms +step:2000/20000 val_loss:2.1449 val_bpb:1.2703 train_time:193696ms step_avg:96.85ms +step:2100/20000 train_loss:2.0565 train_time:203351ms step_avg:96.83ms +step:2200/20000 train_loss:2.1572 train_time:213019ms step_avg:96.83ms +step:2300/20000 train_loss:2.1199 train_time:222688ms step_avg:96.82ms +step:2400/20000 train_loss:2.0779 train_time:232292ms step_avg:96.79ms +step:2500/20000 train_loss:2.1820 train_time:241962ms step_avg:96.78ms +step:2500/20000 val_loss:2.1184 val_bpb:1.2546 train_time:241973ms step_avg:96.79ms +step:2600/20000 train_loss:2.1215 train_time:251634ms step_avg:96.78ms +step:2700/20000 train_loss:2.1109 train_time:261282ms step_avg:96.77ms +step:2800/20000 train_loss:2.1670 train_time:270947ms step_avg:96.77ms +step:2900/20000 train_loss:2.0350 train_time:280535ms step_avg:96.74ms +step:3000/20000 train_loss:2.1700 train_time:290185ms step_avg:96.73ms +step:3000/20000 val_loss:2.1015 val_bpb:1.2446 train_time:290197ms step_avg:96.73ms +step:3100/20000 train_loss:2.0455 train_time:299842ms step_avg:96.72ms +step:3200/20000 train_loss:2.1797 train_time:309484ms step_avg:96.71ms +step:3300/20000 train_loss:2.0796 train_time:319071ms step_avg:96.69ms +step:3400/20000 train_loss:2.0287 train_time:328729ms step_avg:96.69ms +step:3500/20000 train_loss:2.1845 train_time:338381ms step_avg:96.68ms +step:3500/20000 val_loss:2.0854 val_bpb:1.2351 train_time:338392ms step_avg:96.68ms +step:3600/20000 train_loss:2.0976 train_time:348033ms step_avg:96.68ms +step:3700/20000 train_loss:2.0943 train_time:357691ms step_avg:96.67ms +step:3800/20000 train_loss:2.0755 train_time:367294ms step_avg:96.66ms +step:3900/20000 train_loss:2.0751 train_time:376941ms step_avg:96.65ms +step:4000/20000 train_loss:1.9718 train_time:386587ms step_avg:96.65ms +step:4000/20000 val_loss:2.0641 val_bpb:1.2224 train_time:386598ms step_avg:96.65ms +step:4100/20000 train_loss:2.0109 train_time:396241ms step_avg:96.64ms +step:4200/20000 train_loss:2.1485 train_time:405898ms step_avg:96.64ms +step:4300/20000 train_loss:2.0515 train_time:415483ms step_avg:96.62ms +step:4400/20000 train_loss:2.0305 train_time:425140ms step_avg:96.62ms +step:4500/20000 train_loss:2.1169 train_time:434790ms step_avg:96.62ms +step:4500/20000 val_loss:2.0396 val_bpb:1.2080 train_time:434801ms step_avg:96.62ms +step:4600/20000 train_loss:1.8388 train_time:444444ms step_avg:96.62ms +step:4700/20000 train_loss:2.2250 train_time:454038ms step_avg:96.60ms +step:4800/20000 train_loss:2.4221 train_time:463694ms step_avg:96.60ms +step:4900/20000 train_loss:2.0444 train_time:473345ms step_avg:96.60ms +step:5000/20000 train_loss:2.0964 train_time:482992ms step_avg:96.60ms +step:5000/20000 val_loss:2.0166 val_bpb:1.1944 train_time:483003ms step_avg:96.60ms +swa:start step:5050 +step:5100/20000 train_loss:2.1198 train_time:492717ms step_avg:96.61ms +step:5200/20000 train_loss:2.0305 train_time:502362ms step_avg:96.61ms +step:5300/20000 train_loss:1.9936 train_time:512065ms step_avg:96.62ms +step:5400/20000 train_loss:2.0369 train_time:521777ms step_avg:96.63ms +step:5500/20000 train_loss:2.0078 train_time:531493ms step_avg:96.64ms +step:5500/20000 val_loss:1.9908 val_bpb:1.1790 train_time:531545ms step_avg:96.64ms +step:5600/20000 train_loss:1.9423 train_time:541208ms step_avg:96.64ms +step:5700/20000 train_loss:1.9981 train_time:550852ms step_avg:96.64ms +step:5800/20000 train_loss:1.9831 train_time:560557ms step_avg:96.65ms +step:5900/20000 train_loss:1.8898 train_time:570260ms step_avg:96.65ms +step:6000/20000 train_loss:1.9247 train_time:579965ms step_avg:96.66ms +step:6000/20000 val_loss:1.9649 val_bpb:1.1637 train_time:580005ms step_avg:96.67ms +step:6100/20000 train_loss:1.9045 train_time:589613ms step_avg:96.66ms +step:6200/20000 train_loss:1.9365 train_time:599319ms step_avg:96.66ms +step:6207/20000 val_loss:1.9575 val_bpb:1.1594 train_time:600051ms step_avg:96.67ms +stopping_early: wallclock_cap train_time:600051ms step:6207/20000 +peak memory allocated: 25197 MiB reserved: 25280 MiB +swa:applying averaged 24 checkpoints +Serialized model: 96864150 bytes +Code size: 53947 bytes +Total submission size: 96918097 bytes +Serialized model int6+zstd: 15567047 bytes +Total submission size: 15620994 bytes (15.62 MB) +SIZE CHECK PASSED: 15.62 MB < 16.00 MB +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_window stride:64 batch_seqs:32 + sliding_eval [ 0.0%] 32/121136 windows running_bpb=1.215571 + sliding_eval [ 1.3%] 1632/121136 windows running_bpb=1.145384 + sliding_eval [ 2.7%] 3232/121136 windows running_bpb=1.146634 + sliding_eval [ 4.0%] 4832/121136 windows running_bpb=1.140389 + sliding_eval [ 5.3%] 6432/121136 windows running_bpb=1.152160 + sliding_eval [ 6.6%] 8032/121136 windows running_bpb=1.153567 + sliding_eval [ 8.0%] 9632/121136 windows running_bpb=1.155108 + sliding_eval [ 9.3%] 11232/121136 windows running_bpb=1.150445 + sliding_eval [ 10.6%] 12832/121136 windows running_bpb=1.147938 + sliding_eval [ 11.9%] 14432/121136 windows running_bpb=1.149667 + sliding_eval [ 13.2%] 16032/121136 windows running_bpb=1.158456 + sliding_eval [ 14.6%] 17632/121136 windows running_bpb=1.156705 + sliding_eval [ 15.9%] 19232/121136 windows running_bpb=1.157899 + sliding_eval [ 17.2%] 20832/121136 windows running_bpb=1.156148 + sliding_eval [ 18.5%] 22432/121136 windows running_bpb=1.154538 + sliding_eval [ 19.8%] 24032/121136 windows running_bpb=1.154915 + sliding_eval [ 21.2%] 25632/121136 windows running_bpb=1.156341 + sliding_eval [ 22.5%] 27232/121136 windows running_bpb=1.156911 + sliding_eval [ 23.8%] 28832/121136 windows running_bpb=1.163123 + sliding_eval [ 25.1%] 30432/121136 windows running_bpb=1.160502 + sliding_eval [ 26.4%] 32032/121136 windows running_bpb=1.161487 + sliding_eval [ 27.8%] 33632/121136 windows running_bpb=1.160143 + sliding_eval [ 29.1%] 35232/121136 windows running_bpb=1.159485 + sliding_eval [ 30.4%] 36832/121136 windows running_bpb=1.159121 + sliding_eval [ 31.7%] 38432/121136 windows running_bpb=1.159817 + sliding_eval [ 33.0%] 40032/121136 windows running_bpb=1.157480 + sliding_eval [ 34.4%] 41632/121136 windows running_bpb=1.156511 + sliding_eval [ 35.7%] 43232/121136 windows running_bpb=1.156851 + sliding_eval [ 37.0%] 44832/121136 windows running_bpb=1.155660 + sliding_eval [ 38.3%] 46432/121136 windows running_bpb=1.155420 + sliding_eval [ 39.7%] 48032/121136 windows running_bpb=1.154683 + sliding_eval [ 41.0%] 49632/121136 windows running_bpb=1.155899 + sliding_eval [ 42.3%] 51232/121136 windows running_bpb=1.156952 + sliding_eval [ 43.6%] 52832/121136 windows running_bpb=1.157465 + sliding_eval [ 44.9%] 54432/121136 windows running_bpb=1.156963 + sliding_eval [ 46.3%] 56032/121136 windows running_bpb=1.157349 + sliding_eval [ 47.6%] 57632/121136 windows running_bpb=1.156443 + sliding_eval [ 48.9%] 59232/121136 windows running_bpb=1.152576 + sliding_eval [ 50.2%] 60832/121136 windows running_bpb=1.152694 + sliding_eval [ 51.5%] 62432/121136 windows running_bpb=1.153624 + sliding_eval [ 52.9%] 64032/121136 windows running_bpb=1.153767 + sliding_eval [ 54.2%] 65632/121136 windows running_bpb=1.153633 + sliding_eval [ 55.5%] 67232/121136 windows running_bpb=1.152414 + sliding_eval [ 56.8%] 68832/121136 windows running_bpb=1.152156 + sliding_eval [ 58.1%] 70432/121136 windows running_bpb=1.151442 + sliding_eval [ 59.5%] 72032/121136 windows running_bpb=1.151515 + sliding_eval [ 60.8%] 73632/121136 windows running_bpb=1.151492 + sliding_eval [ 62.1%] 75232/121136 windows running_bpb=1.151673 + sliding_eval [ 63.4%] 76832/121136 windows running_bpb=1.151371 + sliding_eval [ 64.7%] 78432/121136 windows running_bpb=1.151955 + sliding_eval [ 66.1%] 80032/121136 windows running_bpb=1.152240 + sliding_eval [ 67.4%] 81632/121136 windows running_bpb=1.151937 + sliding_eval [ 68.7%] 83232/121136 windows running_bpb=1.152965 + sliding_eval [ 70.0%] 84832/121136 windows running_bpb=1.154865 + sliding_eval [ 71.4%] 86432/121136 windows running_bpb=1.154140 + sliding_eval [ 72.7%] 88032/121136 windows running_bpb=1.154867 + sliding_eval [ 74.0%] 89632/121136 windows running_bpb=1.155211 + sliding_eval [ 75.3%] 91232/121136 windows running_bpb=1.155225 + sliding_eval [ 76.6%] 92832/121136 windows running_bpb=1.154805 + sliding_eval [ 78.0%] 94432/121136 windows running_bpb=1.155040 + sliding_eval [ 79.3%] 96032/121136 windows running_bpb=1.154453 + sliding_eval [ 80.6%] 97632/121136 windows running_bpb=1.157248 + sliding_eval [ 81.9%] 99232/121136 windows running_bpb=1.157246 + sliding_eval [ 83.2%] 100832/121136 windows running_bpb=1.157285 + sliding_eval [ 84.6%] 102432/121136 windows running_bpb=1.156913 + sliding_eval [ 85.9%] 104032/121136 windows running_bpb=1.156440 + sliding_eval [ 87.2%] 105632/121136 windows running_bpb=1.155683 + sliding_eval [ 88.5%] 107232/121136 windows running_bpb=1.155666 + sliding_eval [ 89.8%] 108832/121136 windows running_bpb=1.156286 + sliding_eval [ 91.2%] 110432/121136 windows running_bpb=1.156331 + sliding_eval [ 92.5%] 112032/121136 windows running_bpb=1.156277 + sliding_eval [ 93.8%] 113632/121136 windows running_bpb=1.156723 + sliding_eval [ 95.1%] 115232/121136 windows running_bpb=1.156472 + sliding_eval [ 96.4%] 116832/121136 windows running_bpb=1.156093 + sliding_eval [ 97.8%] 118432/121136 windows running_bpb=1.156418 + sliding_eval [ 99.1%] 120032/121136 windows running_bpb=1.156536 +final_int8_zlib_roundtrip val_loss:1.9432 val_bpb:1.1508 eval_time:257362ms +final_int8_zlib_roundtrip_exact val_loss:1.94315106 val_bpb:1.15084656 +=== Seed 2024 === +W0325 07:54:47.711000 132224411792000 torch/distributed/run.py:779] +W0325 07:54:47.711000 132224411792000 torch/distributed/run.py:779] ***************************************** +W0325 07:54:47.711000 132224411792000 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 07:54:47.711000 132224411792000 torch/distributed/run.py:779] ***************************************** +logs/2fd3011e-cf7e-4813-892c-4188650cb7e4.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2024 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9274 val_bpb:4.1028 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9305 train_time:146ms step_avg:146.45ms +step:2/20000 train_loss:7.6541 train_time:231ms step_avg:115.33ms +step:3/20000 train_loss:7.2740 train_time:327ms step_avg:108.89ms +step:4/20000 train_loss:7.9482 train_time:422ms step_avg:105.56ms +step:5/20000 train_loss:8.2933 train_time:519ms step_avg:103.70ms +step:6/20000 train_loss:8.2743 train_time:614ms step_avg:102.38ms +step:7/20000 train_loss:7.5780 train_time:709ms step_avg:101.32ms +step:8/20000 train_loss:6.8716 train_time:805ms step_avg:100.64ms +step:9/20000 train_loss:6.4468 train_time:901ms step_avg:100.12ms +step:10/20000 train_loss:6.0611 train_time:996ms step_avg:99.62ms +step:100/20000 train_loss:3.2292 train_time:9617ms step_avg:96.17ms +step:200/20000 train_loss:2.4546 train_time:19261ms step_avg:96.31ms +step:300/20000 train_loss:2.5762 train_time:28952ms step_avg:96.51ms +step:400/20000 train_loss:2.4349 train_time:38661ms step_avg:96.65ms +step:500/20000 train_loss:2.4143 train_time:48351ms step_avg:96.70ms +step:500/20000 val_loss:2.3761 val_bpb:1.4072 train_time:48363ms step_avg:96.73ms +step:600/20000 train_loss:2.3491 train_time:58100ms step_avg:96.83ms +step:700/20000 train_loss:2.3640 train_time:67852ms step_avg:96.93ms +step:800/20000 train_loss:2.2535 train_time:77597ms step_avg:97.00ms +step:900/20000 train_loss:2.1461 train_time:87344ms step_avg:97.05ms +step:1000/20000 train_loss:2.2890 train_time:97038ms step_avg:97.04ms +step:1000/20000 val_loss:2.2408 val_bpb:1.3271 train_time:97050ms step_avg:97.05ms +step:1100/20000 train_loss:2.3359 train_time:106780ms step_avg:97.07ms +step:1200/20000 train_loss:2.3691 train_time:116513ms step_avg:97.09ms +step:1300/20000 train_loss:2.1153 train_time:126233ms step_avg:97.10ms +step:1400/20000 train_loss:2.1938 train_time:135969ms step_avg:97.12ms +step:1500/20000 train_loss:2.2351 train_time:145627ms step_avg:97.08ms +step:1500/20000 val_loss:2.1981 val_bpb:1.3018 train_time:145639ms step_avg:97.09ms +step:1600/20000 train_loss:2.0893 train_time:155331ms step_avg:97.08ms +step:1700/20000 train_loss:2.1515 train_time:165046ms step_avg:97.09ms +step:1800/20000 train_loss:2.1770 train_time:174744ms step_avg:97.08ms +step:1900/20000 train_loss:2.1413 train_time:184388ms step_avg:97.05ms +step:2000/20000 train_loss:2.0797 train_time:194078ms step_avg:97.04ms +step:2000/20000 val_loss:2.1449 val_bpb:1.2704 train_time:194090ms step_avg:97.04ms +step:2100/20000 train_loss:2.0605 train_time:203838ms step_avg:97.07ms +step:2200/20000 train_loss:2.1593 train_time:213539ms step_avg:97.06ms +step:2300/20000 train_loss:2.1190 train_time:223218ms step_avg:97.05ms +step:2400/20000 train_loss:2.0779 train_time:232836ms step_avg:97.01ms +step:2500/20000 train_loss:2.1796 train_time:242502ms step_avg:97.00ms +step:2500/20000 val_loss:2.1176 val_bpb:1.2542 train_time:242513ms step_avg:97.01ms +step:2600/20000 train_loss:2.1183 train_time:252182ms step_avg:96.99ms +step:2700/20000 train_loss:2.1120 train_time:261847ms step_avg:96.98ms +step:2800/20000 train_loss:2.1668 train_time:271527ms step_avg:96.97ms +step:2900/20000 train_loss:2.0367 train_time:281137ms step_avg:96.94ms +step:3000/20000 train_loss:2.1726 train_time:290803ms step_avg:96.93ms +step:3000/20000 val_loss:2.1018 val_bpb:1.2448 train_time:290815ms step_avg:96.94ms +step:3100/20000 train_loss:2.0464 train_time:300470ms step_avg:96.93ms +step:3200/20000 train_loss:2.1855 train_time:310133ms step_avg:96.92ms +step:3300/20000 train_loss:2.0792 train_time:319748ms step_avg:96.89ms +step:3400/20000 train_loss:2.0280 train_time:329396ms step_avg:96.88ms +step:3500/20000 train_loss:2.1863 train_time:339060ms step_avg:96.87ms +step:3500/20000 val_loss:2.0858 val_bpb:1.2353 train_time:339071ms step_avg:96.88ms +step:3600/20000 train_loss:2.1011 train_time:348717ms step_avg:96.87ms +step:3700/20000 train_loss:2.0936 train_time:358391ms step_avg:96.86ms +step:3800/20000 train_loss:2.0715 train_time:367979ms step_avg:96.84ms +step:3900/20000 train_loss:2.0767 train_time:377638ms step_avg:96.83ms +step:4000/20000 train_loss:1.9762 train_time:387288ms step_avg:96.82ms +step:4000/20000 val_loss:2.0642 val_bpb:1.2225 train_time:387299ms step_avg:96.82ms +step:4100/20000 train_loss:2.0123 train_time:396953ms step_avg:96.82ms +step:4200/20000 train_loss:2.1504 train_time:406596ms step_avg:96.81ms +step:4300/20000 train_loss:2.0511 train_time:416194ms step_avg:96.79ms +step:4400/20000 train_loss:2.0288 train_time:425850ms step_avg:96.78ms +step:4500/20000 train_loss:2.1175 train_time:435516ms step_avg:96.78ms +step:4500/20000 val_loss:2.0403 val_bpb:1.2084 train_time:435528ms step_avg:96.78ms +step:4600/20000 train_loss:1.8384 train_time:445175ms step_avg:96.78ms +step:4700/20000 train_loss:2.2272 train_time:454773ms step_avg:96.76ms +step:4800/20000 train_loss:2.4189 train_time:464429ms step_avg:96.76ms +step:4900/20000 train_loss:2.0435 train_time:474087ms step_avg:96.75ms +step:5000/20000 train_loss:2.0960 train_time:483753ms step_avg:96.75ms +step:5000/20000 val_loss:2.0172 val_bpb:1.1947 train_time:483764ms step_avg:96.75ms +swa:start step:5050 +step:5100/20000 train_loss:2.1168 train_time:493493ms step_avg:96.76ms +step:5200/20000 train_loss:2.0327 train_time:503149ms step_avg:96.76ms +step:5300/20000 train_loss:2.0008 train_time:512861ms step_avg:96.77ms +step:5400/20000 train_loss:2.0373 train_time:522584ms step_avg:96.77ms +step:5500/20000 train_loss:2.0055 train_time:532297ms step_avg:96.78ms +step:5500/20000 val_loss:1.9914 val_bpb:1.1794 train_time:532337ms step_avg:96.79ms +step:5600/20000 train_loss:1.9467 train_time:542033ms step_avg:96.79ms +step:5700/20000 train_loss:2.0031 train_time:551695ms step_avg:96.79ms +step:5800/20000 train_loss:1.9835 train_time:561393ms step_avg:96.79ms +step:5900/20000 train_loss:1.8868 train_time:571098ms step_avg:96.80ms +step:6000/20000 train_loss:1.9295 train_time:580812ms step_avg:96.80ms +step:6000/20000 val_loss:1.9654 val_bpb:1.1640 train_time:580851ms step_avg:96.81ms +step:6100/20000 train_loss:1.9046 train_time:590486ms step_avg:96.80ms +step:6198/20000 val_loss:1.9585 val_bpb:1.1599 train_time:600035ms step_avg:96.81ms +stopping_early: wallclock_cap train_time:600035ms step:6198/20000 +peak memory allocated: 25197 MiB reserved: 25280 MiB +swa:applying averaged 23 checkpoints +Serialized model: 96864150 bytes +Code size: 53947 bytes +Total submission size: 96918097 bytes +Serialized model int6+zstd: 15273866 bytes +Total submission size: 15327813 bytes (15.33 MB) +SIZE CHECK PASSED: 15.33 MB < 16.00 MB +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/submission/train_gpt.py:1216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_window stride:64 batch_seqs:32 + sliding_eval [ 0.0%] 32/121136 windows running_bpb=1.216181 + sliding_eval [ 1.3%] 1632/121136 windows running_bpb=1.144532 + sliding_eval [ 2.7%] 3232/121136 windows running_bpb=1.146602 + sliding_eval [ 4.0%] 4832/121136 windows running_bpb=1.140936 + sliding_eval [ 5.3%] 6432/121136 windows running_bpb=1.153575 + sliding_eval [ 6.6%] 8032/121136 windows running_bpb=1.154620 + sliding_eval [ 8.0%] 9632/121136 windows running_bpb=1.155571 + sliding_eval [ 9.3%] 11232/121136 windows running_bpb=1.150897 + sliding_eval [ 10.6%] 12832/121136 windows running_bpb=1.148410 + sliding_eval [ 11.9%] 14432/121136 windows running_bpb=1.150199 + sliding_eval [ 13.2%] 16032/121136 windows running_bpb=1.158977 + sliding_eval [ 14.6%] 17632/121136 windows running_bpb=1.157196 + sliding_eval [ 15.9%] 19232/121136 windows running_bpb=1.158572 + sliding_eval [ 17.2%] 20832/121136 windows running_bpb=1.156692 + sliding_eval [ 18.5%] 22432/121136 windows running_bpb=1.155084 + sliding_eval [ 19.8%] 24032/121136 windows running_bpb=1.155505 + sliding_eval [ 21.2%] 25632/121136 windows running_bpb=1.156831 + sliding_eval [ 22.5%] 27232/121136 windows running_bpb=1.157307 + sliding_eval [ 23.8%] 28832/121136 windows running_bpb=1.163444 + sliding_eval [ 25.1%] 30432/121136 windows running_bpb=1.160862 + sliding_eval [ 26.4%] 32032/121136 windows running_bpb=1.161873 + sliding_eval [ 27.8%] 33632/121136 windows running_bpb=1.160560 + sliding_eval [ 29.1%] 35232/121136 windows running_bpb=1.159839 + sliding_eval [ 30.4%] 36832/121136 windows running_bpb=1.159443 + sliding_eval [ 31.7%] 38432/121136 windows running_bpb=1.160102 + sliding_eval [ 33.0%] 40032/121136 windows running_bpb=1.157723 + sliding_eval [ 34.4%] 41632/121136 windows running_bpb=1.156799 + sliding_eval [ 35.7%] 43232/121136 windows running_bpb=1.157178 + sliding_eval [ 37.0%] 44832/121136 windows running_bpb=1.156031 + sliding_eval [ 38.3%] 46432/121136 windows running_bpb=1.155868 + sliding_eval [ 39.7%] 48032/121136 windows running_bpb=1.155177 + sliding_eval [ 41.0%] 49632/121136 windows running_bpb=1.156450 + sliding_eval [ 42.3%] 51232/121136 windows running_bpb=1.157543 + sliding_eval [ 43.6%] 52832/121136 windows running_bpb=1.158109 + sliding_eval [ 44.9%] 54432/121136 windows running_bpb=1.157624 + sliding_eval [ 46.3%] 56032/121136 windows running_bpb=1.158007 + sliding_eval [ 47.6%] 57632/121136 windows running_bpb=1.157124 + sliding_eval [ 48.9%] 59232/121136 windows running_bpb=1.153235 + sliding_eval [ 50.2%] 60832/121136 windows running_bpb=1.153341 + sliding_eval [ 51.5%] 62432/121136 windows running_bpb=1.154246 + sliding_eval [ 52.9%] 64032/121136 windows running_bpb=1.154418 + sliding_eval [ 54.2%] 65632/121136 windows running_bpb=1.154252 + sliding_eval [ 55.5%] 67232/121136 windows running_bpb=1.153036 + sliding_eval [ 56.8%] 68832/121136 windows running_bpb=1.152752 + sliding_eval [ 58.1%] 70432/121136 windows running_bpb=1.152040 + sliding_eval [ 59.5%] 72032/121136 windows running_bpb=1.152125 + sliding_eval [ 60.8%] 73632/121136 windows running_bpb=1.152084 + sliding_eval [ 62.1%] 75232/121136 windows running_bpb=1.152294 + sliding_eval [ 63.4%] 76832/121136 windows running_bpb=1.151997 + sliding_eval [ 64.7%] 78432/121136 windows running_bpb=1.152615 + sliding_eval [ 66.1%] 80032/121136 windows running_bpb=1.152919 + sliding_eval [ 67.4%] 81632/121136 windows running_bpb=1.152593 + sliding_eval [ 68.7%] 83232/121136 windows running_bpb=1.153612 + sliding_eval [ 70.0%] 84832/121136 windows running_bpb=1.155550 + sliding_eval [ 71.4%] 86432/121136 windows running_bpb=1.154836 + sliding_eval [ 72.7%] 88032/121136 windows running_bpb=1.155549 + sliding_eval [ 74.0%] 89632/121136 windows running_bpb=1.155907 + sliding_eval [ 75.3%] 91232/121136 windows running_bpb=1.155878 + sliding_eval [ 76.6%] 92832/121136 windows running_bpb=1.155456 + sliding_eval [ 78.0%] 94432/121136 windows running_bpb=1.155681 + sliding_eval [ 79.3%] 96032/121136 windows running_bpb=1.155095 + sliding_eval [ 80.6%] 97632/121136 windows running_bpb=1.157920 + sliding_eval [ 81.9%] 99232/121136 windows running_bpb=1.157931 + sliding_eval [ 83.2%] 100832/121136 windows running_bpb=1.157985 + sliding_eval [ 84.6%] 102432/121136 windows running_bpb=1.157604 + sliding_eval [ 85.9%] 104032/121136 windows running_bpb=1.157098 + sliding_eval [ 87.2%] 105632/121136 windows running_bpb=1.156348 + sliding_eval [ 88.5%] 107232/121136 windows running_bpb=1.156335 + sliding_eval [ 89.8%] 108832/121136 windows running_bpb=1.156985 + sliding_eval [ 91.2%] 110432/121136 windows running_bpb=1.157018 + sliding_eval [ 92.5%] 112032/121136 windows running_bpb=1.156993 + sliding_eval [ 93.8%] 113632/121136 windows running_bpb=1.157412 + sliding_eval [ 95.1%] 115232/121136 windows running_bpb=1.157139 + sliding_eval [ 96.4%] 116832/121136 windows running_bpb=1.156761 + sliding_eval [ 97.8%] 118432/121136 windows running_bpb=1.157084 + sliding_eval [ 99.1%] 120032/121136 windows running_bpb=1.157164 +final_int8_zlib_roundtrip val_loss:1.9441 val_bpb:1.1514 eval_time:257658ms +final_int8_zlib_roundtrip_exact val_loss:1.94406554 val_bpb:1.15138816