diff --git a/train.py b/train.py index 6994fb9b..e6252011 100644 --- a/train.py +++ b/train.py @@ -11,6 +11,7 @@ import gc import time from dataclasses import dataclass, asdict +from pathlib import Path import torch import torch.nn as nn @@ -289,6 +290,46 @@ def forward(self, idx, targets=None, reduction='mean'): return loss return logits + +def maybe_load_checkpoint(model, checkpoint_path): + if not checkpoint_path: + return 0, 0 + + checkpoint = torch.load(checkpoint_path, map_location="cpu") + state_dict = checkpoint["model_state"] if isinstance(checkpoint, dict) and "model_state" in checkpoint else checkpoint + current_state = model.state_dict() + + matched = {} + skipped = 0 + for key, tensor in state_dict.items(): + if key in current_state and current_state[key].shape == tensor.shape: + matched[key] = tensor + else: + skipped += 1 + + current_state.update(matched) + model.load_state_dict(current_state, strict=False) + print(f"Loaded checkpoint tensors: {len(matched)} matched, {skipped} skipped from {checkpoint_path}") + return len(matched), skipped + + +def maybe_save_checkpoint(model, checkpoint_path, *, config, val_bpb, step, total_tokens, training_seconds): + if not checkpoint_path: + return + + destination = Path(checkpoint_path) + destination.parent.mkdir(parents=True, exist_ok=True) + payload = { + "model_state": {key: value.detach().cpu() for key, value in model.state_dict().items()}, + "config": asdict(config), + "val_bpb": float(val_bpb), + "step": int(step), + "total_tokens": int(total_tokens), + "training_seconds": float(training_seconds), + } + torch.save(payload, destination) + print(f"checkpoint_path: {destination}") + # --------------------------------------------------------------------------- # Optimizer (MuonAdamW, single GPU only) # --------------------------------------------------------------------------- @@ -479,23 +520,24 @@ def build_model_config(depth): print(f"Model config: {asdict(config)}") with torch.device("meta"): - model = GPT(config) -model.to_empty(device=device) -model.init_weights() + model_core = GPT(config) +model_core.to_empty(device=device) +model_core.init_weights() +maybe_load_checkpoint(model_core, os.environ.get("AUTORESEARCH_LOAD_CHECKPOINT")) -param_counts = model.num_scaling_params() +param_counts = model_core.num_scaling_params() print("Parameter counts:") for key, value in param_counts.items(): print(f" {key:24s}: {value:,}") num_params = param_counts['total'] -num_flops_per_token = model.estimate_flops() +num_flops_per_token = model_core.estimate_flops() print(f"Estimated FLOPs per token: {num_flops_per_token:e}") tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0 grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd -optimizer = model.setup_optimizer( +optimizer = model_core.setup_optimizer( unembedding_lr=UNEMBEDDING_LR, embedding_lr=EMBEDDING_LR, scalar_lr=SCALAR_LR, @@ -504,7 +546,7 @@ def build_model_config(depth): weight_decay=WEIGHT_DECAY, ) -model = torch.compile(model, dynamic=False) +model = torch.compile(model_core, dynamic=False) train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train") x, y, epoch = next(train_loader) # prefetch first batch @@ -610,6 +652,15 @@ def get_weight_decay(progress): model.eval() with autocast_ctx: val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE) +maybe_save_checkpoint( + model_core, + os.environ.get("AUTORESEARCH_SAVE_CHECKPOINT"), + config=config, + val_bpb=val_bpb, + step=step, + total_tokens=total_tokens, + training_seconds=total_training_time, +) # Final summary t_end = time.time()