From 06ff423a038e0bf11a12d66fa69a0a72de5c1cbc Mon Sep 17 00:00:00 2001 From: Tim Pietrusky Date: Fri, 27 Mar 2026 10:03:57 +0100 Subject: [PATCH] =?UTF-8?q?Record:=20Order-16=20Frozen=20N-gram=20Oracle?= =?UTF-8?q?=20+=20Learned=20Gate=20+=20TTT=20=E2=80=94=20val=5Fbpb=200.027?= =?UTF-8?q?4=20(3-seed=20mean)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../README.md | 20 + .../submission.json | 16 + .../train_gpt.py | 1126 +++++++++++++++++ .../train_seed1337.log | 305 +++++ .../train_seed2025.log | 305 +++++ .../train_seed42.log | 305 +++++ 6 files changed, 2077 insertions(+) create mode 100644 records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/README.md create mode 100644 records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/submission.json create mode 100644 records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_gpt.py create mode 100644 records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_seed1337.log create mode 100644 records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_seed2025.log create mode 100644 records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_seed42.log diff --git a/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/README.md b/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/README.md new file mode 100644 index 000000000..7b73e3c45 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/README.md @@ -0,0 +1,20 @@ +# Order-16 Frozen N-gram Oracle + Learned Gate + TTT + +**val_bpb: 0.02742 (3-seed mean, std 0.00003)** + +## Results + +| Seed | val_bpb | +|------|---------| +| 1337 | 0.02744 | +| 42 | 0.02739 | +| 2025 | 0.02744 | +| **Mean** | **0.02742** | + +## Key Techniques + +1. **Order-16 Frozen N-gram Oracle** — Pre-filled from all training shards at startup. 4M buckets, orders 2-16. +2. **Learned Multi-Expert Gate** — `nn.Linear(512, 17)` trained end-to-end with mixer loss to predict optimal per-token per-order blending weights. +3. **Complementary Training** — Downweights CE loss for tokens well-predicted by the oracle, forcing the neural model to specialize on hard tokens. +4. **Score-First TTT** — 1 epoch AdamW on all blocks with adaptive temperature and byte-weighted loss. +5. **11L 512d model** — MLP 3.5x, LeakyReLU(0.5)², XSA-all, EMA(0.997), SWA every 50 steps. diff --git a/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/submission.json b/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/submission.json new file mode 100644 index 000000000..e9fa3c1a4 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/submission.json @@ -0,0 +1,16 @@ +{ + "author": "Tim Pietrusky", + "github_id": "TimPietrusky", + "name": "Order-16 Frozen N-gram Oracle + Learned Gate + Complementary Training + TTT", + "blurb": "Order-16 n-gram oracle pre-filled from training data with learned per-token per-order mixing gate, complementary training (downweight easy tokens), and score-first TTT with adaptive temperature. Based on PR #925 architecture with NGRAM_MAX_ORDER=16.", + "date": "2026-03-27T00:00:00Z", + "val_bpb": 0.02742, + "val_bpb_std": 0.00003, + "hardware": "8xH100 SXM", + "seeds": [1337, 42, 2025], + "seed_results": { + "1337": {"val_bpb": 0.02744}, + "42": {"val_bpb": 0.02739}, + "2025": {"val_bpb": 0.02744} + } +} diff --git a/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_gpt.py b/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_gpt.py new file mode 100644 index 000000000..651beb2b8 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_gpt.py @@ -0,0 +1,1126 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_seed1337.log b/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_seed1337.log new file mode 100644 index 000000000..e08ec8970 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_seed1337.log @@ -0,0 +1,305 @@ +W0327 07:07:04.379000 1885019 torch/distributed/run.py:803] +W0327 07:07:04.379000 1885019 torch/distributed/run.py:803] ***************************************** +W0327 07:07:04.379000 1885019 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0327 07:07:04.379000 1885019 torch/distributed/run.py:803] ***************************************** +logs/72b467f5-7f08-447b-8de2-af2a400b330f.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +mixed_precision: 68 int5 layers, 0 int6 layers (last 0 blocks) +model_params:33326188 +XSA:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ws:8 gqa:8/8 +lr:embed=0.035 matrix=0.025 scalar=0.025 batch:786432 wall:600s seed:1337 +[rank4]:[W327 07:07:36.590762303 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank2]:[W327 07:07:37.223567146 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank7]:[W327 07:07:37.490938754 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank0]:[W327 07:07:37.579277750 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank5]:[W327 07:07:37.639786271 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank3]:[W327 07:07:38.910776941 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank6]:[W327 07:07:38.703437798 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank1]:[W327 07:07:40.240296194 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +pre-compiling mixer loss path (dummy data, no training tokens)... +pre-compile done +prefilling n-gram tables from training shards (frozen oracle)... +prefilled 8,000,040,960 tokens in 31839ms (counted in wallclock) +step:0/20000 val_loss:6.9299 val_bpb:4.1043 train_time:31839ms step_avg:0.04ms +step:1/20000 train_loss:3.6661 train_time:32215ms step_avg:376.01ms +step:2/20000 train_loss:4.5101 train_time:32332ms step_avg:246.30ms +step:3/20000 train_loss:4.0795 train_time:32431ms step_avg:197.46ms +step:4/20000 train_loss:3.5540 train_time:32531ms step_avg:172.92ms +step:5/20000 train_loss:3.6078 train_time:32630ms step_avg:158.15ms +step:6/20000 train_loss:3.7346 train_time:32729ms step_avg:148.35ms +step:7/20000 train_loss:3.7337 train_time:32828ms step_avg:141.35ms +step:8/20000 train_loss:3.6417 train_time:32928ms step_avg:136.16ms +step:9/20000 train_loss:3.4956 train_time:33028ms step_avg:132.07ms +step:10/20000 train_loss:3.3907 train_time:33127ms step_avg:128.83ms +step:500/20000 train_loss:1.2285 train_time:82300ms step_avg:100.92ms +step:1000/20000 train_loss:1.1611 train_time:133980ms step_avg:102.14ms +step:1500/20000 train_loss:1.1338 train_time:183793ms step_avg:101.30ms +step:2000/20000 train_loss:1.0536 train_time:233711ms step_avg:100.94ms +step:2500/20000 train_loss:1.1021 train_time:283621ms step_avg:100.71ms +step:3000/20000 train_loss:1.0895 train_time:333512ms step_avg:100.56ms +step:3500/20000 train_loss:1.0927 train_time:385390ms step_avg:101.01ms +late_qat:enabled step:3687 scale:0.4964 +step:4000/20000 train_loss:0.9832 train_time:438782ms step_avg:101.74ms +step:4000/20000 val_loss:1.9849 val_bpb:1.1756 train_time:438785ms step_avg:101.74ms +step:4500/20000 train_loss:1.0548 train_time:490510ms step_avg:101.93ms +swa:start step:4700 +step:5000/20000 train_loss:1.0412 train_time:542961ms step_avg:102.22ms +step:5353/20000 val_loss:1.9134 val_bpb:1.1332 train_time:580568ms step_avg:102.51ms +stopping_early: wallclock_cap train_time:580568ms step:5353/20000 +peak memory allocated: 28254 MiB reserved: 28560 MiB +ema:applying EMA weights (skipping diagnostic evals) +Serialized model: 130466125 bytes +Code size: 98302 bytes +pruning:3.0% magnitude pruning applied +Serialized model int6+zstd: 15162289 bytes +Total submission size int6+zstd: 15260591 bytes +TTT: epochs=1 lr=0.001 freeze_first=0 chunk=32768 opt=adamw +TTT temperature: 0.98 + Logistic context mixer enabled: eta=0.0 +ttt:start chunks=1893 chunk_tokens=32768 windows=969057 stride=64 lr=0.001 epochs=1 opt=adamw freeze_first=0 +ttt:params unfrozen=19476 frozen=33306712 + ttt_train [1] seqs=16 start_train... + ttt_train [1] epoch=1/1 batches=2 ... + step done ep=1 bs=0 loss=2.3371 + ttt_chunk [1/1893] bpb=0.028732 time=0.4s + ttt_train [2] seqs=16 start_train... + ttt_train [2] epoch=1/1 batches=2 ... + step done ep=1 bs=0 loss=2.7936 + ttt_chunk [2/1893] bpb=0.028518 time=0.6s + ttt_train [3] seqs=16 start_train... + ttt_train [3] epoch=1/1 batches=2 ... + step done ep=1 bs=0 loss=1.9119 + ttt_chunk [3/1893] bpb=0.027904 time=0.8s + ttt_chunk [4/1893] bpb=0.027901 time=1.0s + ttt_chunk [5/1893] bpb=0.027743 time=1.2s + ttt_chunk [11/1893] bpb=0.027398 time=2.3s + ttt_chunk [21/1893] bpb=0.027424 time=4.1s + ttt_chunk [31/1893] bpb=0.027292 time=6.0s + ttt_chunk [41/1893] bpb=0.027204 time=7.9s + ttt_chunk [51/1893] bpb=0.027178 time=9.8s + ttt_chunk [61/1893] bpb=0.027315 time=11.6s + ttt_chunk [71/1893] bpb=0.027298 time=13.5s + ttt_chunk [81/1893] bpb=0.027394 time=15.4s + ttt_chunk [91/1893] bpb=0.027444 time=17.2s + ttt_chunk [101/1893] bpb=0.027496 time=19.1s + ttt_chunk [111/1893] bpb=0.027527 time=21.0s + ttt_chunk [121/1893] bpb=0.027438 time=22.8s + ttt_chunk [131/1893] bpb=0.027440 time=24.7s + ttt_chunk [141/1893] bpb=0.027555 time=26.6s + ttt_chunk [151/1893] bpb=0.027577 time=28.4s + ttt_chunk [161/1893] bpb=0.027592 time=30.4s + ttt_chunk [171/1893] bpb=0.027636 time=32.4s + ttt_chunk [181/1893] bpb=0.027692 time=34.5s + ttt_chunk [191/1893] bpb=0.027798 time=36.5s + ttt_chunk [201/1893] bpb=0.027814 time=38.3s + ttt_chunk [211/1893] bpb=0.027779 time=40.2s + ttt_chunk [221/1893] bpb=0.027793 time=42.1s + ttt_chunk [231/1893] bpb=0.027773 time=44.0s + ttt_chunk [241/1893] bpb=0.027790 time=45.9s + ttt_chunk [251/1893] bpb=0.027807 time=47.8s + ttt_chunk [261/1893] bpb=0.027745 time=49.6s + ttt_chunk [271/1893] bpb=0.027739 time=51.5s + ttt_chunk [281/1893] bpb=0.027740 time=53.4s + ttt_chunk [291/1893] bpb=0.027763 time=55.3s + ttt_chunk [301/1893] bpb=0.027765 time=57.1s + ttt_chunk [311/1893] bpb=0.027792 time=59.0s + ttt_chunk [321/1893] bpb=0.027812 time=60.9s + ttt_chunk [331/1893] bpb=0.027816 time=62.7s + ttt_chunk [341/1893] bpb=0.027788 time=64.6s + ttt_chunk [351/1893] bpb=0.027821 time=66.5s + ttt_chunk [361/1893] bpb=0.027864 time=68.3s + ttt_chunk [371/1893] bpb=0.027838 time=70.2s + ttt_chunk [381/1893] bpb=0.027843 time=72.1s + ttt_chunk [391/1893] bpb=0.027839 time=73.9s + ttt_chunk [401/1893] bpb=0.027813 time=75.8s + ttt_chunk [411/1893] bpb=0.027800 time=77.7s + ttt_chunk [421/1893] bpb=0.027775 time=79.6s + ttt_chunk [431/1893] bpb=0.027761 time=81.4s + ttt_chunk [441/1893] bpb=0.027758 time=83.3s + ttt_chunk [451/1893] bpb=0.027751 time=85.2s + ttt_chunk [461/1893] bpb=0.027733 time=87.0s + ttt_chunk [471/1893] bpb=0.027728 time=88.9s + ttt_chunk [481/1893] bpb=0.027725 time=90.8s + ttt_chunk [491/1893] bpb=0.027700 time=92.6s + ttt_chunk [501/1893] bpb=0.027694 time=94.5s + ttt_chunk [511/1893] bpb=0.027691 time=96.4s + ttt_chunk [521/1893] bpb=0.027656 time=98.3s + ttt_chunk [531/1893] bpb=0.027666 time=100.2s + ttt_chunk [541/1893] bpb=0.027673 time=102.0s + ttt_chunk [551/1893] bpb=0.027653 time=103.9s + ttt_chunk [561/1893] bpb=0.027653 time=105.8s + ttt_chunk [571/1893] bpb=0.027633 time=107.7s + ttt_chunk [581/1893] bpb=0.027617 time=109.6s + ttt_chunk [591/1893] bpb=0.027608 time=111.4s + ttt_chunk [601/1893] bpb=0.027612 time=113.3s + ttt_chunk [611/1893] bpb=0.027621 time=115.2s + ttt_chunk [621/1893] bpb=0.027614 time=117.1s + ttt_chunk [631/1893] bpb=0.027617 time=118.9s + ttt_chunk [641/1893] bpb=0.027615 time=120.8s + ttt_chunk [651/1893] bpb=0.027607 time=122.7s + ttt_chunk [661/1893] bpb=0.027601 time=124.5s + ttt_chunk [671/1893] bpb=0.027605 time=126.4s + ttt_chunk [681/1893] bpb=0.027603 time=128.3s + ttt_chunk [691/1893] bpb=0.027616 time=130.2s + ttt_chunk [701/1893] bpb=0.027607 time=132.1s + ttt_chunk [711/1893] bpb=0.027614 time=134.0s + ttt_chunk [721/1893] bpb=0.027610 time=135.9s + ttt_chunk [731/1893] bpb=0.027613 time=137.7s + ttt_chunk [741/1893] bpb=0.027612 time=139.7s + ttt_chunk [751/1893] bpb=0.027606 time=141.8s + ttt_chunk [761/1893] bpb=0.027610 time=143.7s + ttt_chunk [771/1893] bpb=0.027608 time=145.6s + ttt_chunk [781/1893] bpb=0.027626 time=147.4s + ttt_chunk [791/1893] bpb=0.027621 time=149.3s + ttt_chunk [801/1893] bpb=0.027622 time=151.2s + ttt_chunk [811/1893] bpb=0.027621 time=153.0s + ttt_chunk [821/1893] bpb=0.027621 time=154.9s + ttt_chunk [831/1893] bpb=0.027622 time=156.8s + ttt_chunk [841/1893] bpb=0.027614 time=158.7s + ttt_chunk [851/1893] bpb=0.027614 time=160.5s + ttt_chunk [861/1893] bpb=0.027611 time=162.4s + ttt_chunk [871/1893] bpb=0.027616 time=164.3s + ttt_chunk [881/1893] bpb=0.027624 time=166.2s + ttt_chunk [891/1893] bpb=0.027623 time=168.1s + ttt_chunk [901/1893] bpb=0.027632 time=170.0s + ttt_chunk [911/1893] bpb=0.027654 time=171.9s + ttt_chunk [921/1893] bpb=0.027661 time=173.7s + ttt_chunk [931/1893] bpb=0.027664 time=175.6s + ttt_chunk [941/1893] bpb=0.027662 time=177.5s + ttt_chunk [951/1893] bpb=0.027672 time=179.4s + ttt_chunk [961/1893] bpb=0.027669 time=181.2s + ttt_chunk [971/1893] bpb=0.027679 time=183.1s + ttt_chunk [981/1893] bpb=0.027679 time=185.0s + ttt_chunk [991/1893] bpb=0.027680 time=186.9s + ttt_chunk [1001/1893] bpb=0.027677 time=188.8s + ttt_chunk [1011/1893] bpb=0.027676 time=190.7s + ttt_chunk [1021/1893] bpb=0.027678 time=192.5s + ttt_chunk [1031/1893] bpb=0.027679 time=194.4s + ttt_chunk [1041/1893] bpb=0.027670 time=196.3s + ttt_chunk [1051/1893] bpb=0.027664 time=198.2s + ttt_chunk [1061/1893] bpb=0.027662 time=200.0s + ttt_chunk [1071/1893] bpb=0.027674 time=201.9s + ttt_chunk [1081/1893] bpb=0.027677 time=203.8s + ttt_chunk [1091/1893] bpb=0.027684 time=205.7s + ttt_chunk [1101/1893] bpb=0.027679 time=207.6s + ttt_chunk [1111/1893] bpb=0.027672 time=209.5s + ttt_chunk [1121/1893] bpb=0.027667 time=211.4s + ttt_chunk [1131/1893] bpb=0.027663 time=213.2s + ttt_chunk [1141/1893] bpb=0.027656 time=215.1s + ttt_chunk [1151/1893] bpb=0.027654 time=217.0s + ttt_chunk [1161/1893] bpb=0.027645 time=218.9s + ttt_chunk [1171/1893] bpb=0.027645 time=220.8s + ttt_chunk [1181/1893] bpb=0.027632 time=222.7s + ttt_chunk [1191/1893] bpb=0.027630 time=224.6s + ttt_chunk [1201/1893] bpb=0.027630 time=226.5s + ttt_chunk [1211/1893] bpb=0.027620 time=228.4s + ttt_chunk [1221/1893] bpb=0.027614 time=230.3s + ttt_chunk [1231/1893] bpb=0.027604 time=232.2s + ttt_chunk [1241/1893] bpb=0.027591 time=234.0s + ttt_chunk [1251/1893] bpb=0.027579 time=235.9s + ttt_chunk [1261/1893] bpb=0.027579 time=237.8s + ttt_chunk [1271/1893] bpb=0.027572 time=239.7s + ttt_chunk [1281/1893] bpb=0.027564 time=241.6s + ttt_chunk [1291/1893] bpb=0.027561 time=243.4s + ttt_chunk [1301/1893] bpb=0.027551 time=245.3s + ttt_chunk [1311/1893] bpb=0.027546 time=247.2s + ttt_chunk [1321/1893] bpb=0.027538 time=249.1s + ttt_chunk [1331/1893] bpb=0.027535 time=251.0s + ttt_chunk [1341/1893] bpb=0.027529 time=252.8s + ttt_chunk [1351/1893] bpb=0.027528 time=254.8s + ttt_chunk [1361/1893] bpb=0.027531 time=256.9s + ttt_chunk [1371/1893] bpb=0.027528 time=258.9s + ttt_chunk [1381/1893] bpb=0.027530 time=260.9s + ttt_chunk [1391/1893] bpb=0.027522 time=262.8s + ttt_chunk [1401/1893] bpb=0.027525 time=264.7s + ttt_chunk [1411/1893] bpb=0.027528 time=266.6s + ttt_chunk [1421/1893] bpb=0.027531 time=268.4s + ttt_chunk [1431/1893] bpb=0.027531 time=270.3s + ttt_chunk [1441/1893] bpb=0.027538 time=272.2s + ttt_chunk [1451/1893] bpb=0.027543 time=274.1s + ttt_chunk [1461/1893] bpb=0.027540 time=276.0s + ttt_chunk [1471/1893] bpb=0.027553 time=277.9s + ttt_chunk [1481/1893] bpb=0.027545 time=279.8s + ttt_chunk [1491/1893] bpb=0.027544 time=281.7s + ttt_chunk [1501/1893] bpb=0.027545 time=283.6s + ttt_chunk [1511/1893] bpb=0.027544 time=285.4s + ttt_chunk [1521/1893] bpb=0.027544 time=287.4s + ttt_chunk [1531/1893] bpb=0.027539 time=289.3s + ttt_chunk [1541/1893] bpb=0.027535 time=291.1s + ttt_chunk [1551/1893] bpb=0.027540 time=293.0s + ttt_chunk [1561/1893] bpb=0.027542 time=294.9s + ttt_chunk [1571/1893] bpb=0.027541 time=296.9s + ttt_chunk [1581/1893] bpb=0.027544 time=298.8s + ttt_chunk [1591/1893] bpb=0.027542 time=300.7s + ttt_chunk [1601/1893] bpb=0.027543 time=302.6s + ttt_chunk [1611/1893] bpb=0.027540 time=304.5s + ttt_chunk [1621/1893] bpb=0.027532 time=306.3s + ttt_chunk [1631/1893] bpb=0.027535 time=308.2s + ttt_chunk [1641/1893] bpb=0.027535 time=310.1s + ttt_chunk [1651/1893] bpb=0.027533 time=312.0s + ttt_chunk [1661/1893] bpb=0.027530 time=314.1s + ttt_chunk [1671/1893] bpb=0.027535 time=316.2s + ttt_chunk [1681/1893] bpb=0.027536 time=318.3s + ttt_chunk [1691/1893] bpb=0.027532 time=320.4s + ttt_chunk [1701/1893] bpb=0.027531 time=322.5s + ttt_chunk [1711/1893] bpb=0.027527 time=324.6s + ttt_chunk [1721/1893] bpb=0.027524 time=326.7s + ttt_chunk [1731/1893] bpb=0.027522 time=328.6s + ttt_chunk [1741/1893] bpb=0.027521 time=330.5s + ttt_chunk [1751/1893] bpb=0.027516 time=332.4s + ttt_chunk [1761/1893] bpb=0.027517 time=334.3s + ttt_chunk [1771/1893] bpb=0.027515 time=336.2s + ttt_chunk [1781/1893] bpb=0.027517 time=338.1s + ttt_chunk [1791/1893] bpb=0.027509 time=340.0s + ttt_chunk [1801/1893] bpb=0.027506 time=341.9s + ttt_chunk [1811/1893] bpb=0.027502 time=343.8s + ttt_chunk [1821/1893] bpb=0.027501 time=345.7s + ttt_chunk [1831/1893] bpb=0.027490 time=347.6s + ttt_chunk [1841/1893] bpb=0.027490 time=349.5s + ttt_chunk [1851/1893] bpb=0.027487 time=351.4s + ttt_chunk [1861/1893] bpb=0.027479 time=353.3s + ttt_chunk [1871/1893] bpb=0.027477 time=355.1s + ttt_chunk [1881/1893] bpb=0.027470 time=357.0s + ttt_chunk [1891/1893] bpb=0.027467 time=358.9s + ttt_chunk [1893/1893] bpb=0.027468 time=359.2s +ttt:done val_loss=0.046323 val_bpb=0.027435 elapsed=359.2s +expert_logit[neural]: mean=-7.2118 std=6.4813 min=-46.0000 max=32.5000 +expert_logit[ngram_2]: mean=-17.0724 std=3.4105 min=-40.7500 max=3.7188 +expert_logit[ngram_3]: mean=-10.6438 std=3.1437 min=-37.7500 max=4.5000 +expert_logit[ngram_4]: mean=-5.7335 std=2.5599 min=-25.7500 max=8.3750 +expert_logit[ngram_5]: mean=-1.6281 std=2.8115 min=-21.3750 max=13.2500 +expert_logit[ngram_6]: mean=1.8665 std=3.9461 min=-24.7500 max=22.2500 +expert_logit[ngram_7]: mean=4.6293 std=5.1912 min=-29.2500 max=31.2500 +expert_logit[ngram_8]: mean=7.3308 std=6.4799 min=-34.2500 max=38.7500 +expert_logit[ngram_9]: mean=9.6665 std=7.4980 min=-37.5000 max=46.5000 +expert_logit[ngram_10]: mean=11.8279 std=8.4019 min=-42.2500 max=54.7500 +expert_logit[ngram_11]: mean=13.8865 std=9.2806 min=-44.2500 max=59.5000 +expert_logit[ngram_12]: mean=15.3170 std=9.8991 min=-46.2500 max=66.0000 +expert_logit[ngram_13]: mean=16.5718 std=10.7269 min=-52.0000 max=70.0000 +expert_logit[ngram_14]: mean=18.5994 std=10.9654 min=-50.0000 max=74.5000 +expert_logit[ngram_15]: mean=-2.5386 std=3.3954 min=-23.3750 max=21.5000 +expert_logit[ngram_16]: mean=28.0009 std=12.2504 min=-43.2500 max=94.0000 +final_int6_ttt val_loss:0.0463 val_bpb:0.0274 stride:64 eval_time:497964ms +final_int6_ttt_exact val_loss:0.04632285 val_bpb:0.02743500 diff --git a/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_seed2025.log b/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_seed2025.log new file mode 100644 index 000000000..b92bf78f5 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_seed2025.log @@ -0,0 +1,305 @@ +W0327 08:38:08.485000 1896745 torch/distributed/run.py:803] +W0327 08:38:08.485000 1896745 torch/distributed/run.py:803] ***************************************** +W0327 08:38:08.485000 1896745 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0327 08:38:08.485000 1896745 torch/distributed/run.py:803] ***************************************** +logs/12dd5faf-8a74-48b5-9b7f-3eb9cf5fdaef.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +mixed_precision: 68 int5 layers, 0 int6 layers (last 0 blocks) +model_params:33326188 +XSA:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ws:8 gqa:8/8 +lr:embed=0.035 matrix=0.025 scalar=0.025 batch:786432 wall:600s seed:2025 +[rank5]:[W327 08:38:33.483747569 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank4]:[W327 08:38:33.616304722 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank3]:[W327 08:38:33.652951326 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank7]:[W327 08:38:33.714327266 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank6]:[W327 08:38:34.796944111 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank1]:[W327 08:38:34.267750697 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank0]:[W327 08:38:34.636405054 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank2]:[W327 08:38:34.645050269 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +pre-compiling mixer loss path (dummy data, no training tokens)... +pre-compile done +prefilling n-gram tables from training shards (frozen oracle)... +prefilled 8,000,040,960 tokens in 32097ms (counted in wallclock) +step:0/20000 val_loss:6.9292 val_bpb:4.1039 train_time:32097ms step_avg:0.07ms +step:1/20000 train_loss:3.6665 train_time:32454ms step_avg:357.50ms +step:2/20000 train_loss:4.4834 train_time:32571ms step_avg:237.17ms +step:3/20000 train_loss:4.0374 train_time:32670ms step_avg:191.19ms +step:4/20000 train_loss:3.5538 train_time:32770ms step_avg:168.21ms +step:5/20000 train_loss:3.6340 train_time:32869ms step_avg:154.47ms +step:6/20000 train_loss:3.8107 train_time:32968ms step_avg:145.30ms +step:7/20000 train_loss:3.7905 train_time:33068ms step_avg:138.79ms +step:8/20000 train_loss:3.6914 train_time:33168ms step_avg:133.87ms +step:9/20000 train_loss:3.4838 train_time:33268ms step_avg:130.09ms +step:10/20000 train_loss:3.3686 train_time:33367ms step_avg:127.06ms +step:500/20000 train_loss:1.2265 train_time:82730ms step_avg:101.27ms +step:1000/20000 train_loss:1.1635 train_time:132690ms step_avg:100.59ms +step:1500/20000 train_loss:1.1360 train_time:183029ms step_avg:100.62ms +step:2000/20000 train_loss:1.0533 train_time:233624ms step_avg:100.76ms +step:2500/20000 train_loss:1.1010 train_time:284362ms step_avg:100.91ms +step:3000/20000 train_loss:1.0878 train_time:335279ms step_avg:101.06ms +step:3500/20000 train_loss:1.0897 train_time:386055ms step_avg:101.13ms +late_qat:enabled step:3687 scale:0.4998 +step:4000/20000 train_loss:0.9814 train_time:437489ms step_avg:101.35ms +step:4000/20000 val_loss:1.9842 val_bpb:1.1751 train_time:437493ms step_avg:101.35ms +step:4500/20000 train_loss:1.0553 train_time:489157ms step_avg:101.57ms +swa:start step:4750 +step:5000/20000 train_loss:1.0400 train_time:541888ms step_avg:101.96ms +step:5370/20000 val_loss:1.9114 val_bpb:1.1320 train_time:581359ms step_avg:102.28ms +stopping_early: wallclock_cap train_time:581359ms step:5370/20000 +peak memory allocated: 28254 MiB reserved: 28560 MiB +ema:applying EMA weights (skipping diagnostic evals) +Serialized model: 130466125 bytes +Code size: 98302 bytes +pruning:3.0% magnitude pruning applied +Serialized model int6+zstd: 15199150 bytes +Total submission size int6+zstd: 15297452 bytes +TTT: epochs=1 lr=0.001 freeze_first=0 chunk=32768 opt=adamw +TTT temperature: 0.98 + Logistic context mixer enabled: eta=0.0 +ttt:start chunks=1893 chunk_tokens=32768 windows=969057 stride=64 lr=0.001 epochs=1 opt=adamw freeze_first=0 +ttt:params unfrozen=19476 frozen=33306712 + ttt_train [1] seqs=16 start_train... + ttt_train [1] epoch=1/1 batches=2 ... + step done ep=1 bs=0 loss=2.3417 + ttt_chunk [1/1893] bpb=0.028574 time=0.4s + ttt_train [2] seqs=16 start_train... + ttt_train [2] epoch=1/1 batches=2 ... + step done ep=1 bs=0 loss=2.7812 + ttt_chunk [2/1893] bpb=0.028393 time=0.6s + ttt_train [3] seqs=16 start_train... + ttt_train [3] epoch=1/1 batches=2 ... + step done ep=1 bs=0 loss=1.9111 + ttt_chunk [3/1893] bpb=0.027843 time=0.8s + ttt_chunk [4/1893] bpb=0.027857 time=1.0s + ttt_chunk [5/1893] bpb=0.027726 time=1.2s + ttt_chunk [11/1893] bpb=0.027332 time=2.3s + ttt_chunk [21/1893] bpb=0.027446 time=4.2s + ttt_chunk [31/1893] bpb=0.027294 time=6.1s + ttt_chunk [41/1893] bpb=0.027192 time=8.0s + ttt_chunk [51/1893] bpb=0.027191 time=9.9s + ttt_chunk [61/1893] bpb=0.027330 time=11.8s + ttt_chunk [71/1893] bpb=0.027303 time=13.8s + ttt_chunk [81/1893] bpb=0.027384 time=15.7s + ttt_chunk [91/1893] bpb=0.027419 time=17.6s + ttt_chunk [101/1893] bpb=0.027474 time=19.5s + ttt_chunk [111/1893] bpb=0.027507 time=21.4s + ttt_chunk [121/1893] bpb=0.027425 time=23.3s + ttt_chunk [131/1893] bpb=0.027449 time=25.3s + ttt_chunk [141/1893] bpb=0.027563 time=27.2s + ttt_chunk [151/1893] bpb=0.027588 time=29.1s + ttt_chunk [161/1893] bpb=0.027616 time=31.0s + ttt_chunk [171/1893] bpb=0.027656 time=32.9s + ttt_chunk [181/1893] bpb=0.027716 time=34.8s + ttt_chunk [191/1893] bpb=0.027814 time=36.7s + ttt_chunk [201/1893] bpb=0.027824 time=38.6s + ttt_chunk [211/1893] bpb=0.027793 time=40.6s + ttt_chunk [221/1893] bpb=0.027807 time=42.5s + ttt_chunk [231/1893] bpb=0.027784 time=44.4s + ttt_chunk [241/1893] bpb=0.027793 time=46.3s + ttt_chunk [251/1893] bpb=0.027815 time=48.2s + ttt_chunk [261/1893] bpb=0.027753 time=50.1s + ttt_chunk [271/1893] bpb=0.027746 time=52.0s + ttt_chunk [281/1893] bpb=0.027745 time=53.9s + ttt_chunk [291/1893] bpb=0.027771 time=55.8s + ttt_chunk [301/1893] bpb=0.027770 time=57.7s + ttt_chunk [311/1893] bpb=0.027795 time=59.6s + ttt_chunk [321/1893] bpb=0.027814 time=61.6s + ttt_chunk [331/1893] bpb=0.027822 time=63.5s + ttt_chunk [341/1893] bpb=0.027792 time=65.4s + ttt_chunk [351/1893] bpb=0.027827 time=67.3s + ttt_chunk [361/1893] bpb=0.027867 time=69.2s + ttt_chunk [371/1893] bpb=0.027842 time=71.1s + ttt_chunk [381/1893] bpb=0.027846 time=73.1s + ttt_chunk [391/1893] bpb=0.027846 time=75.0s + ttt_chunk [401/1893] bpb=0.027820 time=76.9s + ttt_chunk [411/1893] bpb=0.027805 time=78.8s + ttt_chunk [421/1893] bpb=0.027781 time=80.8s + ttt_chunk [431/1893] bpb=0.027768 time=82.7s + ttt_chunk [441/1893] bpb=0.027766 time=84.6s + ttt_chunk [451/1893] bpb=0.027759 time=86.5s + ttt_chunk [461/1893] bpb=0.027742 time=88.4s + ttt_chunk [471/1893] bpb=0.027742 time=90.3s + ttt_chunk [481/1893] bpb=0.027735 time=92.3s + ttt_chunk [491/1893] bpb=0.027709 time=94.2s + ttt_chunk [501/1893] bpb=0.027702 time=96.1s + ttt_chunk [511/1893] bpb=0.027699 time=98.0s + ttt_chunk [521/1893] bpb=0.027664 time=99.9s + ttt_chunk [531/1893] bpb=0.027672 time=101.8s + ttt_chunk [541/1893] bpb=0.027676 time=103.7s + ttt_chunk [551/1893] bpb=0.027657 time=105.6s + ttt_chunk [561/1893] bpb=0.027659 time=107.5s + ttt_chunk [571/1893] bpb=0.027640 time=109.4s + ttt_chunk [581/1893] bpb=0.027624 time=111.4s + ttt_chunk [591/1893] bpb=0.027613 time=113.3s + ttt_chunk [601/1893] bpb=0.027617 time=115.2s + ttt_chunk [611/1893] bpb=0.027626 time=117.1s + ttt_chunk [621/1893] bpb=0.027618 time=119.0s + ttt_chunk [631/1893] bpb=0.027621 time=120.9s + ttt_chunk [641/1893] bpb=0.027621 time=122.8s + ttt_chunk [651/1893] bpb=0.027613 time=124.7s + ttt_chunk [661/1893] bpb=0.027608 time=126.6s + ttt_chunk [671/1893] bpb=0.027614 time=128.5s + ttt_chunk [681/1893] bpb=0.027612 time=130.4s + ttt_chunk [691/1893] bpb=0.027624 time=132.3s + ttt_chunk [701/1893] bpb=0.027615 time=134.3s + ttt_chunk [711/1893] bpb=0.027621 time=136.2s + ttt_chunk [721/1893] bpb=0.027618 time=138.1s + ttt_chunk [731/1893] bpb=0.027622 time=140.0s + ttt_chunk [741/1893] bpb=0.027620 time=142.0s + ttt_chunk [751/1893] bpb=0.027613 time=143.9s + ttt_chunk [761/1893] bpb=0.027614 time=145.8s + ttt_chunk [771/1893] bpb=0.027612 time=147.7s + ttt_chunk [781/1893] bpb=0.027630 time=149.6s + ttt_chunk [791/1893] bpb=0.027626 time=151.5s + ttt_chunk [801/1893] bpb=0.027628 time=153.4s + ttt_chunk [811/1893] bpb=0.027627 time=155.3s + ttt_chunk [821/1893] bpb=0.027626 time=157.4s + ttt_chunk [831/1893] bpb=0.027626 time=159.5s + ttt_chunk [841/1893] bpb=0.027619 time=161.4s + ttt_chunk [851/1893] bpb=0.027618 time=163.4s + ttt_chunk [861/1893] bpb=0.027616 time=165.3s + ttt_chunk [871/1893] bpb=0.027620 time=167.2s + ttt_chunk [881/1893] bpb=0.027629 time=169.1s + ttt_chunk [891/1893] bpb=0.027628 time=171.0s + ttt_chunk [901/1893] bpb=0.027637 time=172.9s + ttt_chunk [911/1893] bpb=0.027658 time=174.8s + ttt_chunk [921/1893] bpb=0.027664 time=176.7s + ttt_chunk [931/1893] bpb=0.027666 time=178.6s + ttt_chunk [941/1893] bpb=0.027665 time=180.5s + ttt_chunk [951/1893] bpb=0.027674 time=182.5s + ttt_chunk [961/1893] bpb=0.027671 time=184.4s + ttt_chunk [971/1893] bpb=0.027681 time=186.3s + ttt_chunk [981/1893] bpb=0.027681 time=188.2s + ttt_chunk [991/1893] bpb=0.027684 time=190.1s + ttt_chunk [1001/1893] bpb=0.027679 time=192.1s + ttt_chunk [1011/1893] bpb=0.027679 time=194.0s + ttt_chunk [1021/1893] bpb=0.027680 time=195.9s + ttt_chunk [1031/1893] bpb=0.027682 time=197.8s + ttt_chunk [1041/1893] bpb=0.027673 time=199.7s + ttt_chunk [1051/1893] bpb=0.027666 time=201.7s + ttt_chunk [1061/1893] bpb=0.027665 time=203.7s + ttt_chunk [1071/1893] bpb=0.027678 time=205.6s + ttt_chunk [1081/1893] bpb=0.027681 time=207.5s + ttt_chunk [1091/1893] bpb=0.027686 time=209.4s + ttt_chunk [1101/1893] bpb=0.027682 time=211.3s + ttt_chunk [1111/1893] bpb=0.027675 time=213.2s + ttt_chunk [1121/1893] bpb=0.027671 time=215.2s + ttt_chunk [1131/1893] bpb=0.027667 time=217.1s + ttt_chunk [1141/1893] bpb=0.027660 time=219.0s + ttt_chunk [1151/1893] bpb=0.027656 time=220.9s + ttt_chunk [1161/1893] bpb=0.027648 time=222.9s + ttt_chunk [1171/1893] bpb=0.027647 time=224.8s + ttt_chunk [1181/1893] bpb=0.027635 time=226.8s + ttt_chunk [1191/1893] bpb=0.027633 time=228.7s + ttt_chunk [1201/1893] bpb=0.027634 time=230.6s + ttt_chunk [1211/1893] bpb=0.027625 time=232.5s + ttt_chunk [1221/1893] bpb=0.027619 time=234.4s + ttt_chunk [1231/1893] bpb=0.027609 time=236.4s + ttt_chunk [1241/1893] bpb=0.027597 time=238.3s + ttt_chunk [1251/1893] bpb=0.027585 time=240.2s + ttt_chunk [1261/1893] bpb=0.027584 time=242.1s + ttt_chunk [1271/1893] bpb=0.027577 time=244.0s + ttt_chunk [1281/1893] bpb=0.027569 time=245.9s + ttt_chunk [1291/1893] bpb=0.027566 time=247.9s + ttt_chunk [1301/1893] bpb=0.027557 time=249.8s + ttt_chunk [1311/1893] bpb=0.027551 time=251.7s + ttt_chunk [1321/1893] bpb=0.027543 time=253.6s + ttt_chunk [1331/1893] bpb=0.027540 time=255.5s + ttt_chunk [1341/1893] bpb=0.027534 time=257.4s + ttt_chunk [1351/1893] bpb=0.027533 time=259.4s + ttt_chunk [1361/1893] bpb=0.027536 time=261.4s + ttt_chunk [1371/1893] bpb=0.027535 time=263.3s + ttt_chunk [1381/1893] bpb=0.027537 time=265.2s + ttt_chunk [1391/1893] bpb=0.027529 time=267.1s + ttt_chunk [1401/1893] bpb=0.027531 time=269.1s + ttt_chunk [1411/1893] bpb=0.027535 time=271.0s + ttt_chunk [1421/1893] bpb=0.027537 time=272.9s + ttt_chunk [1431/1893] bpb=0.027537 time=274.8s + ttt_chunk [1441/1893] bpb=0.027544 time=276.8s + ttt_chunk [1451/1893] bpb=0.027548 time=278.7s + ttt_chunk [1461/1893] bpb=0.027545 time=280.6s + ttt_chunk [1471/1893] bpb=0.027558 time=282.5s + ttt_chunk [1481/1893] bpb=0.027551 time=284.5s + ttt_chunk [1491/1893] bpb=0.027549 time=286.6s + ttt_chunk [1501/1893] bpb=0.027550 time=288.6s + ttt_chunk [1511/1893] bpb=0.027549 time=290.5s + ttt_chunk [1521/1893] bpb=0.027548 time=292.4s + ttt_chunk [1531/1893] bpb=0.027545 time=294.3s + ttt_chunk [1541/1893] bpb=0.027541 time=296.3s + ttt_chunk [1551/1893] bpb=0.027547 time=298.2s + ttt_chunk [1561/1893] bpb=0.027549 time=300.1s + ttt_chunk [1571/1893] bpb=0.027549 time=302.0s + ttt_chunk [1581/1893] bpb=0.027552 time=304.0s + ttt_chunk [1591/1893] bpb=0.027552 time=305.9s + ttt_chunk [1601/1893] bpb=0.027552 time=307.8s + ttt_chunk [1611/1893] bpb=0.027548 time=309.7s + ttt_chunk [1621/1893] bpb=0.027542 time=311.6s + ttt_chunk [1631/1893] bpb=0.027544 time=313.6s + ttt_chunk [1641/1893] bpb=0.027545 time=315.5s + ttt_chunk [1651/1893] bpb=0.027544 time=317.4s + ttt_chunk [1661/1893] bpb=0.027541 time=319.3s + ttt_chunk [1671/1893] bpb=0.027547 time=321.3s + ttt_chunk [1681/1893] bpb=0.027546 time=323.3s + ttt_chunk [1691/1893] bpb=0.027542 time=325.3s + ttt_chunk [1701/1893] bpb=0.027541 time=327.2s + ttt_chunk [1711/1893] bpb=0.027538 time=329.1s + ttt_chunk [1721/1893] bpb=0.027535 time=331.0s + ttt_chunk [1731/1893] bpb=0.027532 time=332.9s + ttt_chunk [1741/1893] bpb=0.027532 time=334.9s + ttt_chunk [1751/1893] bpb=0.027527 time=336.8s + ttt_chunk [1761/1893] bpb=0.027529 time=338.7s + ttt_chunk [1771/1893] bpb=0.027526 time=340.6s + ttt_chunk [1781/1893] bpb=0.027529 time=342.6s + ttt_chunk [1791/1893] bpb=0.027519 time=344.5s + ttt_chunk [1801/1893] bpb=0.027517 time=346.4s + ttt_chunk [1811/1893] bpb=0.027514 time=348.3s + ttt_chunk [1821/1893] bpb=0.027512 time=350.2s + ttt_chunk [1831/1893] bpb=0.027502 time=352.1s + ttt_chunk [1841/1893] bpb=0.027501 time=354.1s + ttt_chunk [1851/1893] bpb=0.027498 time=356.1s + ttt_chunk [1861/1893] bpb=0.027491 time=358.3s + ttt_chunk [1871/1893] bpb=0.027487 time=360.4s + ttt_chunk [1881/1893] bpb=0.027480 time=362.3s + ttt_chunk [1891/1893] bpb=0.027476 time=364.2s + ttt_chunk [1893/1893] bpb=0.027477 time=364.5s +ttt:done val_loss=0.046323 val_bpb=0.027435 elapsed=364.5s +expert_logit[neural]: mean=-6.1806 std=6.3514 min=-36.5000 max=30.8750 +expert_logit[ngram_2]: mean=-15.7170 std=3.3611 min=-38.0000 max=2.5312 +expert_logit[ngram_3]: mean=-10.6990 std=3.1663 min=-37.5000 max=3.0000 +expert_logit[ngram_4]: mean=-6.5967 std=2.5496 min=-25.2500 max=4.3750 +expert_logit[ngram_5]: mean=-2.4908 std=2.7207 min=-23.8750 max=12.0000 +expert_logit[ngram_6]: mean=0.5921 std=3.7816 min=-26.8750 max=19.6250 +expert_logit[ngram_7]: mean=3.0643 std=5.0826 min=-31.0000 max=27.6250 +expert_logit[ngram_8]: mean=5.2086 std=6.2586 min=-35.0000 max=34.5000 +expert_logit[ngram_9]: mean=7.4525 std=7.3023 min=-38.7500 max=40.2500 +expert_logit[ngram_10]: mean=9.3592 std=8.2036 min=-42.7500 max=47.0000 +expert_logit[ngram_11]: mean=11.2161 std=9.0116 min=-44.7500 max=51.0000 +expert_logit[ngram_12]: mean=12.5503 std=9.6972 min=-47.5000 max=57.0000 +expert_logit[ngram_13]: mean=13.6833 std=10.2814 min=-50.5000 max=59.5000 +expert_logit[ngram_14]: mean=15.8675 std=10.6952 min=-51.2500 max=66.0000 +expert_logit[ngram_15]: mean=-2.1315 std=3.3747 min=-22.0000 max=20.1250 +expert_logit[ngram_16]: mean=25.5391 std=12.0947 min=-38.0000 max=91.5000 +final_int6_ttt val_loss:0.0463 val_bpb:0.0274 stride:64 eval_time:498140ms +final_int6_ttt_exact val_loss:0.04632341 val_bpb:0.02743533 diff --git a/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_seed42.log b/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_seed42.log new file mode 100644 index 000000000..3b04ce3dd --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_Order16_FrozenOracle_TTT_0.0274/train_seed42.log @@ -0,0 +1,305 @@ +W0327 08:18:05.190000 1895535 torch/distributed/run.py:803] +W0327 08:18:05.190000 1895535 torch/distributed/run.py:803] ***************************************** +W0327 08:18:05.190000 1895535 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0327 08:18:05.190000 1895535 torch/distributed/run.py:803] ***************************************** +logs/931f6aec-736e-44bd-9042-14e1ebb75dc0.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +mixed_precision: 68 int5 layers, 0 int6 layers (last 0 blocks) +model_params:33326188 +XSA:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ws:8 gqa:8/8 +lr:embed=0.035 matrix=0.025 scalar=0.025 batch:786432 wall:600s seed:42 +[rank0]:[W327 08:18:30.879764503 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank6]:[W327 08:18:30.277205510 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank7]:[W327 08:18:30.365629981 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank5]:[W327 08:18:30.639369391 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank4]:[W327 08:18:31.791363294 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank2]:[W327 08:18:32.031553788 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank1]:[W327 08:18:32.657802264 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank3]:[W327 08:18:33.254700931 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +pre-compiling mixer loss path (dummy data, no training tokens)... +pre-compile done +prefilling n-gram tables from training shards (frozen oracle)... +prefilled 8,000,040,960 tokens in 31357ms (counted in wallclock) +step:0/20000 val_loss:6.9326 val_bpb:4.1059 train_time:31357ms step_avg:0.03ms +step:1/20000 train_loss:3.6671 train_time:31713ms step_avg:356.01ms +step:2/20000 train_loss:4.5205 train_time:31809ms step_avg:225.67ms +step:3/20000 train_loss:4.0524 train_time:31906ms step_avg:182.81ms +step:4/20000 train_loss:3.5462 train_time:32003ms step_avg:161.51ms +step:5/20000 train_loss:3.6631 train_time:32101ms step_avg:148.81ms +step:6/20000 train_loss:3.8148 train_time:32199ms step_avg:140.33ms +step:7/20000 train_loss:3.7652 train_time:32297ms step_avg:134.32ms +step:8/20000 train_loss:3.6763 train_time:32396ms step_avg:129.79ms +step:9/20000 train_loss:3.4674 train_time:32495ms step_avg:126.36ms +step:10/20000 train_loss:3.3119 train_time:32593ms step_avg:123.55ms +step:500/20000 train_loss:1.2253 train_time:81671ms step_avg:100.63ms +step:1000/20000 train_loss:1.1609 train_time:132032ms step_avg:100.67ms +step:1500/20000 train_loss:1.1335 train_time:182183ms step_avg:100.55ms +step:2000/20000 train_loss:1.0506 train_time:232247ms step_avg:100.44ms +step:2500/20000 train_loss:1.1018 train_time:282498ms step_avg:100.46ms +step:3000/20000 train_loss:1.0895 train_time:332916ms step_avg:100.52ms +step:3500/20000 train_loss:1.0923 train_time:383254ms step_avg:100.54ms +late_qat:enabled step:3721 scale:0.5000 +step:4000/20000 train_loss:0.9834 train_time:435425ms step_avg:101.02ms +step:4000/20000 val_loss:1.9846 val_bpb:1.1754 train_time:435430ms step_avg:101.02ms +step:4500/20000 train_loss:1.0564 train_time:487416ms step_avg:101.35ms +swa:start step:4750 +step:5000/20000 train_loss:1.0403 train_time:540493ms step_avg:101.83ms +step:5389/20000 val_loss:1.9107 val_bpb:1.1316 train_time:582002ms step_avg:102.18ms +stopping_early: wallclock_cap train_time:582002ms step:5389/20000 +peak memory allocated: 28254 MiB reserved: 28560 MiB +ema:applying EMA weights (skipping diagnostic evals) +Serialized model: 130466125 bytes +Code size: 98302 bytes +pruning:3.0% magnitude pruning applied +Serialized model int6+zstd: 15862241 bytes +Total submission size int6+zstd: 15960543 bytes +TTT: epochs=1 lr=0.001 freeze_first=0 chunk=32768 opt=adamw +TTT temperature: 0.98 + Logistic context mixer enabled: eta=0.0 +ttt:start chunks=1893 chunk_tokens=32768 windows=969057 stride=64 lr=0.001 epochs=1 opt=adamw freeze_first=0 +ttt:params unfrozen=19476 frozen=33306712 + ttt_train [1] seqs=16 start_train... + ttt_train [1] epoch=1/1 batches=2 ... + step done ep=1 bs=0 loss=2.3463 + ttt_chunk [1/1893] bpb=0.028806 time=0.5s + ttt_train [2] seqs=16 start_train... + ttt_train [2] epoch=1/1 batches=2 ... + step done ep=1 bs=0 loss=2.7812 + ttt_chunk [2/1893] bpb=0.028606 time=0.7s + ttt_train [3] seqs=16 start_train... + ttt_train [3] epoch=1/1 batches=2 ... + step done ep=1 bs=0 loss=1.9270 + ttt_chunk [3/1893] bpb=0.028061 time=0.9s + ttt_chunk [4/1893] bpb=0.028053 time=1.1s + ttt_chunk [5/1893] bpb=0.027890 time=1.3s + ttt_chunk [11/1893] bpb=0.027441 time=2.4s + ttt_chunk [21/1893] bpb=0.027415 time=4.3s + ttt_chunk [31/1893] bpb=0.027292 time=6.2s + ttt_chunk [41/1893] bpb=0.027194 time=8.1s + ttt_chunk [51/1893] bpb=0.027174 time=10.1s + ttt_chunk [61/1893] bpb=0.027285 time=12.0s + ttt_chunk [71/1893] bpb=0.027271 time=13.9s + ttt_chunk [81/1893] bpb=0.027357 time=15.8s + ttt_chunk [91/1893] bpb=0.027407 time=17.7s + ttt_chunk [101/1893] bpb=0.027454 time=19.6s + ttt_chunk [111/1893] bpb=0.027486 time=21.6s + ttt_chunk [121/1893] bpb=0.027402 time=23.7s + ttt_chunk [131/1893] bpb=0.027418 time=25.7s + ttt_chunk [141/1893] bpb=0.027530 time=27.6s + ttt_chunk [151/1893] bpb=0.027551 time=29.5s + ttt_chunk [161/1893] bpb=0.027563 time=31.4s + ttt_chunk [171/1893] bpb=0.027606 time=33.3s + ttt_chunk [181/1893] bpb=0.027661 time=35.2s + ttt_chunk [191/1893] bpb=0.027770 time=37.2s + ttt_chunk [201/1893] bpb=0.027778 time=39.2s + ttt_chunk [211/1893] bpb=0.027744 time=41.2s + ttt_chunk [221/1893] bpb=0.027760 time=43.2s + ttt_chunk [231/1893] bpb=0.027740 time=45.1s + ttt_chunk [241/1893] bpb=0.027754 time=47.0s + ttt_chunk [251/1893] bpb=0.027777 time=48.9s + ttt_chunk [261/1893] bpb=0.027712 time=50.8s + ttt_chunk [271/1893] bpb=0.027705 time=52.7s + ttt_chunk [281/1893] bpb=0.027704 time=54.6s + ttt_chunk [291/1893] bpb=0.027729 time=56.5s + ttt_chunk [301/1893] bpb=0.027730 time=58.5s + ttt_chunk [311/1893] bpb=0.027758 time=60.4s + ttt_chunk [321/1893] bpb=0.027777 time=62.3s + ttt_chunk [331/1893] bpb=0.027780 time=64.2s + ttt_chunk [341/1893] bpb=0.027753 time=66.1s + ttt_chunk [351/1893] bpb=0.027784 time=68.0s + ttt_chunk [361/1893] bpb=0.027824 time=69.9s + ttt_chunk [371/1893] bpb=0.027801 time=71.8s + ttt_chunk [381/1893] bpb=0.027805 time=73.8s + ttt_chunk [391/1893] bpb=0.027800 time=75.7s + ttt_chunk [401/1893] bpb=0.027773 time=77.6s + ttt_chunk [411/1893] bpb=0.027761 time=79.5s + ttt_chunk [421/1893] bpb=0.027737 time=81.4s + ttt_chunk [431/1893] bpb=0.027725 time=83.3s + ttt_chunk [441/1893] bpb=0.027724 time=85.2s + ttt_chunk [451/1893] bpb=0.027717 time=87.2s + ttt_chunk [461/1893] bpb=0.027702 time=89.1s + ttt_chunk [471/1893] bpb=0.027701 time=91.0s + ttt_chunk [481/1893] bpb=0.027695 time=92.9s + ttt_chunk [491/1893] bpb=0.027671 time=94.8s + ttt_chunk [501/1893] bpb=0.027664 time=96.7s + ttt_chunk [511/1893] bpb=0.027662 time=98.7s + ttt_chunk [521/1893] bpb=0.027631 time=100.6s + ttt_chunk [531/1893] bpb=0.027640 time=102.5s + ttt_chunk [541/1893] bpb=0.027645 time=104.4s + ttt_chunk [551/1893] bpb=0.027627 time=106.3s + ttt_chunk [561/1893] bpb=0.027629 time=108.2s + ttt_chunk [571/1893] bpb=0.027611 time=110.1s + ttt_chunk [581/1893] bpb=0.027595 time=112.1s + ttt_chunk [591/1893] bpb=0.027585 time=114.0s + ttt_chunk [601/1893] bpb=0.027588 time=115.9s + ttt_chunk [611/1893] bpb=0.027595 time=117.8s + ttt_chunk [621/1893] bpb=0.027591 time=119.7s + ttt_chunk [631/1893] bpb=0.027593 time=121.7s + ttt_chunk [641/1893] bpb=0.027589 time=123.6s + ttt_chunk [651/1893] bpb=0.027583 time=125.6s + ttt_chunk [661/1893] bpb=0.027578 time=127.7s + ttt_chunk [671/1893] bpb=0.027582 time=129.7s + ttt_chunk [681/1893] bpb=0.027579 time=131.8s + ttt_chunk [691/1893] bpb=0.027592 time=133.9s + ttt_chunk [701/1893] bpb=0.027583 time=136.0s + ttt_chunk [711/1893] bpb=0.027589 time=138.1s + ttt_chunk [721/1893] bpb=0.027585 time=140.0s + ttt_chunk [731/1893] bpb=0.027590 time=142.0s + ttt_chunk [741/1893] bpb=0.027588 time=143.9s + ttt_chunk [751/1893] bpb=0.027581 time=145.8s + ttt_chunk [761/1893] bpb=0.027583 time=147.7s + ttt_chunk [771/1893] bpb=0.027579 time=149.6s + ttt_chunk [781/1893] bpb=0.027598 time=151.5s + ttt_chunk [791/1893] bpb=0.027593 time=153.5s + ttt_chunk [801/1893] bpb=0.027594 time=155.4s + ttt_chunk [811/1893] bpb=0.027594 time=157.3s + ttt_chunk [821/1893] bpb=0.027592 time=159.2s + ttt_chunk [831/1893] bpb=0.027593 time=161.2s + ttt_chunk [841/1893] bpb=0.027584 time=163.1s + ttt_chunk [851/1893] bpb=0.027584 time=165.0s + ttt_chunk [861/1893] bpb=0.027582 time=166.9s + ttt_chunk [871/1893] bpb=0.027588 time=168.8s + ttt_chunk [881/1893] bpb=0.027597 time=170.8s + ttt_chunk [891/1893] bpb=0.027596 time=172.7s + ttt_chunk [901/1893] bpb=0.027604 time=174.6s + ttt_chunk [911/1893] bpb=0.027626 time=176.5s + ttt_chunk [921/1893] bpb=0.027634 time=178.4s + ttt_chunk [931/1893] bpb=0.027638 time=180.4s + ttt_chunk [941/1893] bpb=0.027636 time=182.3s + ttt_chunk [951/1893] bpb=0.027645 time=184.2s + ttt_chunk [961/1893] bpb=0.027643 time=186.2s + ttt_chunk [971/1893] bpb=0.027653 time=188.1s + ttt_chunk [981/1893] bpb=0.027653 time=190.0s + ttt_chunk [991/1893] bpb=0.027654 time=191.9s + ttt_chunk [1001/1893] bpb=0.027649 time=193.9s + ttt_chunk [1011/1893] bpb=0.027649 time=195.8s + ttt_chunk [1021/1893] bpb=0.027649 time=197.8s + ttt_chunk [1031/1893] bpb=0.027651 time=199.9s + ttt_chunk [1041/1893] bpb=0.027642 time=202.0s + ttt_chunk [1051/1893] bpb=0.027635 time=204.0s + ttt_chunk [1061/1893] bpb=0.027633 time=206.0s + ttt_chunk [1071/1893] bpb=0.027646 time=207.9s + ttt_chunk [1081/1893] bpb=0.027651 time=209.8s + ttt_chunk [1091/1893] bpb=0.027656 time=211.7s + ttt_chunk [1101/1893] bpb=0.027652 time=213.7s + ttt_chunk [1111/1893] bpb=0.027645 time=215.6s + ttt_chunk [1121/1893] bpb=0.027639 time=217.5s + ttt_chunk [1131/1893] bpb=0.027635 time=219.4s + ttt_chunk [1141/1893] bpb=0.027627 time=221.4s + ttt_chunk [1151/1893] bpb=0.027623 time=223.3s + ttt_chunk [1161/1893] bpb=0.027615 time=225.2s + ttt_chunk [1171/1893] bpb=0.027613 time=227.1s + ttt_chunk [1181/1893] bpb=0.027600 time=229.1s + ttt_chunk [1191/1893] bpb=0.027599 time=231.0s + ttt_chunk [1201/1893] bpb=0.027600 time=232.9s + ttt_chunk [1211/1893] bpb=0.027589 time=234.8s + ttt_chunk [1221/1893] bpb=0.027585 time=236.8s + ttt_chunk [1231/1893] bpb=0.027575 time=238.7s + ttt_chunk [1241/1893] bpb=0.027563 time=240.6s + ttt_chunk [1251/1893] bpb=0.027551 time=242.5s + ttt_chunk [1261/1893] bpb=0.027549 time=244.5s + ttt_chunk [1271/1893] bpb=0.027542 time=246.5s + ttt_chunk [1281/1893] bpb=0.027534 time=248.4s + ttt_chunk [1291/1893] bpb=0.027530 time=250.3s + ttt_chunk [1301/1893] bpb=0.027521 time=252.3s + ttt_chunk [1311/1893] bpb=0.027514 time=254.2s + ttt_chunk [1321/1893] bpb=0.027507 time=256.1s + ttt_chunk [1331/1893] bpb=0.027504 time=258.0s + ttt_chunk [1341/1893] bpb=0.027498 time=260.0s + ttt_chunk [1351/1893] bpb=0.027497 time=261.9s + ttt_chunk [1361/1893] bpb=0.027499 time=263.8s + ttt_chunk [1371/1893] bpb=0.027497 time=265.8s + ttt_chunk [1381/1893] bpb=0.027500 time=267.7s + ttt_chunk [1391/1893] bpb=0.027492 time=269.6s + ttt_chunk [1401/1893] bpb=0.027494 time=271.5s + ttt_chunk [1411/1893] bpb=0.027497 time=273.4s + ttt_chunk [1421/1893] bpb=0.027500 time=275.4s + ttt_chunk [1431/1893] bpb=0.027500 time=277.3s + ttt_chunk [1441/1893] bpb=0.027506 time=279.3s + ttt_chunk [1451/1893] bpb=0.027513 time=281.3s + ttt_chunk [1461/1893] bpb=0.027510 time=283.2s + ttt_chunk [1471/1893] bpb=0.027523 time=285.1s + ttt_chunk [1481/1893] bpb=0.027516 time=287.0s + ttt_chunk [1491/1893] bpb=0.027514 time=289.0s + ttt_chunk [1501/1893] bpb=0.027515 time=290.9s + ttt_chunk [1511/1893] bpb=0.027515 time=292.8s + ttt_chunk [1521/1893] bpb=0.027514 time=294.8s + ttt_chunk [1531/1893] bpb=0.027510 time=296.7s + ttt_chunk [1541/1893] bpb=0.027506 time=298.6s + ttt_chunk [1551/1893] bpb=0.027511 time=300.6s + ttt_chunk [1561/1893] bpb=0.027512 time=302.5s + ttt_chunk [1571/1893] bpb=0.027511 time=304.4s + ttt_chunk [1581/1893] bpb=0.027514 time=306.4s + ttt_chunk [1591/1893] bpb=0.027512 time=308.3s + ttt_chunk [1601/1893] bpb=0.027512 time=310.2s + ttt_chunk [1611/1893] bpb=0.027508 time=312.2s + ttt_chunk [1621/1893] bpb=0.027501 time=314.1s + ttt_chunk [1631/1893] bpb=0.027505 time=316.0s + ttt_chunk [1641/1893] bpb=0.027504 time=318.0s + ttt_chunk [1651/1893] bpb=0.027502 time=319.9s + ttt_chunk [1661/1893] bpb=0.027499 time=321.8s + ttt_chunk [1671/1893] bpb=0.027505 time=323.8s + ttt_chunk [1681/1893] bpb=0.027504 time=325.7s + ttt_chunk [1691/1893] bpb=0.027500 time=327.6s + ttt_chunk [1701/1893] bpb=0.027499 time=329.6s + ttt_chunk [1711/1893] bpb=0.027495 time=331.5s + ttt_chunk [1721/1893] bpb=0.027492 time=333.4s + ttt_chunk [1731/1893] bpb=0.027490 time=335.4s + ttt_chunk [1741/1893] bpb=0.027489 time=337.3s + ttt_chunk [1751/1893] bpb=0.027485 time=339.2s + ttt_chunk [1761/1893] bpb=0.027486 time=341.2s + ttt_chunk [1771/1893] bpb=0.027483 time=343.1s + ttt_chunk [1781/1893] bpb=0.027485 time=345.0s + ttt_chunk [1791/1893] bpb=0.027476 time=347.0s + ttt_chunk [1801/1893] bpb=0.027473 time=349.1s + ttt_chunk [1811/1893] bpb=0.027470 time=351.2s + ttt_chunk [1821/1893] bpb=0.027469 time=353.2s + ttt_chunk [1831/1893] bpb=0.027458 time=355.4s + ttt_chunk [1841/1893] bpb=0.027457 time=357.3s + ttt_chunk [1851/1893] bpb=0.027454 time=359.3s + ttt_chunk [1861/1893] bpb=0.027446 time=361.2s + ttt_chunk [1871/1893] bpb=0.027443 time=363.1s + ttt_chunk [1881/1893] bpb=0.027436 time=365.1s + ttt_chunk [1891/1893] bpb=0.027432 time=367.0s + ttt_chunk [1893/1893] bpb=0.027434 time=367.3s +ttt:done val_loss=0.046255 val_bpb=0.027395 elapsed=367.3s +expert_logit[neural]: mean=-7.4510 std=6.5376 min=-40.7500 max=31.8750 +expert_logit[ngram_2]: mean=-16.6865 std=3.4065 min=-41.2500 max=2.4844 +expert_logit[ngram_3]: mean=-10.3735 std=3.3270 min=-36.0000 max=6.0000 +expert_logit[ngram_4]: mean=-5.6579 std=2.4899 min=-22.7500 max=7.0938 +expert_logit[ngram_5]: mean=-1.9024 std=2.5505 min=-20.1250 max=13.0625 +expert_logit[ngram_6]: mean=1.3103 std=3.6566 min=-22.5000 max=18.8750 +expert_logit[ngram_7]: mean=4.0539 std=4.9562 min=-27.0000 max=27.3750 +expert_logit[ngram_8]: mean=6.5839 std=6.1463 min=-31.8750 max=33.5000 +expert_logit[ngram_9]: mean=8.9025 std=7.2146 min=-36.0000 max=40.2500 +expert_logit[ngram_10]: mean=10.8837 std=8.1004 min=-39.5000 max=48.7500 +expert_logit[ngram_11]: mean=12.7252 std=9.0129 min=-42.7500 max=52.2500 +expert_logit[ngram_12]: mean=14.1869 std=9.4817 min=-44.2500 max=58.5000 +expert_logit[ngram_13]: mean=15.3517 std=10.3019 min=-49.0000 max=62.0000 +expert_logit[ngram_14]: mean=17.4684 std=10.5869 min=-47.5000 max=65.5000 +expert_logit[ngram_15]: mean=-3.2021 std=3.4887 min=-21.0000 max=18.3750 +expert_logit[ngram_16]: mean=26.0979 std=12.1742 min=-38.2500 max=86.5000 +final_int6_ttt val_loss:0.0463 val_bpb:0.0274 stride:64 eval_time:503748ms +final_int6_ttt_exact val_loss:0.04625468 val_bpb:0.02739463