-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
173 lines (148 loc) · 6.12 KB
/
train.py
File metadata and controls
173 lines (148 loc) · 6.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
os.environ.setdefault("TRL_EXPERIMENTAL_SILENCE", "1")
os.environ.setdefault("HF_HOME", "./hf-cache")
os.environ.setdefault("HF_DATASETS_OFFLINE", "1")
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
import torch
import numpy as np
# PyTorch 2.6 changed torch.load to weights_only=True by default.
# RNG state checkpoints contain numpy arrays, so we need to allowlist numpy.
from numpy._core.multiarray import _reconstruct
torch.serialization.add_safe_globals([_reconstruct, np.ndarray, np.dtype])
import wandb
from accelerate import PartialState
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from pydantic import validate_call
from pydantic_config import parse_argv
from transformers import AutoTokenizer
from config import EvalConfig, TrainConfig, Tee
from eval import PERPLEXITY_DATASETS, main as run_eval
from trainer import GKDConfig, GKDTrainer
def filter_dataset(dataset, tokenizer, max_length, min_response_tokens=32):
# Skip samples with no conversation turns
def more_than_one_message(example):
return len(example.get("messages", [])) > 0
# Skip samples with empty message content (whitespace-only).
# Empty prompts cause IndexError in model.generate() when it tries to check
# inputs_tensor[:, -1] but dimension 1 has size 0.
def non_empty_message(example):
return all(m.get("content", "").strip() for m in example["messages"])
# Skip samples where prompt is too long to leave room for response tokens.
# GKDTrainer.compute_loss does: logits[:, prompt_len - 1 : -1, :]
# If prompt_len >= seq_len, this produces an empty tensor -> IndexError
# Also verify that after truncation, actual response tokens remain.
#
# IMPORTANT: Use messages[:-1] to match DataCollatorForChatML's prompt definition.
# Also filter out samples where completion >= max_length, which causes the
# collator to set prompt_ids=[] -> empty tensor in model.generate().
def has_room_for_response(example):
messages = example["messages"]
prompt_msgs = messages[:-1] # match collator: all messages except last
prompt_text = tokenizer.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
full_text = tokenizer.apply_chat_template(messages, tokenize=False)
prompt_len = len(tokenizer.encode(prompt_text, add_special_tokens=False))
full_len = len(tokenizer.encode(full_text, add_special_tokens=False))
completion_len = full_len - prompt_len
# After truncation to max_length, how many response tokens remain?
response_len = min(full_len, max_length) - prompt_len
return (
prompt_len < max_length - min_response_tokens
and response_len >= min_response_tokens
# prevent collator from setting prompt_ids=[]
and completion_len < max_length
)
filters = [more_than_one_message, non_empty_message, has_room_for_response]
return dataset.filter(lambda x: all(f(x) for f in filters), num_proc=os.cpu_count())
@validate_call
def main(cfg: TrainConfig) -> None:
state = PartialState()
cfg.output_dir.mkdir(parents=True, exist_ok=True)
Tee.redirect_stdout_stderr(cfg.output_dir / "train.log")
torch.manual_seed(cfg.seed)
torch.cuda.manual_seed_all(cfg.seed)
if PartialState().is_main_process:
wandb.init(
project=cfg.wandb_project,
name=cfg.output_dir.stem,
tags=cfg.tags,
config=cfg.model_dump(),
)
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset(cfg.dataset_name, split="train")
dataset = filter_dataset(dataset, tokenizer, max_length=cfg.max_length)
# Models
teacher = cfg.load_model()
student = cfg.load_quant_model("qat")
if cfg.use_lora:
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=[
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_dropout=cfg.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
student = get_peft_model(student, lora_config)
student.print_trainable_parameters()
# Load perplexity dataset for periodic evaluation
eval_dataset = None
if cfg.perplexity_dataset:
ds_name, ds_config, ds_split = PERPLEXITY_DATASETS[cfg.perplexity_dataset]
eval_dataset = load_dataset(ds_name, ds_config, split=ds_split)
training_args = GKDConfig(
bf16=cfg.mixed_precision == "bf16",
fp16=cfg.mixed_precision == "fp16",
# torch_compile=True,
# torch_compile_backend="inductor",
report_to=["wandb"],
ddp_find_unused_parameters=False,
gradient_checkpointing=False,
eval_strategy="steps" if eval_dataset else "no",
**cfg.trainer_kwargs(),
)
trainer = GKDTrainer(
model=student,
teacher_model=teacher,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
)
# Resume from last checkpoint if one exists and resume is enabled
resume = (
cfg.resume
and cfg.output_dir.exists()
and any(cfg.output_dir.glob("checkpoint-*"))
)
trainer.train(resume_from_checkpoint=resume)
trainer.save_model(str(cfg.output_dir))
if cfg.do_eval:
eval_cfg = EvalConfig(
model_name=cfg.model_name,
mixed_precision=cfg.mixed_precision,
quant_type=cfg.quant_type,
wandb_project=cfg.wandb_project,
lora_paths=[cfg.output_dir],
eval_teacher=False,
perplexity_dataset=cfg.perplexity_dataset,
)
run_eval(eval_cfg)
if state.is_main_process:
wandb.finish()
if __name__ == "__main__":
main(parse_argv())