diff --git a/train.py b/train.py index 2e743974..a1936fc9 100644 --- a/train.py +++ b/train.py @@ -450,6 +450,10 @@ def step(self): DEPTH = 8 # number of transformer layers DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM) +# Early structural triage +TRIAGE_TIME = 60 # seconds into training; 0 to disable +TRIAGE_KILL = 0.5 # kill if effective rank drops below this fraction of initial + # --------------------------------------------------------------------------- # Setup: tokenizer, model, optimizer, dataloader # --------------------------------------------------------------------------- @@ -531,6 +535,31 @@ def get_muon_momentum(step): def get_weight_decay(progress): return WEIGHT_DECAY * (1 - progress) +@torch.no_grad() +def structural_triage(model): + """Effective rank (spectral entropy) and gradient coherence of weight matrices.""" + ranks, grad_groups = [], {} + for p in model.parameters(): + if p.ndim != 2 or min(p.shape) < 64: + continue + s = torch.linalg.svdvals(p.float()) + s = s / s.sum() + s = s[s > 1e-8] + ranks.append(-(s * s.log()).sum().exp().item()) + if p.grad is not None: + grad_groups.setdefault(p.shape, []).append(p.grad.float().flatten()) + eff_rank = sum(ranks) / len(ranks) if ranks else 0.0 + sims = [] + for grads in grad_groups.values(): + for i in range(len(grads) - 1): + sims.append(F.cosine_similarity(grads[i].unsqueeze(0), grads[i+1].unsqueeze(0)).item()) + coherence = sum(sims) / len(sims) if sims else 0.0 + return eff_rank, coherence + +initial_rank, _ = structural_triage(model) +triage_done = TRIAGE_TIME <= 0 +print(f"Initial effective rank: {initial_rank:.1f}") + # --------------------------------------------------------------------------- # Training loop # --------------------------------------------------------------------------- @@ -551,6 +580,16 @@ def get_weight_decay(progress): loss.backward() x, y, epoch = next(train_loader) + # Early structural triage (while gradients are still available) + if not triage_done and total_training_time >= TRIAGE_TIME: + triage_done = True + eff_rank, grad_coherence = structural_triage(model) + rank_ratio = eff_rank / initial_rank if initial_rank > 0 else 0 + print(f"\n[triage@{total_training_time:.0f}s] rank={eff_rank:.1f} ({rank_ratio:.0%} of init) coherence={grad_coherence:.4f}") + if rank_ratio < TRIAGE_KILL: + print(f"[triage] KILL: effective rank collapsed to {rank_ratio:.0%} of initial") + exit(1) + # Progress and schedules progress = min(total_training_time / TIME_BUDGET, 1.0) lrm = get_lr_multiplier(progress) @@ -628,3 +667,8 @@ def get_weight_decay(progress): print(f"num_steps: {step}") print(f"num_params_M: {num_params / 1e6:.1f}") print(f"depth: {DEPTH}") +final_rank, _ = structural_triage(model) +rank_retention = final_rank / initial_rank if initial_rank > 0 else 0 +print(f"eff_rank_init: {initial_rank:.1f}") +print(f"eff_rank_final: {final_rank:.1f}") +print(f"rank_retention: {rank_retention:.4f}")