-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
166 lines (138 loc) · 7 KB
/
train.py
File metadata and controls
166 lines (138 loc) · 7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Adapted from Tevatron code
import logging
import os.path
import sys
logging.basicConfig(
level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s',
handlers=[logging.StreamHandler(sys.stdout)] # Ensures logs appear in stdout
)
logger = logging.getLogger(__name__)
import sys
import torch
import wandb
import yaml
from transformers import HfArgumentParser
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.data.collator.train_collator import MultimodalDataCollator
from src.data.loader.mixed_dataset import init_mixed_dataset
from src.data.loader.concat_dataset import init_concat_dataset
from src.data.dataset.mmeb_dataset import CustomRandomSampler
from src.model.model import MMEBModel
from src.trainer import GradCacheLateProcessTrainer
from src.utils import print_rank, print_master, find_latest_checkpoint
from src.model.processor import load_processor, get_backbone_name
import os
import sys
import shutil
import torch.distributed as dist
import numpy as np
import math
def delete_pycache(root='.'):
for dirpath, dirnames, filenames in os.walk(root):
for dirname in dirnames:
if dirname == '__pycache__':
full_path = os.path.join(dirpath, dirname)
print(f"Deleting: {full_path}")
try:
shutil.rmtree(full_path)
except:
print(">>>>>", "Module not exists", full_path, flush=True)
pass
delete_pycache()
def main():
# a hack for torch.distributed.launch: https://github.com/huggingface/transformers/issues/22171
for arg in sys.argv:
if arg.startswith("--local-rank="):
rank = arg.split("=")[1]
sys.argv.remove(arg)
sys.argv.append('--local_rank')
sys.argv.append(rank)
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
# DEBUG PRINTS for Distributed Setup
print("Distributed init debug info:")
print(f"RANK: {os.environ.get('RANK')}")
print(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}")
print(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}")
print(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}")
print(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}")
if torch.distributed.is_available():
print(f"torch.distributed.is_initialized: {torch.distributed.is_initialized()}")
if torch.distributed.is_initialized():
print(f"torch.distributed.get_rank(): {torch.distributed.get_rank()}")
print(f"torch.distributed.get_world_size(): {torch.distributed.get_world_size()}")
# Check for existing checkpoints
if training_args.resume_from == 'auto':
resume_checkpoint_dir = find_latest_checkpoint(training_args.output_dir)
if resume_checkpoint_dir:
logger.info(f"Resuming from checkpoint: {resume_checkpoint_dir}")
elif training_args.resume_from.isdigit():
resume_checkpoint_dir = os.path.join(training_args.output_dir, f'checkpoint-{training_args.resume_from}')
if os.path.exists(resume_checkpoint_dir):
logger.info(f"Resuming from checkpoint: {resume_checkpoint_dir}")
else:
resume_checkpoint_dir = None
logger.info("No checkpoint found. Starting fresh training.")
# Initialize WandB if enabled
if 'wandb' in training_args.report_to:
if (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or (not torch.distributed.is_initialized()):
wandb.login(key="37464c29a08af81181e7450dc78367e94bd2a95c")
print_rank('init wandb')
wandb.init(project=training_args.project_name, name=training_args.run_name, mode="online")
wandb.config.update(model_args)
wandb.config.update(data_args)
wandb.config.update(training_args)
model = MMEBModel.build(model_args)
model_backbone = get_backbone_name(hf_config=model.config)
setattr(model_args, 'model_backbone', model_backbone)
setattr(training_args, 'model_backbone', model_backbone)
print_rank(f'model_backbone: {model_backbone}')
processor = load_processor(model_args, data_args)
setattr(model, 'model_backbone', model_backbone)
setattr(model, 'processor', processor)
with open(data_args.dataset_config, 'r') as yaml_file:
dataset_config = yaml.safe_load(yaml_file)
train_dataset, dataset_lens = init_concat_dataset(dataset_config, model_args, data_args, training_args)
train_collator = MultimodalDataCollator(processor, model_args, data_args, training_args)
if (data_args.rdibn or data_args.sdibn or data_args.odibn) and len(dataset_lens)>=1:
training_args.accelerator_config.use_seedable_sampler=False
assert data_args.max_len is None
assert training_args.gradient_accumulation_steps==1
assert data_args.image_resolution is None
assert data_args.resize_use_processor
trainer_cls = GradCacheLateProcessTrainer
trainer = trainer_cls(
model=model,
processing_class=processor,
args=training_args,
model_args=model_args,
train_dataset=train_dataset,
data_collator=train_collator,
max_length=data_args.max_len,
eval_dataset=data_args.eval_dataset_name,
data_args=data_args,
training_args=training_args,
)
train_dataset.trainer = trainer
print(f">>>>>>>>>dataset lengths: {dataset_lens}", flush=True)
if (data_args.rdibn or data_args.sdibn or data_args.odibn) and len(dataset_lens)>=1:
# Multiple embedding datasets & we want to make sure each batch mostly comes from one dataset
# Set custom sampler, see https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/trainer.py#L785
total_bs = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
total_bs = total_bs * dist.get_world_size() if dist.is_initialized() else total_bs
num_samples = np.sum(dataset_lens)
num_repeats = math.ceil((total_bs*training_args.max_steps)/num_samples)
print(f">>>>>>>>>>Embedding dataset lengths inside: {dataset_lens}; training args {training_args}; num_repeats {num_repeats}", flush=True)
trainer._get_train_sampler = lambda: CustomRandomSampler(
total_batch_size=total_bs, ds_lens=dataset_lens,
_num_samples=num_samples, data_source=train_dataset, ordered_datasets=data_args.odibn, same_datasets=data_args.sdibn, random_datasets=data_args.rdibn, num_repeats=num_repeats, chunk_size=data_args.chunk_size)
# import ipdb; ipdb.set_trace()
trainer.train(resume_from_checkpoint=resume_checkpoint_dir)
trainer.save_model(training_args.output_dir)
if trainer.is_world_process_zero():
processor.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
main()