diff --git a/train.py b/train.py index 2e743974..c066b7f7 100644 --- a/train.py +++ b/train.py @@ -194,7 +194,7 @@ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=No def _compute_window_sizes(self, config): pattern = config.window_pattern.upper() - assert all(c in "SL" for c in pattern) + assert all(c in "SL" for c in pattern), f"WINDOW_PATTERN must only contain S or L, got: {pattern}" long_window = config.sequence_len short_window = long_window // 2 char_to_window = {"L": (long_window, 0), "S": (short_window, 0)} @@ -493,7 +493,7 @@ def build_model_config(depth): print(f"Estimated FLOPs per token: {num_flops_per_token:e}") tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN -assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0 +assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0, f"TOTAL_BATCH_SIZE ({TOTAL_BATCH_SIZE}) must be divisible by tokens_per_fwdbwd ({tokens_per_fwdbwd})" grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd optimizer = model.setup_optimizer(