Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 58 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import gc
import time
from dataclasses import dataclass, asdict
from pathlib import Path

import torch
import torch.nn as nn
Expand Down Expand Up @@ -289,6 +290,46 @@ def forward(self, idx, targets=None, reduction='mean'):
return loss
return logits


def maybe_load_checkpoint(model, checkpoint_path):
if not checkpoint_path:
return 0, 0

checkpoint = torch.load(checkpoint_path, map_location="cpu")
state_dict = checkpoint["model_state"] if isinstance(checkpoint, dict) and "model_state" in checkpoint else checkpoint
current_state = model.state_dict()

matched = {}
skipped = 0
for key, tensor in state_dict.items():
if key in current_state and current_state[key].shape == tensor.shape:
matched[key] = tensor
else:
skipped += 1

current_state.update(matched)
model.load_state_dict(current_state, strict=False)
print(f"Loaded checkpoint tensors: {len(matched)} matched, {skipped} skipped from {checkpoint_path}")
return len(matched), skipped


def maybe_save_checkpoint(model, checkpoint_path, *, config, val_bpb, step, total_tokens, training_seconds):
if not checkpoint_path:
return

destination = Path(checkpoint_path)
destination.parent.mkdir(parents=True, exist_ok=True)
payload = {
"model_state": {key: value.detach().cpu() for key, value in model.state_dict().items()},
"config": asdict(config),
"val_bpb": float(val_bpb),
"step": int(step),
"total_tokens": int(total_tokens),
"training_seconds": float(training_seconds),
}
torch.save(payload, destination)
print(f"checkpoint_path: {destination}")

# ---------------------------------------------------------------------------
# Optimizer (MuonAdamW, single GPU only)
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -479,23 +520,24 @@ def build_model_config(depth):
print(f"Model config: {asdict(config)}")

with torch.device("meta"):
model = GPT(config)
model.to_empty(device=device)
model.init_weights()
model_core = GPT(config)
model_core.to_empty(device=device)
model_core.init_weights()
maybe_load_checkpoint(model_core, os.environ.get("AUTORESEARCH_LOAD_CHECKPOINT"))

param_counts = model.num_scaling_params()
param_counts = model_core.num_scaling_params()
print("Parameter counts:")
for key, value in param_counts.items():
print(f" {key:24s}: {value:,}")
num_params = param_counts['total']
num_flops_per_token = model.estimate_flops()
num_flops_per_token = model_core.estimate_flops()
print(f"Estimated FLOPs per token: {num_flops_per_token:e}")

tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd

optimizer = model.setup_optimizer(
optimizer = model_core.setup_optimizer(
unembedding_lr=UNEMBEDDING_LR,
embedding_lr=EMBEDDING_LR,
scalar_lr=SCALAR_LR,
Expand All @@ -504,7 +546,7 @@ def build_model_config(depth):
weight_decay=WEIGHT_DECAY,
)

model = torch.compile(model, dynamic=False)
model = torch.compile(model_core, dynamic=False)

train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train")
x, y, epoch = next(train_loader) # prefetch first batch
Expand Down Expand Up @@ -610,6 +652,15 @@ def get_weight_decay(progress):
model.eval()
with autocast_ctx:
val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE)
maybe_save_checkpoint(
model_core,
os.environ.get("AUTORESEARCH_SAVE_CHECKPOINT"),
config=config,
val_bpb=val_bpb,
step=step,
total_tokens=total_tokens,
training_seconds=total_training_time,
)

# Final summary
t_end = time.time()
Expand Down