-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathtrain_dataset_for_llava.py
More file actions
176 lines (151 loc) · 5.49 KB
/
train_dataset_for_llava.py
File metadata and controls
176 lines (151 loc) · 5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import copy
import json
import logging
import pathlib
import torch
import random
import transformers
import tokenizers
from PIL import Image
from typing import Dict, Optional, Sequence, List
from dataclasses import dataclass, field
from torch.utils.data import Dataset
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
@dataclass
class DataArguments:
data_path: str = field(default=None,
metadata={"help": "Path to the training data."})
lazy_preprocess: bool = False
is_multimodal: bool = False
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = 'square'
more_data: Optional[str] = field(default=None) # new add
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments):
super(LazySupervisedDataset, self).__init__()
list_data_dict = json.load(open(data_path, "r"))
rank0_print(f"Total count of list_data_dict load from {data_path}: {len(list_data_dict)}")
# new add
if data_args.more_data is not None and data_args.more_data != "":
rank0_print("Append more data.")
more_data_dict = self.load_self_defined_data(data_args.more_data)
list_data_dict += more_data_dict
rank0_print(f"Total count of list_data_dict after append data.: {len(list_data_dict)}")
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if 'image' in sample else 0
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
try:
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
cur_len = cur_len if 'image' in sample else -cur_len
length_list.append(cur_len)
except Exception as e:
rank0_print(f'modality_lengths line 701 {repr(e)}')
rank0_print(sample)
raise e
return length_list
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image_folder = None
image_file = None
flag = False
while not flag:
try:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if 'image' in sources[0]:
image_file = self.list_data_dict[i]['image']
image_folder = self.data_args.image_folder
processor = self.data_args.image_processor
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
if self.data_args.image_aspect_ratio == 'pad':
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args
)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
has_image=('image' in self.list_data_dict[i])
)
if isinstance(i, int):
data_dict = dict(
input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0]
)
# image exist in the data
if 'image' in self.list_data_dict[i]:
data_dict['image'] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
flag = True
except Exception as e:
rank0_print(f"{repr(e)} image file can't open {image_folder} {image_file}")
i = random.randint(0, len(self.list_data_dict) - 1)
return data_dict
def get_json_files(self, data_dir):
json_files = []
# 递归遍历目录,可以读到链接文件
for root, dirs, files in os.walk(data_dir, followlinks=True):
for file in files:
if file.endswith('.json'):
json_files.append(os.path.join(root, file))
return json_files
def load_self_defined_data(self, data_dir):
more_data_dict = []
json_files = None
if data_dir.endswith('.json'):
json_files = [data_dir]
else:
json_files = self.get_json_files(data_dir)
for more_data_path in json_files:
more_data = json.load(open(more_data_path, "r"))
rank0_print(f"Count of {more_data_path}: {len(more_data)}")
# rank0_print(more_data[0])
more_data_dict += more_data
rank0_print(f"Total json file {len(json_files)}")
rank0_print(f"Total Count {len(more_data_dict)}")
rank0_print(type(more_data_dict))
rank0_print(more_data_dict[0])
return more_data_dict