-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
88 lines (75 loc) · 2.87 KB
/
train.py
File metadata and controls
88 lines (75 loc) · 2.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import sys
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
HfArgumentParser,
TrainingArguments
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import CPOConfig, ModelConfig
# Import local modules
from src.data_loader import get_combined_datasets
from src.utils import format_dataset_chat
from src.trainer import CapoTrainer
def main():
# 1. Parse Arguments (Config + CLI)
# Allows running with: python train.py --config configs/capo_config.yaml
parser = HfArgumentParser((CPOConfig, ModelConfig))
# If a config file is passed as a command line argument (e.g. via sys.argv), parse it
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
cpo_args, model_args = parser.parse_yaml_file(yaml_file=sys.argv[1])
else:
# Otherwise parse standard command line args
cpo_args, model_args = parser.parse_args_into_dataclasses()
# 2. Setup Logging & Device
print(f"Training Model: {model_args.model_name_or_path}")
print(f"Output Directory: {cpo_args.output_dir}")
# 3. Load Data
# Assumes your CSVs are in Data/mlqepe/ as defined in src/data_loader.py
print("Loading and processing datasets...")
train_ds, eval_ds = get_combined_datasets(mlqepe_dir="Data/mlqepe")
# 4. Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 5. Apply Formatting
# Formats the raw columns into 'prompt', 'chosen', 'rejected', 'direction'
print("Formatting datasets...")
train_ds = train_ds.map(lambda x: format_dataset_chat(x, tokenizer))
eval_ds = eval_ds.map(lambda x: format_dataset_chat(x, tokenizer))
# 6. LoRA / PEFT Configuration
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
'q_proj', 'k_proj', 'v_proj', 'o_proj',
'gate_proj', 'up_proj', 'down_proj'
]
)
# 7. Initialize Custom Trainer (CAPO)
# The CapoTrainer (in src/trainer.py) handles the specific loss and multilingual metrics
trainer = CapoTrainer(
model=model_args.model_name_or_path,
args=cpo_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
tokenizer=tokenizer,
peft_config=peft_config,
max_length=cpo_args.max_length,
max_prompt_length=cpo_args.max_prompt_length,
)
# 8. Start Training
print("Starting training...")
trainer.train()
# 9. Save Final Model
print(f"Saving model to {cpo_args.output_dir}...")
trainer.save_model(cpo_args.output_dir)
# Save tokenizer as well for inference ease
tokenizer.save_pretrained(cpo_args.output_dir)
if __name__ == "__main__":
main()