-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathprepare_data.py
More file actions
87 lines (63 loc) · 2.75 KB
/
prepare_data.py
File metadata and controls
87 lines (63 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
from io import BytesIO
import multiprocessing
from functools import partial
from PIL import Image
import lmdb
from tqdm import tqdm
from torchvision import datasets
from torchvision.transforms import functional as trans_fn
from torchvision.transforms import InterpolationMode
import os
from glob import glob
def resize_and_convert(img, size, resample, quality=100):
img = trans_fn.resize(img, size, resample)
img = trans_fn.center_crop(img, size)
buffer = BytesIO()
img.save(buffer, format='jpeg', quality=quality)
val = buffer.getvalue()
return val
def resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100):
imgs = []
for size in sizes:
imgs.append(resize_and_convert(img, size, resample, quality))
return imgs
def resize_worker(img_file, sizes, resample):
i, file = img_file
img = Image.open(file)
img = img.convert('RGB')
out = resize_multiple(img, sizes=sizes, resample=resample)
return i, out
def prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS):
resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
files = sorted(dataset, key=lambda x: x.split('/')[-1])
files = [(i, file) for i, file in enumerate(files)]
total = 0
with multiprocessing.Pool(n_worker) as pool:
for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
for size, img in zip(sizes, imgs):
key = f'{size}-{str(i).zfill(5)}'.encode('utf-8')
with env.begin(write=True) as txn:
txn.put(key, img)
total += 1
with env.begin(write=True) as txn:
txn.put('length'.encode('utf-8'), str(total).encode('utf-8'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--out', type=str)
parser.add_argument('--size', type=str, default='128,256,512,1024')
parser.add_argument('--n_worker', type=int, default=8)
parser.add_argument('--resample', type=str, default='lanczos')
parser.add_argument('path', type=str)
args = parser.parse_args()
os.makedirs(args.out,exist_ok=True)
resample_map = {'lanczos': InterpolationMode.LANCZOS, 'bilinear': InterpolationMode.BILINEAR}
resample = resample_map[args.resample]
sizes = [int(s.strip()) for s in args.size.split(',')]
print(f'Make dataset of image sizes:', ', '.join(str(s) for s in sizes))
imgset = glob(os.path.join(args.path, '*'))
print(len(imgset))
print('Processing: ',args.path)
# assert len(imgset)==int(args.out.split('/')[-2].replace('shot',''))
with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)