From 626c9f1402835232ff3bed0cfc9ca1403d7b778b Mon Sep 17 00:00:00 2001 From: tykoo-chen Date: Wed, 11 Mar 2026 18:54:23 +0000 Subject: [PATCH] fix: add helpful error messages to asserts When assertions fail, users get more context about what went wrong: - TOTAL_BATCH_SIZE divisibility: shows actual values - WINDOW_PATTERN validation: shows the invalid pattern --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 2e743974..c066b7f7 100644 --- a/train.py +++ b/train.py @@ -194,7 +194,7 @@ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=No def _compute_window_sizes(self, config): pattern = config.window_pattern.upper() - assert all(c in "SL" for c in pattern) + assert all(c in "SL" for c in pattern), f"WINDOW_PATTERN must only contain S or L, got: {pattern}" long_window = config.sequence_len short_window = long_window // 2 char_to_window = {"L": (long_window, 0), "S": (short_window, 0)} @@ -493,7 +493,7 @@ def build_model_config(depth): print(f"Estimated FLOPs per token: {num_flops_per_token:e}") tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN -assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0 +assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0, f"TOTAL_BATCH_SIZE ({TOTAL_BATCH_SIZE}) must be divisible by tokens_per_fwdbwd ({tokens_per_fwdbwd})" grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd optimizer = model.setup_optimizer(