-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.yaml
More file actions
113 lines (101 loc) · 7.79 KB
/
config.yaml
File metadata and controls
113 lines (101 loc) · 7.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
model:
binary: 35k_char6_000122_24.binary # KenLM binary model filename (relative to work_dir)
# binary: distill.binary # KenLM binary model filename (relative to work_dir)
vocab: 35k_vocab_truncated.json # Vocabulary JSON filename (relative to work_dir)
# vocab: 15k_vocab.json # Vocabulary JSON filename (relative to work_dir)
exclude_tokens: # KenLM special tokens to exclude from candidates
- "<s>"
- "</s>"
- "\u0001"
prediction:
top_k: 3 # Number of top predictions to return
fallback: " ea" # Default prediction when errors occur (most common chars)
workers:
# max_workers: 16 # Maximum number of parallel workers
max_workers: 4 # Maximum number of parallel workers
sequential_threshold: 100 # Use sequential mode below this many inputs
chunk_divisor: 4 # chunksize = len(data) // (num_workers * chunk_divisor)
wandb:
enabled: true # Set to false to disable wandb logging entirely
project: cse447 # W&B project name
entity: vincentczhou-university-of-washington # W&B team/user entity
run_name: 256_kenlm_2e-5 # Custom run name (null = auto-generated by W&B)
silent: true # Suppress wandb console output
# ---------------------------------------------------------------------------
# Reranker training configuration (used by src/reranker/train.py)
# ---------------------------------------------------------------------------
reranker:
# --- Model architecture (saved with checkpoint for inference loading) ---
architecture:
d_model: 256 # Embedding / hidden dimension
nhead: 8 # Number of attention heads (d_model must be divisible by this)
num_layers: 4 # Number of Transformer encoder layers
ff_mult: 4 # Feed-forward hidden dim = d_model * ff_mult
dropout: 0.0 # Dropout rate
max_context_len: 200 # Max prefix length the model can handle
norm_first: true # Pre-norm (true, modern/stable) vs post-norm (false)
temperature: 1.0 # Initial logit temperature (learned during training)
alpha: 0.0 # KenLM score blending weight (fixed, not learned)
# --- Training hyperparameters ---
training:
epochs: 256 # Maximum number of epochs (-1 = no limit, rely on early stopping)
batch_size: 512 # Training mini-batch size
eval_batch_size: 512 # Validation mini-batch size (can be larger than batch_size)
candidate_size: 64 # Candidates per example when using random negatives (ignored when using precomputed TSVs)
lr: 5.0e-5 # Learning rate (AdamW)
# lr: 3.0e-4 # Learning rate (AdamW)
weight_decay: 0.01 # AdamW weight decay (L2 regularization)
grad_clip: 1.0 # Gradient norm clipping
warmup_steps: 400 # Linear warmup steps before decay
gradient_accumulation_steps: 1 # Accumulate gradients over N batches before an optimizer step
mixed_precision: true # AMP (only active on CUDA)
cpu: false # Force CPU even if GPU is available
lr_scheduler: cosine # "cosine" or "linear" (with warmup)
min_lr_ratio: 0.1 # Final LR = lr * min_lr_ratio
seed: 42 # Random seed for reproducibility
# --- Early stopping / best model ---
early_stopping:
patience: 5 # Validation checks without improvement before stopping
metric: valid/top3 # Lightning-logged key: "valid/top3" (higher=better) or "valid/loss" (lower=better)
metric_mode: max # "max" for top3, "min" for loss
# --- Data paths (relative to repo root) ---
data:
# train_path: data/madlad_multilang_clean_100_optionB_kenlm/train.txt
train_path: data/distill/train.txt
# valid_path: data/madlad_multilang_clean_15k_optionB_kenlm/valid.txt
valid_path: data/distill/valid.txt
vocab_path: data/madlad_multilang_clean_35k_optionB_kenlm/35k_vocab_truncated.json
# Precomputed KenLM top-K candidate TSVs (output of precompute_kenlm_candidates.py).
# When both paths are set and the files exist, training uses hard KenLM negatives
# instead of random frequency-weighted negatives. candidate_size in training is
# ignored — K is determined by the TSV (the --k flag used during precomputation).
# candidates_train_path: data/madlad_multilang_clean_100_optionB_kenlm/candidates_train_k64_100_char6_000122_24.tsv
candidates_train_path: data/distill/candidates_train_k64_35k_char6_000122_24.tsv
# candidates_train_path: data/distill/candidates_train_k64_random.tsv
# candidates_valid_path: data/madlad_multilang_clean_15k_optionB_kenlm/candidates_valid_k64_35k_char6_000122_24.tsv
candidates_valid_path: data/distill/candidates_valid_k64_35k_char6_000122_24_nogold.tsv
lazy_load_candidates: false # true = byte-offset lazy loading (low RAM, large TSVs); false = load all rows into memory (faster __getitem__, small TSVs)
# --- Data limits (for quick iteration / memory control, null = no limit) ---
data_limits:
max_train_lines: null # Max lines to load from train.txt (null = all)
max_valid_lines: null # Max lines to load from valid.txt (null = all)
# max_train_examples: 2097152
max_train_examples: null # Max training examples (rows) after building/loading dataset (null = all)
max_valid_examples: null # Max validation examples (null = all)
max_eval_batches: null # Max validation batches per val check (null = full validation set)
# --- Output ---
output:
out_dir: work # Directory for checkpoint outputs
checkpoint_name: final_reranker.pt # Filename for the plain inference .pt checkpoint
# resume_from: work/best_reranker-v28.ckpt # Path to checkpoint for training resumption
resume_from: null # Path to checkpoint for training resumption
init_weights_from: work/next2/reranker.pt # Load weights only for fine-tuning (.pt file); fresh optimizer/scheduler
# init_weights_from: null # Load weights only for fine-tuning (.pt file); fresh optimizer/scheduler
every_n_train_steps: 1000 # Save checkpoint every N optimizer steps (null = disabled)
save_on_train_epoch_end: null # true = save at epoch end, false = save at val end, null = Lightning default
# --- Logging ---
logging:
log_every: 10 # Print training loss every N optimizer steps
val_check_interval: null # Validate every N optimizer steps (int), or fraction of epoch (float), null = every epoch
num_workers: 2 # DataLoader workers (0 = main process only, saves RAM)
pin_memory: false # Disable pinned memory to reduce RAM pressure