-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtransform celeba.py
More file actions
59 lines (48 loc) · 2.13 KB
/
transform celeba.py
File metadata and controls
59 lines (48 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import shutil
import os
from model.dataset import ModelDataset
def main():
parent_dir = r'C:\DATA\CelebA'
output = r'C:\DATA\CelebA\transformed'
subjects_dict = {}
with open(os.path.join(parent_dir, r'identity_CelebA.txt')) as ds_table:
for line in ds_table:
if line[-1] == '\n':
line = line[:-1]
image_name = line.split(' ')[0]
subject = line.split(' ')[1]
if subject in subjects_dict.keys():
subjects_dict[subject].append(image_name)
else:
subjects_dict[subject] = [image_name]
print(f'found: {len(subjects_dict.keys())} subject')
if not os.path.exists(output):
os.makedirs(output)
for subject, images in subjects_dict.items():
if not os.path.exists(os.path.join(output, f's{subject}')):
os.makedirs(os.path.join(output, f's{subject}'))
for image_name in images:
shutil.copy(os.path.join(parent_dir, r'img_align_celeba\img_align_celeba', image_name),
os.path.join(output, f's{subject}', image_name))
print(f'subject: {subject} finished')
print('Finished')
if __name__ == '__main__':
# ModelDataset.create_samples_from_folders(r'C:\DATA\CelebA\subjects', r'C:\DATA\CelebA\transformed')
# main()
# transfer data
# for dir in os.listdir(r'C:\DATA\CelebA\transformed-TRAIN'):
# l = len(os.listdir(os.path.join(r'C:\DATA\CelebA\transformed-TRAIN', dir)))
# if l <= 10:
# print(dir)
# shutil.move(os.path.join(r'C:\DATA\CelebA\transformed-TRAIN', dir), r'C:\DATA\CelebA\transformed-TEST')
test = 0
for dir in os.listdir(r'C:\DATA\CelebA\transformed-TEST'):
l = len(os.listdir(os.path.join(r'C:\DATA\CelebA\transformed-TEST', dir)))
test += l
print(test)
train = 0
for dir in os.listdir(r'C:\DATA\CelebA\transformed-TRAIN'):
l = len(os.listdir(os.path.join(r'C:\DATA\CelebA\transformed-TRAIN', dir)))
train += l
print(train)
print(f'train test split:\n\t> train:{train/(train + test):.2f}\n\t> test:{test/(train + test):.2f}')