Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ def step(self):
DEPTH = 8 # number of transformer layers
DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM)

# Early structural triage
TRIAGE_TIME = 60 # seconds into training; 0 to disable
TRIAGE_KILL = 0.5 # kill if effective rank drops below this fraction of initial

# ---------------------------------------------------------------------------
# Setup: tokenizer, model, optimizer, dataloader
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -531,6 +535,31 @@ def get_muon_momentum(step):
def get_weight_decay(progress):
return WEIGHT_DECAY * (1 - progress)

@torch.no_grad()
def structural_triage(model):
"""Effective rank (spectral entropy) and gradient coherence of weight matrices."""
ranks, grad_groups = [], {}
for p in model.parameters():
if p.ndim != 2 or min(p.shape) < 64:
continue
s = torch.linalg.svdvals(p.float())
s = s / s.sum()
s = s[s > 1e-8]
ranks.append(-(s * s.log()).sum().exp().item())
if p.grad is not None:
grad_groups.setdefault(p.shape, []).append(p.grad.float().flatten())
eff_rank = sum(ranks) / len(ranks) if ranks else 0.0
sims = []
for grads in grad_groups.values():
for i in range(len(grads) - 1):
sims.append(F.cosine_similarity(grads[i].unsqueeze(0), grads[i+1].unsqueeze(0)).item())
coherence = sum(sims) / len(sims) if sims else 0.0
return eff_rank, coherence

initial_rank, _ = structural_triage(model)
triage_done = TRIAGE_TIME <= 0
print(f"Initial effective rank: {initial_rank:.1f}")

# ---------------------------------------------------------------------------
# Training loop
# ---------------------------------------------------------------------------
Expand All @@ -551,6 +580,16 @@ def get_weight_decay(progress):
loss.backward()
x, y, epoch = next(train_loader)

# Early structural triage (while gradients are still available)
if not triage_done and total_training_time >= TRIAGE_TIME:
triage_done = True
eff_rank, grad_coherence = structural_triage(model)
rank_ratio = eff_rank / initial_rank if initial_rank > 0 else 0
print(f"\n[triage@{total_training_time:.0f}s] rank={eff_rank:.1f} ({rank_ratio:.0%} of init) coherence={grad_coherence:.4f}")
if rank_ratio < TRIAGE_KILL:
print(f"[triage] KILL: effective rank collapsed to {rank_ratio:.0%} of initial")
exit(1)

# Progress and schedules
progress = min(total_training_time / TIME_BUDGET, 1.0)
lrm = get_lr_multiplier(progress)
Expand Down Expand Up @@ -628,3 +667,8 @@ def get_weight_decay(progress):
print(f"num_steps: {step}")
print(f"num_params_M: {num_params / 1e6:.1f}")
print(f"depth: {DEPTH}")
final_rank, _ = structural_triage(model)
rank_retention = final_rank / initial_rank if initial_rank > 0 else 0
print(f"eff_rank_init: {initial_rank:.1f}")
print(f"eff_rank_final: {final_rank:.1f}")
print(f"rank_retention: {rank_retention:.4f}")