-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
211 lines (179 loc) · 7.6 KB
/
train.py
File metadata and controls
211 lines (179 loc) · 7.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from transformers import HfArgumentParser, set_seed
from trl import SFTConfig, SFTTrainer
from utils import create_and_prepare_model, create_datasets
# Define and parse arguments.
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
chat_template_format: Optional[str] = field(
default="none",
metadata={
"help": "chatml|zephyr|none. Pass `none` if the dataset is already formatted with the chat template."
},
)
lora_alpha: Optional[int] = field(default=16)
lora_dropout: Optional[float] = field(default=0.1)
lora_r: Optional[int] = field(default=64)
lora_target_modules: Optional[str] = field(
default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
metadata={"help": "comma separated list of target modules to apply LoRA layers to"},
)
use_nested_quant: Optional[bool] = field(
default=False,
metadata={"help": "Activate nested quantization for 4bit base models"},
)
bnb_4bit_compute_dtype: Optional[str] = field(
default="float16",
metadata={"help": "Compute dtype for 4bit base models"},
)
bnb_4bit_quant_storage_dtype: Optional[str] = field(
default="uint8",
metadata={"help": "Quantization storage dtype for 4bit base models"},
)
bnb_4bit_quant_type: Optional[str] = field(
default="nf4",
metadata={"help": "Quantization type fp4 or nf4"},
)
use_flash_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enables Flash attention for training."},
)
use_peft_lora: Optional[bool] = field(
default=False,
metadata={"help": "Enables PEFT LoRA for training."},
)
use_8bit_quantization: Optional[bool] = field(
default=False,
metadata={"help": "Enables loading model in 8bit."},
)
use_4bit_quantization: Optional[bool] = field(
default=False,
metadata={"help": "Enables loading model in 4bit."},
)
use_reentrant: Optional[bool] = field(
default=False,
metadata={"help": "Gradient Checkpointing param. Refer the related docs"},
)
use_unsloth: Optional[bool] = field(
default=False,
metadata={"help": "Enables UnSloth for training."},
)
@dataclass
class DataTrainingArguments:
dataset_name: Optional[str] = field(
default="timdettmers/openassistant-guanaco",
metadata={"help": "The preference dataset to use."},
)
append_concat_token: Optional[bool] = field(
default=False,
metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
)
add_special_tokens: Optional[bool] = field(
default=False,
metadata={"help": "If True, tokenizers adds special tokens to each sample being packed."},
)
splits: Optional[str] = field(
default="train,test",
metadata={"help": "Comma separate list of the splits to use from the dataset."},
)
from transformers import TrainerCallback
import time
class StepTimerCallback(TrainerCallback):
def __init__(self, average_over=10):
self.last_step_time = None
self.step_times = []
self.average_over = average_over
def on_step_begin(self, args, state, control, **kwargs):
self.last_step_time = time.time()
def on_step_end(self, args, state, control, **kwargs):
if self.last_step_time is not None:
step_time = time.time() - self.last_step_time
self.step_times.append(step_time)
recent_times = self.step_times
avg_step_time = sum(recent_times) / len(recent_times)
print(
f"Step {state.global_step}: {step_time:.4f}sec, {len(recent_times)} step avg: {avg_step_time:.4f}sec")
def replace_FSDP_init():
# 그냥 생각나는대로 적음. 이 함수를 실행하면 하나의 computer에 root process가 2개가 됨. 이는 일반적인 상황이 아니기 때문에 기존의 accelerate 로직이 잘못 동작할 수 있음. 지금 생각 나는 것들은 log(), save_checkpoint() 같은 것들이 있음. 이 외에도 root process 에서만 동작하는 것들 중에서 중복이 되면 난감한 것들을 제외해야 함.
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
original_init = FSDP.__init__ # 원본 백업
def custom_init(self, *args, **kwargs):
if 'process_group' not in kwargs or kwargs['process_group'] is None:
rank = dist.get_rank()
# groups = [[0], [1]]
ws = dist.get_world_size()
gs = 2
print(f'ws: {ws} gs: {gs}')
groups = tuple(tuple(range(s, s + gs)) for s in range(0, ws, gs))
my_group = None
for g in groups:
if rank in g:
my_group = dist.new_group(g)
break
kwargs['process_group'] = my_group
original_init(self, *args, **kwargs)
FSDP.__init__ = custom_init
def main(model_args, data_args, training_args):
# Set seed for reproducibility
set_seed(training_args.seed)
# model
model, peft_config, tokenizer = create_and_prepare_model(model_args, data_args, training_args)
# gradient ckpt
model.config.use_cache = not training_args.gradient_checkpointing
training_args.gradient_checkpointing = training_args.gradient_checkpointing and not model_args.use_unsloth
if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": model_args.use_reentrant}
training_args.dataset_kwargs = {
"append_concat_token": data_args.append_concat_token,
"add_special_tokens": data_args.add_special_tokens,
}
# datasets
train_dataset, eval_dataset = create_datasets(
tokenizer,
data_args,
training_args,
apply_chat_template=model_args.chat_template_format != "none",
)
# replace_FSDP_init() # test purpose only
# trainer
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
callbacks=[StepTimerCallback()],
)
trainer.accelerator.print(f"{trainer.model}")
if hasattr(trainer.model, "print_trainable_parameters"):
trainer.model.print_trainable_parameters()
# train
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
trainer.train(resume_from_checkpoint=checkpoint)
# saving final model
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model()
if __name__ == "__main__":
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SFTConfig))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
main(model_args, data_args, training_args)