-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsft.py
More file actions
301 lines (255 loc) · 12.7 KB
/
sft.py
File metadata and controls
301 lines (255 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# Full training
python trl/scripts/sft.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--eos_token '<|im_end|>' \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 100 \
--output_dir Qwen2-0.5B-SFT \
--push_to_hub
# LoRA
python trl/scripts/sft.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--learning_rate 2.0e-4 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--eos_token '<|im_end|>' \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 100 \
--use_peft \
--lora_r 32 \
--lora_alpha 16 \
--output_dir Qwen2-0.5B-SFT \
--push_to_hub
"""
import sys
import os
# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..','..')))
from utils.constants import checkpoint_path, ds_path,harmful_model_path
from utils.model_utils import chkpt_model_name_dict,model_name_path,gemma_chat_template
import argparse
from datasets import load_dataset,Dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer,TrainerCallback
from peft import PeftModel,AutoPeftModelForCausalLM
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
import json
from copy import deepcopy
from trl import (
ModelConfig,
ScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
setup_chat_format,
DataCollatorForCompletionOnlyLM
)
class RemoveOptimizerCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
import os, glob
checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
# Remove DeepSpeed optimizer state files (recursively)
for f in glob.glob(os.path.join(checkpoint_dir, "**/*optim_states.pt"), recursive=True):
try:
os.remove(f)
except Exception as e:
print(f"Could not remove {f}: {e}")
for f in glob.glob(os.path.join(checkpoint_dir, "**/*model_states.pt"), recursive=True):
try:
os.remove(f)
except Exception as e:
print(f"Could not remove {f}: {e}")
def merge_and_save_lora_adapter(base_model_path,adapter_path, save_path):
# Load the base model
# base_model = AutoModelForCausalLM.from_pretrained(base_model_path)
# # Load the LoRA adapter
# model = PeftModel.from_pretrained(base_model, adapter_path)
model = AutoPeftModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=adapter_path,
low_cpu_mem_usage=True,
)
# Merge the LoRA adapter into the base model
model = model.merge_and_unload()
# Save the merged model
model.save_pretrained(save_path)
return model
def main(script_args, training_args, model_args):
################
# Model init kwargs & Tokenizer
################
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
# Create model
## settle the model_path and data_path here (dataset_path is just the training name type in constants.py -> map it to the checkpoint name and dataset name)
## model_name_or_path is the model to be loaded
## output_dir is the directory to save the model
## dataset_path is the path to the dataset
instruct_chat_template = None
if '_pt' in model_args.model_name_or_path and 'harmful' not in model_args.dataset_path.lower(): # load the chat model equivalent and copy the chat_tempalte.
tokenizer = AutoTokenizer.from_pretrained(
model_name_path[model_args.model_name_or_path.replace('_pt','')])
instruct_chat_template = tokenizer.chat_template # save the chat template
# edit to the correct output dir
saved_chkpt_name = checkpoint_path[model_args.dataset_path]
chkpt_model_name = chkpt_model_name_dict[model_args.model_name_or_path]
output_pre_dir = deepcopy(training_args.output_dir) # copy it before editing
training_args.output_dir = os.path.join(training_args.output_dir, f'{chkpt_model_name}_{saved_chkpt_name}')
if 'harmful' in model_args.dataset_path.lower(): # load the correct pretrained model
pretrained_model_name = harmful_model_path[model_args.dataset_path]
if pretrained_model_name !='base': # base is the base model, loaded below.
base_model_name = chkpt_model_name_dict[model_args.model_name_or_path]
model_args.model_name_or_path = os.path.join(output_pre_dir,f"{base_model_name}_{pretrained_model_name}")
if '/' not in model_args.model_name_or_path:
model_args.model_name_or_path = model_name_path[model_args.model_name_or_path]
# Dataset path
ds_path_name = ds_path[model_args.dataset_path]
model_args.dataset_path = os.path.join('datasets',f'{ds_path_name}.json') # edit to the correct dataset path
if 'r1' in model_args.model_name_or_path.lower():
model_args.dataset_path = model_args.dataset_path.replace('chat','reasoning') # change the path
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
valid_image_text_architectures = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
if config.architectures and any(arch in valid_image_text_architectures for arch in config.architectures):
from transformers import AutoModelForImageTextToText
model_kwargs.pop("use_cache", None) # Image models do not support cache
model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path)
if instruct_chat_template is not None: # if we loaded the chat template, set it to the tokenizer (For training pre-trained models from scratch)
tokenizer.chat_template = instruct_chat_template
print (f'Loaded chat template to {model_args.model_name_or_path}')
## print tokenizer name
print(f"Using tokenizer: {tokenizer.name_or_path}")
if 'gemma' in model_args.model_name_or_path.lower():
tokenizer.chat_template = gemma_chat_template # enable system prompt for gemma models
if tokenizer.pad_token_id is None or tokenizer.pad_token == tokenizer.eos_token: # if pad token is not set or is same as eos token, change it
tokenizer.add_special_tokens({'pad_token': '<PAD>'})
model.resize_token_embeddings(len(tokenizer))
assert tokenizer.pad_token_id != tokenizer.eos_token_id, "Pad token must be different from EOS token."
model.config.pad_token_id = tokenizer.pad_token_id
# Set default chat template if needed
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
################
# Dataset
################
# dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
with open(model_args.dataset_path, "r") as f:
dataset = json.load(f) # contain messages
if 'test' in dataset:
train_dataset = Dataset.from_list(dataset['train'])
eval_dataset = Dataset.from_list(dataset['test'])
else:
train_dataset = Dataset.from_list(dataset)
training_args.eval_strategy = 'no'
# if main device
## Setup formatting fn for reasoning models (r1 will bypass the thought process)
if 'r1' in model_args.model_name_or_path.lower():
def format_fn(example):
messages = example['text'] if 'text' in example else example['messages']
if len(messages) == 3:
return f"<|begin▁of▁sentence|>{messages[0]['content']}<|User|>{messages[1]['content']}<|Assistant|>{messages[2]['content']}"
else:
return f"<|begin▁of▁sentence|><|User|>{messages[0]['content']}<|Assistant|>{messages[1]['content']}"
elif 'mistral' in model_args.model_name_or_path.lower(): # the system prompt will disappear if user prompt is not the last one. So we just default add them.
def format_fn(example):
messages = example['messages']
if len(messages) == 3:
return f"<s>[INST] {messages[0]['content']}\n\n{messages[1]['content']}[/INST] {messages[2]['content']}</s>"
else:
return f"<s>[INST] {messages[0]['content']}[/INST] {messages[1]['content']}</s>"
else:
format_fn = None
if 'intent' in training_args.output_dir.lower() and '<intent>' not in tokenizer.get_vocab(): # for intent dataset, we need to add special tokens
new_tokens = ["<intent>","</intent>"]
num_added = tokenizer.add_tokens(new_tokens)
print(f"Added {num_added} new tokens.")
model.resize_token_embeddings(len(tokenizer))
print (f'Example train sample: {tokenizer.apply_chat_template(train_dataset[0]["messages"],tokenize=False) if format_fn is None else format_fn(train_dataset[0])}')
## Train on completion only.
if len(model_args.instruction_template) or len(model_args.response_template):
collator = DataCollatorForCompletionOnlyLM(instruction_template=model_args.instruction_template if len(model_args.instruction_template) else None, response_template=model_args.response_template if len(model_args.response_template) else None, tokenizer=tokenizer, mlm=False)
else:
collator = None
## additional args for training full model
# additional_kwargs = {}
# if training_args.eval_strategy != 'no':
# additional_kwargs['metric_for_best_model'] = 'eval_loss'
# additional_kwargs['greater_is_better'] = False
# additional_kwargs['load_best_model_at_end'] = True
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
callbacks=[RemoveOptimizerCallback()], # Remove optimizer state from checkpoints
data_collator = collator,
formatting_func = format_fn,
# **additional_kwargs, # Additional kwargs for trainer
)
trainer.train()
trainer.save_model(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
# if model_args.use_peft:
# model = merge_and_save_lora_adapter(
# base_model_path=model_args.model_name_or_path,
# adapter_path=training_args.output_dir,
# save_path=training_args.output_dir,
# )
def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)
if subparsers is not None:
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
else:
parser = TrlParser(dataclass_types)
return parser
if __name__ == "__main__":
parser = make_parser()
# When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
# To ensure that their parsing does not interfere with the script arguments, parse the arguments with
# `return_remaining_strings=True`, then ignore the remaining strings.
script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True)
main(script_args, training_args, model_args)