diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f6582e1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.DS_Store +flask/.DS_Store +notebooks/.DS_Store diff --git a/README.md b/README.md index 4aaca1b..1268bd9 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,56 @@ -# face -Predicting the Current Face of Missing Children +# CV Project + +- Course: YearDraem School 2nd (Track CV) +- Project period: 21 Jul. ~ 5 Aug., 2022 +- Team members & roles + #### 권오균(O. Kwon) : model search, code modify and apply, paper review, data search, flask web implementation, ppt production. + #### 홍승현(S. Hong) : flask web implementation(input page), data search, model search. + #### 김지혜(J. Kim) : team leader, flask web implementation(output page), ppt production, presentation, data search. + #### 권민경(M. Kwon) : ppt production, model search, data search, paper review. + #### 이진석(J. Lee) : model search, paper review, code modify support, flask web implementation support, meeting report. +- Stacks + + ![Python](https://img.shields.io/badge/python-3670A0?style=for-the-badge&logo=python&logoColor=ffdd54) + ![PyTorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=for-the-badge&logo=PyTorch&logoColor=white) + ![Anaconda](https://img.shields.io/badge/Anaconda-%2344A833.svg?style=for-the-badge&logo=anaconda&logoColor=white) + ![OpenCV](https://img.shields.io/badge/opencv-%23white.svg?style=for-the-badge&logo=opencv&logoColor=white) + ![NumPy](https://img.shields.io/badge/numpy-%23013243.svg?style=for-the-badge&logo=numpy&logoColor=white) + ![SciPy](https://img.shields.io/badge/SciPy-%230C55A5.svg?style=for-the-badge&logo=scipy&logoColor=%white) + + + ![Google Colab](https://img.shields.io/badge/Google%20Colab-F9AB00?style=for-the-badge&logo=googlecolab&logoColor=white) + ![Google Drive](https://img.shields.io/badge/Google%20Drive-4285F4?style=for-the-badge&logo=googledrive&logoColor=white) + ![Visual Studio Code](https://img.shields.io/badge/Visual%20Studio%20Code-0078d7.svg?style=for-the-badge&logo=visual-studio-code&logoColor=white) + ![Jupyter Notebook](https://img.shields.io/badge/jupyter-%23FA0F00.svg?style=for-the-badge&logo=jupyter&logoColor=white) + + ![Flask](https://img.shields.io/badge/flask-%23000.svg?style=for-the-badge&logo=flask&logoColor=white) + ![HTML5](https://img.shields.io/badge/html5-%23E34F26.svg?style=for-the-badge&logo=html5&logoColor=white) + ![CSS3](https://img.shields.io/badge/css3-%231572B6.svg?style=for-the-badge&logo=css3&logoColor=white) + ![JavaScript](https://img.shields.io/badge/javascript-%23323330.svg?style=for-the-badge&logo=javascript&logoColor=%23F7DF1E) + ![Bootstrap](https://img.shields.io/badge/bootstrap-%23563D7C.svg?style=for-the-badge&logo=bootstrap&logoColor=white) + + ![Notion](https://img.shields.io/badge/Notion-%23000000.svg?style=for-the-badge&logo=notion&logoColor=white) + ![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white) + + +## Project topic +### Missing Children Current Face Prediction + +#### The reason we chose this topic + +##### All the team members wanted to try the face detecting and CycleGAN series models, and while looking for a suitable model, we found Lifespan Age Transformation Synthesis (LATS) and decided to try it. +[LATS] https://github.com/royorel/Lifespan_Age_Transformation_Synthesis +##### Then, after thinking about how to use the model we found, We thought it would be good to predict the present state of the children who was lost, so we started this project. + +## Why using colab +##### The first plan was to implement a web service through flask in the local environment. +##### However, the computer we are using is an M1 imac, there is a compatibility issue with pytorch and it is difficult to use the GPU, so we decided to implement it through google colab. We used ngrok to run flask in colab. + +## Flask Web implement result + +![ezgif-2-6189335993](https://user-images.githubusercontent.com/87400909/182803440-dc213fd9-9594-4487-a8b0-78ebb8a61899.gif) + +## Presentation + +![CV_오프라인-1조](https://user-images.githubusercontent.com/87400909/182988156-8bcfc09a-a6e7-4618-a7dc-144d6f3000f6.gif) + diff --git a/data/base_dataset.py b/data/base_dataset.py new file mode 100755 index 0000000..7d9a1d1 --- /dev/null +++ b/data/base_dataset.py @@ -0,0 +1,13 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import torch.utils.data as data + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + def name(self): + return 'BaseDataset' + + def initialize(self, opt): + pass diff --git a/data/data_loader.py b/data/data_loader.py new file mode 100755 index 0000000..b1d784e --- /dev/null +++ b/data/data_loader.py @@ -0,0 +1,39 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import torch.utils.data +from data.multiclass_unaligned_dataset import MulticlassUnalignedDataset +from pdb import set_trace as st + +class AgingDataLoader(): + def name(self): + return 'AgingDataLoader' + + def initialize(self, opt): + self.opt = opt + self.dataset = CreateDataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + drop_last=True, + num_workers=int(opt.nThreads)) + + def load_data(self): + return self.dataloader + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) + + +def CreateDataset(opt): + dataset = MulticlassUnalignedDataset() + print("dataset [%s] was created" % (dataset.name())) + dataset.initialize(opt) + return dataset + + +def CreateDataLoader(opt): + data_loader = AgingDataLoader() + print(data_loader.name()) + data_loader.initialize(opt) + return data_loader diff --git a/data/dataset_utils.py b/data/dataset_utils.py new file mode 100755 index 0000000..7332d3b --- /dev/null +++ b/data/dataset_utils.py @@ -0,0 +1,62 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import os +from PIL import Image +import torchvision.transforms as transforms +import numpy as np +import random +from pdb import set_trace as st + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def list_folder_images(dir, opt): + images = [] + parsings = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + for fname in os.listdir(dir): + if is_image_file(fname): + path = os.path.join(dir, fname) + # make sure there's a matching parsings for the image + # parsing files must be png + parsing_fname = fname[:-3] + 'png' + if os.path.isfile(os.path.join(dir, 'parsings', parsing_fname)): + parsing_path = os.path.join(dir, 'parsings', parsing_fname) + images.append(path) + parsings.append(parsing_path) + + # sort according to identity in case of FGNET test + if 'fgnet' in opt.dataroot.lower(): + images.sort(key=str.lower) + parsings.sort(key=str.lower) + + return images, parsings + +def get_transform(opt, normalize=True): + transform_list = [] + + if opt.resize_or_crop == 'resize_and_crop': + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Resize(osize, interpolation=Image.NEAREST)) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'crop': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + transform_list += [transforms.ToTensor()] + + if normalize: + mean = (0.5,) + std = (0.5,) + transform_list += [transforms.Normalize(mean,std)] + + return transforms.Compose(transform_list) diff --git a/data/multiclass_unaligned_dataset.py b/data/multiclass_unaligned_dataset.py new file mode 100755 index 0000000..6fcd2b2 --- /dev/null +++ b/data/multiclass_unaligned_dataset.py @@ -0,0 +1,226 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import os.path +import re +import torch +import random +import numpy as np +from data.base_dataset import BaseDataset +from data.dataset_utils import list_folder_images, get_transform +from util.preprocess_itw_im import preprocessInTheWildImage +from PIL import Image +from pdb import set_trace as st + +CLASSES_UPPER_BOUNDS = [2, 6, 9, 14, 19, 29, 39, 49, 69, 120] + +class MulticlassUnalignedDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.name_mapping = {} + self.prev_A = -1 + self.prev_B = -1 + self.class_A = -1 + self.class_B = -1 + self.get_samples = False + if not self.opt.isTrain: + self.in_the_wild = opt.in_the_wild + else: + self.in_the_wild = False + + # find all existing classes in root + if not self.in_the_wild: + self.tempClassNames = [] + subDirs = next(os.walk(self.root))[1] # a quick way to get all subdirectories + for currDir in subDirs: + if self.opt.isTrain: + prefix = 'train' + else: + prefix = 'test' + if prefix in currDir: # we assume that the class name starts with the prefix + len_prefix = len(prefix) + className = currDir[len_prefix:] + self.tempClassNames += [className] + + # sort classes + if len(self.opt.sort_order) > 0: + self.classNames = [] + for i, nextClass in enumerate(self.opt.sort_order): + for currClass in self.tempClassNames: + if nextClass == currClass: + self.classNames += [currClass] + curr_class_num = self.assign_age_class(currClass) + self.name_mapping[currClass] = curr_class_num + else: + self.classNames = sorted(self.tempClassNames) + for i, currClass in enumerate(self.classNames): + curr_class_num = self.assign_age_class(currClass) + self.name_mapping[currClass] = curr_class_num + else: + self.classNames = [] + for i, nextClass in enumerate(self.opt.sort_order): + self.classNames += [nextClass] + curr_class_num = self.assign_age_class(nextClass) + self.name_mapping[nextClass] = curr_class_num + + self.active_classes_mapping = {} + + for i, name in enumerate(self.classNames): + self.active_classes_mapping[i] = self.name_mapping[name] + + self.numClasses = len(self.classNames) + opt.numClasses = self.numClasses + opt.classNames = self.classNames + + # set class counter for test mode + if self.opt.isTrain is False: + opt.batchSize = self.numClasses + self.class_counter = 0 + self.img_counter = 0 + + # arrange directories + if not self.in_the_wild: + self.dirs = [] + self.img_paths = [] + self.parsing_paths = [] + self.sizes = [] + + for currClass in self.classNames: + self.dirs += [os.path.join(self.root, opt.phase + currClass)] + imgs, parsings = list_folder_images(self.dirs[-1], self.opt) + self.img_paths += [imgs] + self.parsing_paths += [parsings] + self.sizes += [len(self.img_paths[-1])] + + opt.dataset_size = self.__len__() + + self.transform = get_transform(opt) + + if (not self.opt.isTrain) and self.in_the_wild: + self.preprocessor = preprocessInTheWildImage(out_size=opt.fineSize) + + def set_sample_mode(self, mode=False): + self.get_samples = mode + self.class_counter = 0 + self.img_counter = 0 + + def assign_age_class(self, class_name): + ages = [int(s) for s in re.split('-|_', class_name) if s.isdigit()] + max_age = ages[-1] + for i in range(len(CLASSES_UPPER_BOUNDS)): + if max_age <= CLASSES_UPPER_BOUNDS[i]: + break + + return i + + def mask_image(self, img, parsings): + labels_to_mask = [0,14,15,16,18] + for idx in labels_to_mask: + img[parsings == idx] = 128 + + return img + + def get_item_from_path(self, path): + path_dir, im_name = os.path.split(path) + img = Image.open(path).convert('RGB') + img = np.array(img.getdata(), dtype=np.uint8).reshape(img.size[1], img.size[0], 3) + + if self.in_the_wild: + img, parsing = self.preprocessor.forward(img) + else: + parsing_path = os.path.join(path_dir, 'parsings', im_name[:-4] + '.png') + parsing = Image.open(parsing_path).convert('RGB') + parsing = np.array(parsing.getdata(), dtype=np.uint8).reshape(parsing.size[1], parsing.size[0], 3) + + img = Image.fromarray(self.mask_image(img, parsing)) + img = self.transform(img).unsqueeze(0) + + return {'Imgs': img, + 'Paths': [path], + 'Classes': torch.zeros(1, dtype=torch.int), + 'Valid': True} + + def __getitem__(self, index): + if self.opt.isTrain and not self.get_samples: + condition = True + self.class_A_idx = random.randint(0,self.numClasses - 1) + self.class_A = self.active_classes_mapping[self.class_A_idx] + while condition: + self.class_B_idx = random.randint(0,self.numClasses - 1) + self.class_B = self.active_classes_mapping[self.class_B_idx] + condition = self.class_A == self.class_B + + index_A = random.randint(0, self.sizes[self.class_A_idx] - 1) + index_B = random.randint(0, self.sizes[self.class_B_idx] - 1) + + A_img_path = self.img_paths[self.class_A_idx][index_A] + A_img = Image.open(A_img_path).convert('RGB') + A_img = np.array(A_img.getdata(), dtype=np.uint8).reshape(A_img.size[1], A_img.size[0], 3) + + B_img_path = self.img_paths[self.class_B_idx][index_B] + B_img = Image.open(B_img_path).convert('RGB') + B_img = np.array(B_img.getdata(), dtype=np.uint8).reshape(B_img.size[1], B_img.size[0], 3) + + A_parsing_path = self.parsing_paths[self.class_A_idx][index_A] + A_parsing = Image.open(A_parsing_path).convert('RGB') + A_parsing = np.array(A_parsing.getdata(), dtype=np.uint8).reshape(A_parsing.size[1], A_parsing.size[0], 3) + A_img = Image.fromarray(self.mask_image(A_img, A_parsing)) + + B_parsing_path = self.parsing_paths[self.class_B_idx][index_B] + B_parsing = Image.open(B_parsing_path).convert('RGB') + B_parsing = np.array(B_parsing.getdata(), dtype=np.uint8).reshape(B_parsing.size[1], B_parsing.size[0], 3) + B_img = Image.fromarray(self.mask_image(B_img, B_parsing)) + + # numpy conversions are an annoying hack to form a PIL image with more than 3 channels + A_img = self.transform(A_img) + B_img = self.transform(B_img) + + return {'A': A_img, 'B': B_img, + "A_class": self.class_A_idx, "B_class": self.class_B_idx, + 'A_paths': A_img_path, 'B_paths': B_img_path} + + else: # in test mode, load one image from each class + i = self.class_counter % self.numClasses + self.class_counter += 1 + + if self.get_samples: + ind = random.randint(0, self.sizes[i] - 1) + else: + ind = self.img_counter if self.img_counter < self.sizes[i] else -1 + + if i == self.numClasses - 1: + self.img_counter += 1 + + if ind > -1: + valid = True + paths = self.img_paths[i][ind] + img = Image.open(self.img_paths[i][ind]).convert('RGB') + img = np.array(img.getdata(), dtype=np.uint8).reshape(img.size[1], img.size[0], 3) + + parsing_path = self.parsing_paths[i][ind] + parsing = Image.open(parsing_path).convert('RGB') + parsing = np.array(parsing.getdata(), dtype=np.uint8).reshape(parsing.size[1], parsing.size[0], 3) + img = Image.fromarray(self.mask_image(img, parsing)) + + img = self.transform(img) + + else: + img = torch.zeros(3, self.opt.fineSize, self.opt.fineSize) + paths = '' + valid = False + + return {'Imgs': img, + 'Paths': paths, + 'Classes': i, + 'Valid': valid} + + def __len__(self): + if self.opt.isTrain: + return round(sum(self.sizes) / 2) # this determines how many iterations we make per epoch + elif self.in_the_wild: + return 0 + else: + return max(self.sizes) * self.numClasses + + def name(self): + return 'MulticlassUnalignedDataset' diff --git a/datasets/create_dataset.py b/datasets/create_dataset.py new file mode 100644 index 0000000..314b619 --- /dev/null +++ b/datasets/create_dataset.py @@ -0,0 +1,96 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import argparse +import PIL.Image as Image +import os +import csv +import shutil +from pdb import set_trace as st + +clusters = ['0-2','3-6','7-9','10-14','15-19', + '20-29','30-39','40-49','50-69','70-120'] + +def processIm(img_filename, phase, csv_row, num): + img_basename = os.path.basename(img_filename) + labels_filename = os.path.join(os.path.dirname(img_filename), 'parsings', img_basename) + + age, age_conf = csv_row['age_group'], float(csv_row['age_group_confidence']) + gender, gender_conf = csv_row['gender'], float(csv_row['gender_confidence']) + head_pitch, head_roll, head_yaw = float(csv_row['head_pitch']), float(csv_row['head_roll']), float(csv_row['head_yaw']) + left_eye_occluded, right_eye_occluded = float(csv_row['left_eye_occluded']), float(csv_row['right_eye_occluded']) + glasses = csv_row['glasses'] + + no_attributes_found = head_pitch == -1 and head_roll == -1 and head_yaw == -1 and \ + left_eye_occluded == -1 and right_eye_occluded == -1 and glasses == -1 + + age_cond = age_conf > 0.6 + gender_cond = gender_conf > 0.66 + head_pose_cond = abs(head_pitch) < 30.0 and abs(head_yaw) < 40.0 + eyes_cond = (left_eye_occluded < 90.0 and right_eye_occluded < 50.0) or (left_eye_occluded < 50.0 and right_eye_occluded < 90.0) + glasses_cond = glasses != 'Dark' + + valid1 = age_cond and gender_cond and no_attributes_found + valid2 = age_cond and gender_cond and head_pose_cond and eyes_cond and glasses_cond + + if gender == 'male': + dst_gender = 'males' + else: + dst_gender = 'females' + + dst_cluster = phase + age + + if (valid1 or valid2): + dst_path = os.path.join(dst_gender, dst_cluster, img_basename) + parsing_dst_path = os.path.join(dst_gender, dst_cluster, 'parsings', img_basename) + shutil.copy(img_filename, dst_path) + shutil.copy(labels_filename, parsing_dst_path) + + +def create_dataset(folder, labels_file, train_split): + # create all necessary subfolders + for clust in clusters: + trainMaleClusterPath = "males/train" + clust + testMaleClusterPath = "males/test" + clust + trainFemaleClusterPath = "females/train" + clust + testFemleClusterPath = "females/test" + clust + + if not os.path.isdir(trainMaleClusterPath): + os.makedirs(trainMaleClusterPath) + os.makedirs(os.path.join(trainMaleClusterPath,'parsings')) + os.makedirs(testMaleClusterPath) + os.makedirs(os.path.join(testMaleClusterPath,'parsings')) + os.makedirs(trainFemaleClusterPath) + os.makedirs(os.path.join(trainFemaleClusterPath,'parsings')) + os.makedirs(testFemleClusterPath) + os.makedirs(os.path.join(testFemleClusterPath,'parsings')) + + # process images + with open(labels_file,'r', newline='') as f: + reader = csv.DictReader(f) + for csv_row in reader: + num = int(csv_row['image_number']) + + if num < train_split: + phase = 'train' + else: + phase = 'test' + + subdir = str(num - (num % 1000)).zfill(5) + img_filename = os.path.join(folder,subdir,str(num).zfill(5)+'.png') + if os.path.isfile(img_filename): + print('processing {}'.format(img_filename)) + processIm(img_filename, phase, csv_row, num) + else: + print('Image {}.png was not found'.format(str(num).zfill(5))) + + +if __name__ == '__main__': + argparser = argparse.ArgumentParser() + argparser.add_argument('--folder', type=str, default='../../FFHQ-Aging-Dataset/ffhq_aging256x256', help='Location of the raw FFHQ-Aging dataset') + argparser.add_argument('--labels_file', type=str, default='../../FFHQ-Aging-Dataset/ffhq_aging_labels.csv', help='Location of the raw FFHQ-Aging dataset') + argparser.add_argument('--train_split', type=int, default=69000, help='number of images to allocate for training') + args = argparser.parse_args() + folder = args.folder + labels_file = args.labels_file + train_split = args.train_split + create_dataset(folder, labels_file, train_split) diff --git a/download_models.py b/download_models.py new file mode 100644 index 0000000..bb1d7b0 --- /dev/null +++ b/download_models.py @@ -0,0 +1,4 @@ +import util.util as util + +if __name__ == "__main__": + util.download_pretrained_models() diff --git a/females_image_list.txt b/females_image_list.txt new file mode 100644 index 0000000..f8d830f --- /dev/null +++ b/females_image_list.txt @@ -0,0 +1,16 @@ +datasets/females/test3-6/69418.png +datasets/females/test7-9/69859.png +datasets/females/test15-19/69828.png +datasets/females/test50-69/69249.png +datasets/females/test3-6/69270.png +datasets/females/test0-2/69766.png +datasets/females/test7-9/69369.png +datasets/females/test15-19/69352.png +datasets/females/test7-9/69222.png +datasets/females/test7-9/69526.png +datasets/females/test15-19/69489.png +datasets/females/test30-39/69716.png +datasets/females/test3-6/69583.png +datasets/females/test15-19/69131.png +datasets/females/test7-9/69005.png +datasets/females/test15-19/69435.png diff --git a/males_image_list.txt b/males_image_list.txt new file mode 100755 index 0000000..f490ace --- /dev/null +++ b/males_image_list.txt @@ -0,0 +1,17 @@ +datasets/males/test7-9/69235.png +datasets/males/test3-6/69269.png +datasets/males/test15-19/69969.png +datasets/males/test50-69/69063.png +datasets/males/test30-39/69587.png +datasets/males/test50-69/69443.png +datasets/males/test30-39/69193.png +datasets/males/test50-69/69691.png +datasets/males/test30-39/69553.png +datasets/males/test0-2/69293.png +datasets/males/test0-2/69598.png +datasets/males/test0-2/69336.png +datasets/males/test0-2/69085.png +datasets/males/test0-2/69782.png +datasets/males/test0-2/69754.png +datasets/males/test15-19/69735.png +datasets/males/test15-19/69451.png diff --git a/models/LATS_model.py b/models/LATS_model.py new file mode 100755 index 0000000..e5b2eb7 --- /dev/null +++ b/models/LATS_model.py @@ -0,0 +1,504 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import numpy as np +import torch +import torch.nn as nn +import re +import functools +from collections import OrderedDict +from .base_model import BaseModel +import util.util as util +from . import networks +from pdb import set_trace as st + +class LATS(BaseModel): #Lifetime Age Transformation Synthesis + def name(self): + return 'LATS' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + # if opt.resize_or_crop != 'none': # when training at full res this causes OOM + torch.backends.cudnn.benchmark = True + + # determine mode of operation [train, test, deploy, traverse (latent interpolation)] + self.isTrain = opt.isTrain + self.traverse = (not self.isTrain) and opt.traverse + + # mode to generate Fig. 15 in the paper + self.compare_to_trained_outputs = (not self.isTrain) and opt.compare_to_trained_outputs + if self.compare_to_trained_outputs: + self.compare_to_trained_class = opt.compare_to_trained_class + self.trained_class_jump = opt.trained_class_jump + + self.deploy = (not self.isTrain) and opt.deploy + if not self.isTrain and opt.random_seed != -1: + torch.manual_seed(opt.random_seed) + torch.cuda.manual_seed_all(opt.random_seed) + np.random.seed(opt.random_seed) + + # network architecture parameters + self.nb = opt.batchSize + self.size = opt.fineSize + self.ngf = opt.ngf + self.ngf_global = self.ngf + + self.numClasses = opt.numClasses + self.use_moving_avg = not opt.no_moving_avg + + self.no_cond_noise = opt.no_cond_noise + style_dim = opt.gen_dim_per_style * self.numClasses + self.duplicate = opt.gen_dim_per_style + + self.cond_length = style_dim + + # self.active_classes_mapping = opt.active_classes_mapping + + if not self.isTrain: + self.debug_mode = opt.debug_mode + else: + self.debug_mode = False + + ##### define networks + # Generators + self.netG = self.parallelize(networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.n_downsample, + id_enc_norm=opt.id_enc_norm, gpu_ids=self.gpu_ids, padding_type='reflect', style_dim=style_dim, + init_type='kaiming', conv_weight_norm=opt.conv_weight_norm, + decoder_norm=opt.decoder_norm, activation=opt.activation, + adaptive_blocks=opt.n_adaptive_blocks, normalize_mlp=opt.normalize_mlp, + modulated_conv=opt.use_modulated_conv)) + if self.isTrain and self.use_moving_avg: + self.g_running = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.n_downsample, + id_enc_norm=opt.id_enc_norm, gpu_ids=self.gpu_ids, padding_type='reflect', style_dim=style_dim, + init_type='kaiming', conv_weight_norm=opt.conv_weight_norm, + decoder_norm=opt.decoder_norm, activation=opt.activation, + adaptive_blocks=opt.n_adaptive_blocks, normalize_mlp=opt.normalize_mlp, + modulated_conv=opt.use_modulated_conv) + self.g_running.train(False) + self.requires_grad(self.g_running, flag=False) + self.accumulate(self.g_running, self.netG, decay=0) + + # Discriminator network + if self.isTrain: + self.netD = self.parallelize(networks.define_D(opt.output_nc, opt.ndf, n_layers=opt.n_layers_D, + numClasses=self.numClasses, gpu_ids=self.gpu_ids, + init_type='kaiming')) + + if self.opt.verbose: + print('---------- Networks initialized -------------') + + # load networks + if (not self.isTrain) or opt.continue_train or opt.load_pretrain: + pretrained_path = '' if (not self.isTrain) or (self.isTrain and opt.continue_train) else opt.load_pretrain + if self.isTrain: + self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) + self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) + if self.use_moving_avg: + self.load_network(self.g_running, 'g_running', opt.which_epoch, pretrained_path) + elif self.use_moving_avg: + self.load_network(self.netG, 'g_running', opt.which_epoch, pretrained_path) + else: + self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) + + + # set loss functions and optimizers + if self.isTrain: + # define loss functions + self.criterionGAN = self.parallelize(networks.SelectiveClassesNonSatGANLoss()) + self.R1_reg = networks.R1_reg() + self.age_reconst_criterion = self.parallelize(networks.FeatureConsistency()) + self.identity_reconst_criterion = self.parallelize(networks.FeatureConsistency()) + self.criterionCycle = self.parallelize(networks.FeatureConsistency()) #torch.nn.L1Loss() + self.criterionRec = self.parallelize(networks.FeatureConsistency()) #torch.nn.L1Loss() + + # initialize optimizers + self.old_lr = opt.lr + + # set optimizer G + paramsG = [] + params_dict_G = dict(self.netG.named_parameters()) + # set the MLP learning rate to 0.01 or the global learning rate + for key, value in params_dict_G.items(): + decay_cond = ('decoder.mlp' in key) + if opt.decay_adain_affine_layers: + decay_cond = decay_cond or ('class_std' in key) or ('class_mean' in key) + if decay_cond: + paramsG += [{'params':[value],'lr':opt.lr * 0.01,'mult':0.01}] + else: + paramsG += [{'params':[value],'lr':opt.lr}] + + self.optimizer_G = torch.optim.Adam(paramsG, lr=opt.lr, betas=(opt.beta1, opt.beta2)) + + # set optimizer D + paramsD = list(self.netD.parameters()) + self.optimizer_D = torch.optim.Adam(paramsD, lr=opt.lr, betas=(opt.beta1, opt.beta2)) + + + def parallelize(self, model): + # parallelize a network + if self.isTrain and len(self.gpu_ids) > 0: + return networks._CustomDataParallel(model) + else: + return model + + + def requires_grad(self, model, flag=True): + # freeze network weights + for p in model.parameters(): + p.requires_grad = flag + + + def accumulate(self, model1, model2, decay=0.999): + # implements exponential moving average + params1 = dict(model1.named_parameters()) + params2 = dict(model2.named_parameters()) + model1_parallel = isinstance(model1, nn.DataParallel) + model2_parallel = isinstance(model2, nn.DataParallel) + + for k in params1.keys(): + if model2_parallel and not model1_parallel: + k2 = 'module.' + k + elif model1_parallel and not model2_parallel: + k2 = re.sub('module.', '', k) + else: + k2 = k + params1[k].data.mul_(decay).add_(1 - decay, params2[k2].data) + + + def set_inputs(self, data, mode='train'): + # set input data to feed to the network + if mode == 'train': + real_A = data['A'] + real_B = data['B'] + + self.class_A = data['A_class'] + self.class_B = data['B_class'] + + self.reals = torch.cat((real_A, real_B), 0) + + if len(self.gpu_ids) > 0: + self.reals = self.reals.cuda() + + else: + inputs = data['Imgs'] + if inputs.dim() > 4: + inputs = inputs.squeeze(0) + + self.class_A = data['Classes'] + if self.class_A.dim() > 1: + self.class_A = self.class_A.squeeze(0) + + if torch.is_tensor(data['Valid']): + self.valid = data['Valid'].bool() + else: + self.valid = torch.ones(1, dtype=torch.bool) + + if self.valid.dim() > 1: + self.valid = self.valid.squeeze(0) + + if isinstance(data['Paths'][0], tuple): + self.image_paths = [path[0] for path in data['Paths']] + else: + self.image_paths = data['Paths'] + + self.isEmpty = False if any(self.valid) else True + if not self.isEmpty: + available_idx = torch.arange(len(self.class_A)) + select_idx = torch.masked_select(available_idx, self.valid).long() + inputs = torch.index_select(inputs, 0, select_idx) + + self.class_A = torch.index_select(self.class_A, 0, select_idx) + self.image_paths = [val for i, val in enumerate(self.image_paths) if self.valid[i] == 1] + + self.reals = inputs + + if len(self.gpu_ids) > 0: + self.reals = self.reals.cuda() + + + def get_conditions(self, mode='train'): + # set conditional inputs to the network + if mode == 'train': + nb = self.reals.shape[0] // 2 + elif self.traverse or self.deploy: + if self.traverse and self.compare_to_trained_outputs: + nb = 2 + else: + nb = self.numClasses + else: + nb = self.numValid + + #tex condition mapping + condG_A_gen = self.Tensor(nb, self.cond_length) + condG_B_gen = self.Tensor(nb, self.cond_length) + condG_A_orig = self.Tensor(nb, self.cond_length) + condG_B_orig = self.Tensor(nb, self.cond_length) + + if self.no_cond_noise: + noise_sigma = 0 + else: + noise_sigma = 0.2 + + for i in range(nb): + condG_A_gen[i, :] = (noise_sigma * torch.randn(1, self.cond_length)).cuda() + condG_A_gen[i, self.class_B[i]*self.duplicate:(self.class_B[i] + 1)*self.duplicate] += 1 + if not (self.traverse or self.deploy): + condG_B_gen[i, :] = (noise_sigma * torch.randn(1, self.cond_length)).cuda() + condG_B_gen[i, self.class_A[i]*self.duplicate:(self.class_A[i] + 1)*self.duplicate] += 1 + + condG_A_orig[i, :] = (noise_sigma * torch.randn(1, self.cond_length)).cuda() + condG_A_orig[i, self.class_A[i]*self.duplicate:(self.class_A[i] + 1)*self.duplicate] += 1 + + condG_B_orig[i, :] = (noise_sigma * torch.randn(1, self.cond_length)).cuda() + condG_B_orig[i, self.class_B[i]*self.duplicate:(self.class_B[i] + 1)*self.duplicate] += 1 + + if mode == 'train': + self.gen_conditions = torch.cat((condG_A_gen, condG_B_gen), 0) #torch.cat((self.class_B, self.class_A), 0) + # if the results are not good this might be the issue!!!! uncomment and update code respectively + self.cyc_conditions = torch.cat((condG_B_gen, condG_A_gen), 0) + self.orig_conditions = torch.cat((condG_A_orig, condG_B_orig),0) + else: + self.gen_conditions = condG_A_gen #self.class_B + if not (self.traverse or self.deploy): + # if the results are not good this might be the issue!!!! uncomment and update code respectively + self.cyc_conditions = condG_B_gen #self.class_A + self.orig_conditions = condG_A_orig + + + def update_G(self, infer=False): + # Generator optimization setp + self.optimizer_G.zero_grad() + self.get_conditions() + + ############### multi GPU ############### + rec_images, gen_images, cyc_images, orig_id_features, \ + orig_age_features, fake_id_features, fake_age_features = \ + self.netG(self.reals, self.gen_conditions, self.cyc_conditions, self.orig_conditions) + + #discriminator pass + disc_out = self.netD(gen_images) + + #self-reconstruction loss + if self.opt.lambda_rec > 0: + loss_G_Rec = self.criterionRec(rec_images, self.reals) * self.opt.lambda_rec + else: + loss_G_Rec = torch.zeros(1).cuda() + + #cycle loss + if self.opt.lambda_cyc > 0: + loss_G_Cycle = self.criterionCycle(cyc_images, self.reals) * self.opt.lambda_cyc + else: + loss_G_Cycle = torch.zeros(1).cuda() + + # identity feature loss + loss_G_identity_reconst = self.identity_reconst_criterion(fake_id_features, orig_id_features) * self.opt.lambda_id + # age feature loss + loss_G_age_reconst = self.age_reconst_criterion(fake_age_features, self.gen_conditions) * self.opt.lambda_age + # orig age feature loss + loss_G_age_reconst += self.age_reconst_criterion(orig_age_features, self.orig_conditions) * self.opt.lambda_age + + # adversarial loss + target_classes = torch.cat((self.class_B,self.class_A),0) + loss_G_GAN = self.criterionGAN(disc_out, target_classes, True, is_gen=True) + + # overall loss + loss_G = (loss_G_GAN + loss_G_Rec + loss_G_Cycle + \ + loss_G_identity_reconst + loss_G_age_reconst).mean() + + loss_G.backward() + self.optimizer_G.step() + + # update exponential moving average + if self.use_moving_avg: + self.accumulate(self.g_running, self.netG) + + # generate images for visdom + if infer: + if self.use_moving_avg: + with torch.no_grad(): + orig_id_features_out, _ = self.g_running.encode(self.reals) + #within domain decode + if self.opt.lambda_rec > 0: + rec_images_out = self.g_running.decode(orig_id_features_out, self.orig_conditions) + + #cross domain decode + gen_images_out = self.g_running.decode(orig_id_features_out, self.gen_conditions) + #encode generated + fake_id_features_out, _ = self.g_running.encode(gen_images) + #decode generated + if self.opt.lambda_cyc > 0: + cyc_images_out = self.g_running.decode(fake_id_features_out, self.cyc_conditions) + else: + gen_images_out = gen_images + if self.opt.lambda_rec > 0: + rec_images_out = rec_images + if self.opt.lambda_cyc > 0: + cyc_images_out = cyc_images + + loss_dict = {'loss_G_Adv': loss_G_GAN.mean(), 'loss_G_Cycle': loss_G_Cycle.mean(), + 'loss_G_Rec': loss_G_Rec.mean(), 'loss_G_identity_reconst': loss_G_identity_reconst.mean(), + 'loss_G_age_reconst': loss_G_age_reconst.mean()} + + return [loss_dict, + None if not infer else self.reals, + None if not infer else gen_images_out, + None if not infer else rec_images_out, + None if not infer else cyc_images_out] + + + def update_D(self): + # Discriminator optimization setp + self.optimizer_D.zero_grad() + self.get_conditions() + + ############### multi GPU ############### + _, gen_images, _, _, _, _, _ = self.netG(self.reals, self.gen_conditions, None, None, disc_pass=True) + + #fake discriminator pass + fake_disc_in = gen_images.detach() + fake_disc_out = self.netD(fake_disc_in) + + #real discriminator pass + real_disc_in = self.reals + + # necessary for R1 regularization + real_disc_in.requires_grad_() + + real_disc_out = self.netD(real_disc_in) + + #Fake GAN loss + fake_target_classes = torch.cat((self.class_B,self.class_A),0) + loss_D_fake = self.criterionGAN(fake_disc_out, fake_target_classes, False, is_gen=False) + + #Real GAN loss + real_target_classes = torch.cat((self.class_A,self.class_B),0) + loss_D_real = self.criterionGAN(real_disc_out, real_target_classes, True, is_gen=False) + + # R1 regularization + loss_D_reg = self.R1_reg(real_disc_out, real_disc_in) + + loss_D = (loss_D_fake + loss_D_real + loss_D_reg).mean() + loss_D.backward() + self.optimizer_D.step() + + return {'loss_D_real': loss_D_real.mean(), 'loss_D_fake': loss_D_fake.mean(), 'loss_D_reg': loss_D_reg.mean()} + + + def inference(self, data): + self.set_inputs(data, mode='test') + if self.isEmpty: + return + + self.numValid = self.valid.sum().item() + sz = self.reals.size() + self.fake_B = self.Tensor(self.numClasses, sz[0], sz[1], sz[2], sz[3]) + self.cyc_A = self.Tensor(self.numClasses, sz[0], sz[1], sz[2], sz[3]) + + with torch.no_grad(): + if self.traverse or self.deploy: + if self.traverse and self.compare_to_trained_outputs: + start = self.compare_to_trained_class - self.trained_class_jump + end = start + (self.trained_class_jump * 2) * 2 #arange is between [start, end), end is always omitted + self.class_B = torch.arange(start, end, step=self.trained_class_jump*2, dtype=self.class_A.dtype) + else: + self.class_B = torch.arange(self.numClasses, dtype=self.class_A.dtype) + + self.get_conditions(mode='test') + + self.fake_B = self.netG.infer(self.reals, self.gen_conditions, traverse=self.traverse, deploy=self.deploy, interp_step=self.opt.interp_step) + else: + for i in range(self.numClasses): + self.class_B = self.Tensor(self.numValid).long().fill_(i) + self.get_conditions(mode='test') + + if self.isTrain: + self.fake_B[i, :, :, :, :] = self.g_running.infer(self.reals, self.gen_conditions) + else: + self.fake_B[i, :, :, :, :] = self.netG.infer(self.reals, self.gen_conditions) + + cyc_input = self.fake_B[i, :, :, :, :] + + if self.isTrain: + self.cyc_A[i, :, :, :, :] = self.g_running.infer(cyc_input, self.cyc_conditions) + else: + self.cyc_A[i, :, :, :, :] = self.netG.infer(cyc_input, self.cyc_conditions) + + visuals = self.get_visuals() + + return visuals + + + def save(self, which_epoch): + self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) + self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) + if self.use_moving_avg: + self.save_network(self.g_running, 'g_running', which_epoch, self.gpu_ids) + + + def update_learning_rate(self): + lr = self.old_lr * self.opt.decay_gamma + for param_group in self.optimizer_D.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + mult = param_group.get('mult', 1.0) + param_group['lr'] = lr * mult + if self.opt.verbose: + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr + + + def get_visuals(self): + return_dicts = [OrderedDict() for i in range(self.numValid)] + + real_A = util.tensor2im(self.reals.data) + fake_B_tex = util.tensor2im(self.fake_B.data) + + if self.debug_mode: + rec_A_tex = util.tensor2im(self.cyc_A.data[:,:,:,:,:]) + + if self.numValid == 1: + real_A = np.expand_dims(real_A, axis=0) + + for i in range(self.numValid): + # get the original image and the results for the current samples + curr_real_A = real_A[i, :, :, :] + real_A_img = curr_real_A[:, :, :3] + + # start with age progression/regression images + if self.traverse or self.deploy: + curr_fake_B_tex = fake_B_tex + orig_dict = OrderedDict([('orig_img', real_A_img)]) + else: + curr_fake_B_tex = fake_B_tex[:, i, :, :, :] + orig_dict = OrderedDict([('orig_img_cls_' + str(self.class_A[i].item()), real_A_img)]) + + return_dicts[i].update(orig_dict) + + # set output classes numebr + if self.traverse: + out_classes = curr_fake_B_tex.shape[0] + else: + out_classes = self.numClasses + + for j in range(out_classes): + fake_res_tex = curr_fake_B_tex[j, :, :, :3] + fake_dict_tex = OrderedDict([('tex_trans_to_class_' + str(j), fake_res_tex)]) + return_dicts[i].update(fake_dict_tex) + + if not (self.traverse or self.deploy): + if self.debug_mode: + # continue with tex reconstructions + curr_rec_A_tex = rec_A_tex[:, i, :, :, :] + orig_dict = OrderedDict([('orig_img2', real_A_img)]) + return_dicts[i].update(orig_dict) + for j in range(self.numClasses): + rec_res_tex = curr_rec_A_tex[j, :, :, :3] + rec_dict_tex = OrderedDict([('tex_rec_from_class_' + str(j), rec_res_tex)]) + return_dicts[i].update(rec_dict_tex) + + return return_dicts + + +class InferenceModel(LATS): + def forward(self, data): + return self.inference(data) diff --git a/models/base_model.py b/models/base_model.py new file mode 100755 index 0000000..bb0c9bd --- /dev/null +++ b/models/base_model.py @@ -0,0 +1,99 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import os +import torch +import torch.nn as nn +import sys + +class BaseModel(torch.nn.Module): + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # used in test time, no backprop + def test(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, gpu_ids): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + if isinstance(network,nn.DataParallel): + torch.save(network.module.state_dict(), save_path) + else: + torch.save(network.state_dict(), save_path) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label, save_dir=''): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + if not save_dir: + save_dir = self.save_dir + save_path = os.path.join(save_dir, save_filename) + if not os.path.isfile(save_path): + print('%s not exists yet!' % save_path) + if 'G' in network_label: + raise('Generator must exist!') + else: + try: + if isinstance(network,nn.DataParallel): + network.module.load_state_dict(torch.load(save_path)) + else: + network.load_state_dict(torch.load(save_path)) + except: + pretrained_dict = torch.load(save_path) + if isinstance(network,nn.DataParallel): + model_dict = network.module.state_dict() + else: + model_dict = network.state_dict() + try: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + network.load_state_dict(pretrained_dict) + if self.opt.verbose: + print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) + except: + print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) + if sys.version_info >= (3,0): + not_initialized = set() + else: + from sets import Set + not_initialized = Set() + for k, v in pretrained_dict.items(): + if v.size() == model_dict[k].size(): + model_dict[k] = v + + for k, v in model_dict.items(): + if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): + not_initialized.add(k.split('.')[0]) + if self.opt.verbose: + print(sorted(not_initialized)) + network.load_state_dict(model_dict) + + def update_learning_rate(): + pass diff --git a/models/models.py b/models/models.py new file mode 100755 index 0000000..e554a8d --- /dev/null +++ b/models/models.py @@ -0,0 +1,16 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import torch + +def create_model(opt): + from .LATS_model import LATS, InferenceModel + if opt.isTrain: + model = LATS() + else: + model = InferenceModel() + + model.initialize(opt) + if opt.verbose: + print("model [%s] was created" % (model.name())) + + return model diff --git a/models/networks.py b/models/networks.py new file mode 100755 index 0000000..502d890 --- /dev/null +++ b/models/networks.py @@ -0,0 +1,800 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F +import functools +from torch.autograd import grad as Grad +from torch.autograd import Function +import numpy as np +from math import sqrt +from pdb import set_trace as st + +############################################################################### +# Functions +############################################################################### +def weights_init(init_type='gaussian'): + def init_fun(m): + classname = m.__class__.__name__ + if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): + if init_type == 'gaussian': + init.normal_(m.weight.data, 0.0, 0.02) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=math.sqrt(2)) + elif init_type == 'default': + pass + else: + assert 0, "Unsupported initialization: {}".format(init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + + return init_fun + + +def get_norm_layer(norm_type='instance'): + if norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'pixel': + norm_layer = PixelNorm + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + +def define_G(input_nc, output_nc, ngf, n_downsample_global=2, + id_enc_norm='pixel', gpu_ids=[], padding_type='reflect', + style_dim=50, init_type='gaussian', + conv_weight_norm=False, decoder_norm='pixel', activation='lrelu', + adaptive_blocks=4, normalize_mlp=False, modulated_conv=False): + + id_enc_norm = get_norm_layer(norm_type=id_enc_norm) + + netG = Generator(input_nc, output_nc, ngf, n_downsampling=n_downsample_global, + id_enc_norm=id_enc_norm, padding_type=padding_type, style_dim=style_dim, + conv_weight_norm=conv_weight_norm, decoder_norm=decoder_norm, + actvn=activation, adaptive_blocks=adaptive_blocks, + normalize_mlp=normalize_mlp, modulated_conv=modulated_conv) + + print(netG) + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + netG.cuda(gpu_ids[0]) + + netG.apply(weights_init(init_type)) + + return netG + +def define_D(input_nc, ndf, n_layers=6, numClasses=2, gpu_ids=[], + init_type='gaussian'): + + netD = StyleGANDiscriminator(input_nc, ndf=ndf, n_layers=n_layers, + numClasses=numClasses) + + print(netD) + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + netD.cuda(gpu_ids[0]) + + netD.apply(weights_init('gaussian')) + + return netD + +def print_network(net): + if isinstance(net, list): + net = net[0] + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + + +############################################################################## +# Data parallel wrapper +############################################################################## +class _CustomDataParallel(nn.DataParallel): + def __init__(self, model): + super(_CustomDataParallel, self).__init__(model) + + def __getattr__(self, name): + try: + return super(_CustomDataParallel, self).__getattr__(name) + except AttributeError: + print(name) + return getattr(self.module, name) + + +############################################################################## +# Losses +############################################################################## +class FeatureConsistency(nn.Module): + def __init__(self): + super(FeatureConsistency, self).__init__() + + def __call__(self,input,target): + return torch.mean(torch.abs(input - target)) + + +class R1_reg(nn.Module): + def __init__(self, lambda_r1=10.0): + super(R1_reg, self).__init__() + self.lambda_r1 = lambda_r1 + + def __call__(self, d_out, d_in): + """Compute gradient penalty: (L2_norm(dy/dx))**2.""" + b = d_in.shape[0] + dydx = torch.autograd.grad(outputs=d_out.mean(), + inputs=d_in, + retain_graph=True, + create_graph=True, + only_inputs=True)[0] + dydx_sq = dydx.pow(2) + assert (dydx_sq.size() == d_in.size()) + r1_reg = dydx_sq.sum() / b + + return r1_reg * self.lambda_r1 + + +class SelectiveClassesNonSatGANLoss(nn.Module): + def __init__(self): + super(SelectiveClassesNonSatGANLoss, self).__init__() + self.sofplus = nn.Softplus() + + def __call__(self, input, target_classes, target_is_real, is_gen=False): + bSize = input.shape[0] + b_ind = torch.arange(bSize).long() + relevant_inputs = input[b_ind, target_classes, :, :] + if target_is_real: + loss = self.sofplus(-relevant_inputs).mean() + else: + loss = self.sofplus(relevant_inputs).mean() + + return loss + +############################################################################## +# Generator +############################################################################## +class EqualLR: + def __init__(self, name): + self.name = name + + def compute_weight(self, module): + weight = getattr(module, self.name + '_orig') + fan_in = weight.data.size(1) * weight.data[0][0].numel() + + return weight * sqrt(2 / fan_in) + + @staticmethod + def apply(module, name): + fn = EqualLR(name) + + weight = getattr(module, name) + del module._parameters[name] + module.register_parameter(name + '_orig', nn.Parameter(weight.data)) + module.register_forward_pre_hook(fn) + + return fn + + def __call__(self, module, input): + weight = self.compute_weight(module) + setattr(module, self.name, weight) + + +def equal_lr(module, name='weight'): + EqualLR.apply(module, name) + + return module + +class PixelNorm(nn.Module): + def __init__(self, num_channels=None): + super().__init__() + # num_channels is only used to match function signature with other normalization layers + # it has no actual use + + def forward(self, input): + return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-5) + +class ModulatedConv2d(nn.Module): + def __init__(self, fin, fout, kernel_size, padding_type='reflect', upsample=False, downsample=False, latent_dim=256, normalize_mlp=False): + super(ModulatedConv2d, self).__init__() + self.in_channels = fin + self.out_channels = fout + self.kernel_size = kernel_size + self.upsample = upsample + self.downsample = downsample + padding_size = kernel_size // 2 + if kernel_size == 1: + self.demudulate = False + else: + self.demudulate = True + + self.weight = nn.Parameter(torch.Tensor(fout, fin, kernel_size, kernel_size)) + self.bias = nn.Parameter(torch.Tensor(1, fout, 1, 1)) + self.conv = F.conv2d + + if normalize_mlp: + self.mlp_class_std = nn.Sequential(EqualLinear(latent_dim, fin), PixelNorm()) + else: + self.mlp_class_std = EqualLinear(latent_dim, fin) + + self.blur = Blur(fout) + + if padding_type == 'reflect': + self.padding = nn.ReflectionPad2d(padding_size) + else: + self.padding = nn.ZeroPad2d(padding_size) + + if self.upsample: + self.upsampler = nn.Upsample(scale_factor=2, mode='nearest') + + if self.downsample: + self.downsampler = nn.AvgPool2d(2) + + self.weight.data.normal_() + self.bias.data.zero_() + + def forward(self, input, latent): + fan_in = self.weight.data.size(1) * self.weight.data[0][0].numel() + weight = self.weight * sqrt(2 / fan_in) + weight = weight.view(1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size) + + s = 1 + self.mlp_class_std(latent).view(-1, 1, self.in_channels, 1, 1) + weight = s * weight + if self.demudulate: + d = torch.rsqrt((weight ** 2).sum(4).sum(3).sum(2) + 1e-5).view(-1, self.out_channels, 1, 1, 1) + weight = (d * weight).view(-1, self.in_channels, self.kernel_size, self.kernel_size) + else: + weight = weight.view(-1, self.in_channels, self.kernel_size, self.kernel_size) + + if self.upsample: + input = self.upsampler(input) + + if self.downsample: + input = self.blur(input) + + b,_,h,w = input.shape + input = input.view(1,-1,h,w) + input = self.padding(input) + out = self.conv(input, weight, groups=b).view(b, self.out_channels, h, w) + self.bias + + if self.downsample: + out = self.downsampler(out) + + if self.upsample: + out = self.blur(out) + + return out + +class EqualConv2d(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + conv = nn.Conv2d(*args, **kwargs) + conv.weight.data.normal_() + conv.bias.data.zero_() + self.conv = equal_lr(conv) + + def forward(self, input): + return self.conv(input) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + + linear = nn.Linear(in_dim, out_dim) + linear.weight.data.normal_() + linear.bias.data.zero_() + + self.linear = equal_lr(linear) + + def forward(self, input): + return self.linear(input) + +class BlurFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, kernel, kernel_flip): + ctx.save_for_backward(kernel, kernel_flip) + + grad_input = F.conv2d( + grad_output, kernel_flip, padding=1, groups=grad_output.shape[1] + ) + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_output): + kernel, kernel_flip = ctx.saved_tensors + + grad_input = F.conv2d( + gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1] + ) + + return grad_input, None, None + +class BlurFunction(Function): + @staticmethod + def forward(ctx, input, kernel, kernel_flip): + ctx.save_for_backward(kernel, kernel_flip) + + output = F.conv2d(input, kernel, padding=1, groups=input.shape[1]) + + return output + + @staticmethod + def backward(ctx, grad_output): + kernel, kernel_flip = ctx.saved_tensors + + grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip) + + return grad_input, None, None + +blur = BlurFunction.apply + +class Blur(nn.Module): + def __init__(self, channel): + super().__init__() + + weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32) + weight = weight.view(1, 1, 3, 3) + weight = weight / weight.sum() + weight_flip = torch.flip(weight, [2, 3]) + + self.register_buffer('weight', weight.repeat(channel, 1, 1, 1)) + self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1)) + + def forward(self, input): + return blur(input, self.weight, self.weight_flip) + +class MLP(nn.Module): + def __init__(self, input_dim, out_dim, fc_dim, n_fc, + weight_norm=False, activation='relu', normalize_mlp=False):#, pixel_norm=False): + super(MLP, self).__init__() + if weight_norm: + linear = EqualLinear + else: + linear = nn.Linear + + if activation == 'lrelu': + actvn = nn.LeakyReLU(0.2,True) + elif activation == 'blrelu': + actvn = BidirectionalLeakyReLU() + else: + actvn = nn.ReLU(True) + + self.input_dim = input_dim + self.model = [] + + # normalize input + if normalize_mlp: + self.model += [PixelNorm()] + + # set the first layer + self.model += [linear(input_dim, fc_dim), + actvn] + if normalize_mlp: + self.model += [PixelNorm()] + + # set the inner layers + for i in range(n_fc - 2): + self.model += [linear(fc_dim, fc_dim), + actvn] + if normalize_mlp: + self.model += [PixelNorm()] + + # set the last layer + self.model += [linear(fc_dim, out_dim)] # no output activations + + # normalize output + if normalize_mlp: + self.model += [PixelNorm()] + + self.model = nn.Sequential(*self.model) + + def forward(self, input): + out = self.model(input) + return out + +class StyledConvBlock(nn.Module): + def __init__(self, fin, fout, latent_dim=256, padding='reflect', upsample=False, downsample=False, + actvn='lrelu', use_pixel_norm=False, normalize_affine_output=False, modulated_conv=False): + super(StyledConvBlock, self).__init__() + if not modulated_conv: + if padding == 'reflect': + padding_layer = nn.ReflectionPad2d + else: + padding_layer = nn.ZeroPad2d + + if modulated_conv: + conv2d = ModulatedConv2d + else: + conv2d = EqualConv2d + + if modulated_conv: + self.actvn_gain = sqrt(2) + else: + self.actvn_gain = 1.0 + + self.use_pixel_norm = use_pixel_norm + self.upsample = upsample + self.downsample = downsample + self.modulated_conv = modulated_conv + + if actvn == 'relu': + activation = nn.ReLU(True) + else: + activation = nn.LeakyReLU(0.2,True) + + if self.downsample: + self.downsampler = nn.AvgPool2d(2) + + if self.modulated_conv: + self.conv0 = conv2d(fin, fout, kernel_size=3, padding_type=padding, upsample=upsample, + latent_dim=latent_dim, normalize_mlp=normalize_affine_output) + else: + conv0 = conv2d(fin, fout, kernel_size=3) + if self.upsample: + seq0 = [self.upsampler, padding_layer(1), conv0, Blur(fout)] + else: + seq0 = [padding_layer(1), conv0] + self.conv0 = nn.Sequential(*seq0) + + if use_pixel_norm: + self.pxl_norm0 = PixelNorm() + + self.actvn0 = activation + + if self.modulated_conv: + self.conv1 = conv2d(fout, fout, kernel_size=3, padding_type=padding, downsample=downsample, + latent_dim=latent_dim, normalize_mlp=normalize_affine_output) + else: + conv1 = conv2d(fout, fout, kernel_size=3) + if self.downsample: + seq1 = [Blur(fout), padding_layer(1), conv1, self.downsampler] + else: + seq1 = [padding_layer(1), conv1] + self.conv1 = nn.Sequential(*seq1) + + if use_pixel_norm: + self.pxl_norm1 = PixelNorm() + + self.actvn1 = activation + + def forward(self, input, latent=None): + if self.modulated_conv: + out = self.conv0(input,latent) + else: + out = self.conv0(input) + + out = self.actvn0(out) * self.actvn_gain + if self.use_pixel_norm: + out = self.pxl_norm0(out) + + if self.modulated_conv: + out = self.conv1(out,latent) + else: + out = self.conv1(out) + + out = self.actvn1(out) * self.actvn_gain + if self.use_pixel_norm: + out = self.pxl_norm1(out) + + return out + +class IdentityEncoder(nn.Module): + def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=7, + norm_layer=PixelNorm, padding_type='reflect', + conv_weight_norm=False, actvn='relu'): + assert(n_blocks >= 0) + super(IdentityEncoder, self).__init__() + + if padding_type == 'reflect': + padding_layer = nn.ReflectionPad2d + else: + padding_layer = nn.ZeroPad2d + + if conv_weight_norm: + conv2d = EqualConv2d + else: + conv2d = nn.Conv2d + + if actvn == 'lrelu': + activation = nn.LeakyReLU(0.2, True) + else: + activation = nn.ReLU(True) + + encoder = [padding_layer(3), conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] + ### downsample + for i in range(n_downsampling): + mult = 2**i + encoder += [padding_layer(1), + conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0), + norm_layer(ngf * mult * 2), activation] + + ### resnet blocks + mult = 2**n_downsampling + for i in range(n_blocks): + encoder += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, + norm_layer=norm_layer, conv_weight_norm=conv_weight_norm)] + + self.encoder = nn.Sequential(*encoder) + + def forward(self, input): + return self.encoder(input) + +class AgeEncoder(nn.Module): + def __init__(self, input_nc, ngf=64, n_downsampling=4, style_dim=50, padding_type='reflect', + conv_weight_norm=False, actvn='lrelu'): + super(AgeEncoder, self).__init__() + + if padding_type == 'reflect': + padding_layer = nn.ReflectionPad2d + else: + padding_layer = nn.ZeroPad2d + + if conv_weight_norm: + conv2d = EqualConv2d + else: + conv2d = nn.Conv2d + + if actvn == 'lrelu': + activation = nn.LeakyReLU(0.2, True) + else: + activation = nn.ReLU(True) + + encoder = [padding_layer(3), conv2d(input_nc, ngf, kernel_size=7, padding=0), activation] + ### downsample + for i in range(n_downsampling): + mult = 2**i + encoder += [padding_layer(1), + conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0), + activation] + + encoder += [conv2d(ngf * mult * 2, style_dim, kernel_size=1, stride=1, padding=0)] + + self.encoder = nn.Sequential(*encoder) + + def forward(self, input): + features = self.encoder(input) + latent = features.mean(dim=3).mean(dim=2) + return latent + +class StyledDecoder(nn.Module): + def __init__(self, output_nc, ngf=64, style_dim=50, latent_dim=256, n_downsampling=2, + padding_type='reflect', actvn='lrelu', use_tanh=True, use_pixel_norm=False, + normalize_mlp=False, modulated_conv=False): + super(StyledDecoder, self).__init__() + if padding_type == 'reflect': + padding_layer = nn.ReflectionPad2d + else: + padding_layer = nn.ZeroPad2d + + mult = 2**n_downsampling + last_upconv_out_layers = ngf * mult // 4 + + self.StyledConvBlock_0 = StyledConvBlock(ngf * mult, ngf * mult, latent_dim=latent_dim, + padding=padding_type, actvn=actvn, + use_pixel_norm=use_pixel_norm, + normalize_affine_output=normalize_mlp, + modulated_conv=modulated_conv) + + self.StyledConvBlock_1 = StyledConvBlock(ngf * mult, ngf * mult, latent_dim=latent_dim, + padding=padding_type, actvn=actvn, + use_pixel_norm=use_pixel_norm, + normalize_affine_output=normalize_mlp, + modulated_conv=modulated_conv) + + self.StyledConvBlock_2 = StyledConvBlock(ngf * mult, ngf * mult, latent_dim=latent_dim, + padding=padding_type, actvn=actvn, + use_pixel_norm=use_pixel_norm, + normalize_affine_output=normalize_mlp, + modulated_conv=modulated_conv) + + self.StyledConvBlock_3 = StyledConvBlock(ngf * mult, ngf * mult, latent_dim=latent_dim, + padding=padding_type, actvn=actvn, + use_pixel_norm=use_pixel_norm, + normalize_affine_output=normalize_mlp, + modulated_conv=modulated_conv) + + self.StyledConvBlock_up0 = StyledConvBlock(ngf * mult, ngf * mult // 2, latent_dim=latent_dim, + padding=padding_type, upsample=True, actvn=actvn, + use_pixel_norm=use_pixel_norm, + normalize_affine_output=normalize_mlp, + modulated_conv=modulated_conv) + self.StyledConvBlock_up1 = StyledConvBlock(ngf * mult // 2, last_upconv_out_layers, latent_dim=latent_dim, + padding=padding_type, upsample=True, actvn=actvn, + use_pixel_norm=use_pixel_norm, + normalize_affine_output=normalize_mlp, + modulated_conv=modulated_conv) + + self.conv_img = nn.Sequential(EqualConv2d(last_upconv_out_layers, output_nc, 1), nn.Tanh()) + self.mlp = MLP(style_dim, latent_dim, 256, 8, weight_norm=True, activation=actvn, normalize_mlp=normalize_mlp) + + def forward(self, id_features, target_age=None, traverse=False, deploy=False, interp_step=0.5): + if target_age is not None: + if traverse: + alphas = torch.arange(1,0,step=-interp_step).view(-1,1).cuda() + interps = len(alphas) + orig_class_num = target_age.shape[0] + output_classes = interps * (orig_class_num - 1) + 1 + temp_latent = self.mlp(target_age) + latent = temp_latent.new_zeros((output_classes, temp_latent.shape[1])) + else: + latent = self.mlp(target_age) + else: + latent = None + + if traverse: + id_features = id_features.repeat(output_classes,1,1,1) + for i in range(orig_class_num-1): + latent[interps*i:interps*(i+1), :] = alphas * temp_latent[i,:] + (1 - alphas) * temp_latent[i+1,:] + latent[-1,:] = temp_latent[-1,:] + elif deploy: + output_classes = target_age.shape[0] + id_features = id_features.repeat(output_classes,1,1,1) + + out = self.StyledConvBlock_0(id_features, latent) + out = self.StyledConvBlock_1(out, latent) + out = self.StyledConvBlock_2(out, latent) + out = self.StyledConvBlock_3(out, latent) + out = self.StyledConvBlock_up0(out, latent) + out = self.StyledConvBlock_up1(out, latent) + out = self.conv_img(out) + + return out + +class Generator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, style_dim=50, n_downsampling=2, + n_blocks=4, adaptive_blocks=4, id_enc_norm=PixelNorm, + padding_type='reflect', conv_weight_norm=False, + decoder_norm='pixel', actvn='lrelu', normalize_mlp=False, + modulated_conv=False): + super(Generator, self).__init__() + self.id_encoder = IdentityEncoder(input_nc, ngf, n_downsampling, n_blocks, id_enc_norm, + padding_type, conv_weight_norm=conv_weight_norm, + actvn='relu') # replacing relu with leaky relu here causes nans and the entire training to collapse immediately + self.age_encoder = AgeEncoder(input_nc, ngf=ngf, n_downsampling=4, style_dim=style_dim, + padding_type=padding_type, actvn=actvn, + conv_weight_norm=conv_weight_norm) + + use_pixel_norm = decoder_norm == 'pixel' + self.decoder = StyledDecoder(output_nc, ngf=ngf, style_dim=style_dim, + n_downsampling=n_downsampling, actvn=actvn, + use_pixel_norm=use_pixel_norm, + normalize_mlp=normalize_mlp, + modulated_conv=modulated_conv) + + def encode(self, input): + if torch.is_tensor(input): + id_features = self.id_encoder(input) + age_features = self.age_encoder(input) + return id_features, age_features + else: + return None, None + + def decode(self, id_features, target_age_features, traverse=False, deploy=False, interp_step=0.5): + if torch.is_tensor(id_features): + return self.decoder(id_features, target_age_features, traverse=traverse, deploy=deploy, interp_step=interp_step) + else: + return None + + #parallel forward + def forward(self, input, target_age_code, cyc_age_code, source_age_code, disc_pass=False): + orig_id_features = self.id_encoder(input) + orig_age_features = self.age_encoder(input) + if disc_pass: + rec_out = None + else: + rec_out = self.decode(orig_id_features, source_age_code) + + gen_out = self.decode(orig_id_features, target_age_code) + if disc_pass: + fake_id_features = None + fake_age_features = None + cyc_out = None + else: + fake_id_features = self.id_encoder(gen_out) + fake_age_features = self.age_encoder(gen_out) + cyc_out = self.decode(fake_id_features, cyc_age_code) + return rec_out, gen_out, cyc_out, orig_id_features, orig_age_features, fake_id_features, fake_age_features + + + def infer(self, input, target_age_features, traverse=False, deploy=False, interp_step=0.5): + id_features = self.id_encoder(input) + out = self.decode(id_features, target_age_features, traverse=traverse, deploy=deploy, interp_step=interp_step) + return out + +# Define a resnet block +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), + conv_weight_norm=False, use_pixel_norm=False): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, + conv_weight_norm, use_pixel_norm) + + def build_conv_block(self, dim, padding_type, norm_layer, activation, conv_weight_norm, use_pixel_norm): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + if conv_weight_norm: + conv2d = EqualConv2d + else: + conv2d = nn.Conv2d + + self.use_pixel_norm = use_pixel_norm + if self.use_pixel_norm: + self.pixel_norm = PixelNorm() + + conv_block += [conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim), + activation] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + +############################################################################## +# Discriminator +############################################################################## +class StyleGANDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=6, numClasses=2, padding_type='reflect'): + super(StyleGANDiscriminator, self).__init__() + self.n_layers = n_layers + if padding_type == 'reflect': + padding_layer = nn.ReflectionPad2d + else: + padding_layer = nn.ZeroPad2d + + activation = nn.LeakyReLU(0.2,True) + + sequence = [padding_layer(0), EqualConv2d(input_nc, ndf, kernel_size=1), activation] + + nf = ndf + for n in range(n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [StyledConvBlock(nf_prev, nf, downsample=True, actvn=activation)] + + self.model = nn.Sequential(*sequence) + + output_nc = numClasses + self.gan_head = nn.Sequential(padding_layer(1), EqualConv2d(nf+1, nf, kernel_size=3), activation, + EqualConv2d(nf, output_nc, kernel_size=4), activation) + + def minibatch_stdev(self, input): + out_std = torch.sqrt(input.var(0, unbiased=False) + 1e-8) + mean_std = out_std.mean() + mean_std = mean_std.expand(input.size(0), 1, input.size(2), input.size(3)) + out = torch.cat((input, mean_std), 1) + return out + + def forward(self, input): + features = self.model(input) + gan_out = self.gan_head(self.minibatch_stdev(features)) + return gan_out diff --git a/options/__init__.py b/options/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/options/__pycache__/__init__.cpython-310.pyc b/options/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..ea106a5 Binary files /dev/null and b/options/__pycache__/__init__.cpython-310.pyc differ diff --git a/options/__pycache__/__init__.cpython-38.pyc b/options/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..62bb8be Binary files /dev/null and b/options/__pycache__/__init__.cpython-38.pyc differ diff --git a/options/__pycache__/base_options.cpython-310.pyc b/options/__pycache__/base_options.cpython-310.pyc new file mode 100644 index 0000000..5b51dbd Binary files /dev/null and b/options/__pycache__/base_options.cpython-310.pyc differ diff --git a/options/__pycache__/base_options.cpython-38.pyc b/options/__pycache__/base_options.cpython-38.pyc new file mode 100644 index 0000000..2279178 Binary files /dev/null and b/options/__pycache__/base_options.cpython-38.pyc differ diff --git a/options/__pycache__/test_options.cpython-310.pyc b/options/__pycache__/test_options.cpython-310.pyc new file mode 100644 index 0000000..04ed6db Binary files /dev/null and b/options/__pycache__/test_options.cpython-310.pyc differ diff --git a/options/__pycache__/test_options.cpython-38.pyc b/options/__pycache__/test_options.cpython-38.pyc new file mode 100644 index 0000000..76aef36 Binary files /dev/null and b/options/__pycache__/test_options.cpython-38.pyc differ diff --git a/options/base_options.py b/options/base_options.py new file mode 100755 index 0000000..700a4fb --- /dev/null +++ b/options/base_options.py @@ -0,0 +1,118 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import argparse +import os +from util import util +import torch + +class BaseOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser() + self.initialized = False + + def initialize(self): + # experiment specifics + self.parser.add_argument('--name', type=str, default='debug', help='name of the experiment. It decides where to store samples and models') + self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + + # input/output sizes + self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') + self.parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size') + self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') + self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') + self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') + + # for setting inputs + self.parser.add_argument('--dataroot', type=str, default='./datasets/males/') + self.parser.add_argument('--sort_classes', type=bool, default=True, help='a flag that indicates whether to sort the classes') + self.parser.add_argument('--sort_order', type=str, default='0-2,3-6,7-9,15-19,30-39,50-69', help='a specific order to sort the classes, must contain all classes, only works when sort_classes is true') + self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') + self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') + self.parser.add_argument('--nThreads', default=4, type=int, help='# threads for loading data') + self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + self.parser.add_argument('--display_single_pane_ncols', type=int, default=6, help='if positive, display all images in a single visdom web panel with certain number of images per row.') + + # for displays + self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') + self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') + self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') + + # for generator + self.parser.add_argument('--use_modulated_conv', type=bool, default=True, help='if specified, use modulated conv layers in the decoder like in StyleGAN2') + self.parser.add_argument('--conv_weight_norm', type=bool, default=True, help='if specified, use weight normalization in conv and linear layers like in progrssive growing of GANs') + self.parser.add_argument('--id_enc_norm', type=str, default='pixel', help='instance, pixel normalization') + self.parser.add_argument('--decoder_norm',type=str, default='pixel', choices=['pixel','none'], help='type of upsampling layers normalization') + self.parser.add_argument('--n_adaptive_blocks', type=int, default=4, help='# of adaptive normalization blocks') + self.parser.add_argument('--activation',type=str, default='lrelu', choices=['relu','lrelu'], help='type of generator activation layer') + self.parser.add_argument('--normalize_mlp', type=bool, default=True, help='if specified, normalize the generator MLP inputs and outputs') + self.parser.add_argument('--no_moving_avg', action='store_true', help='if specified, do not use moving average network') + self.parser.add_argument('--use_resblk_pixel_norm', action='store_true', help='if specified, apply pixel norm on the resnet block outputs') + self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') + self.parser.add_argument('--no_cond_noise', action='store_true', help='remove gaussian noise from latent age code') + self.parser.add_argument('--gen_dim_per_style', type=int, default=50, help='per class dimension of adain generator style latent code') + self.parser.add_argument('--n_downsample', type=int, default=2, help='number of downsampling layers in generator') + self.parser.add_argument('--verbose', action='store_true', default = False, help='toggles verbose') + + self.initialized = True + + def parse(self, save=True): + if not self.initialized: + self.initialize() + try: + self.opt = self.parser.parse_args() + except: # solves argparse error in google colab + self.opt = self.parser.parse_args(args=[]) + + self.opt.isTrain = self.isTrain # train or test + + str_ids = self.opt.gpu_ids.split(',') + self.opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + self.opt.gpu_ids.append(id) + + # set gpu ids + if len(self.opt.gpu_ids) > 0: + torch.cuda.set_device(self.opt.gpu_ids[0]) + + # set class specific sort order + if self.opt.sort_order is not None: + order = self.opt.sort_order.split(',') + self.opt.sort_order = [] + for currName in order: + self.opt.sort_order += [currName] + + # set decay schedule + if self.isTrain and self.opt.decay_epochs is not None: + decay_epochs = self.opt.decay_epochs.split(',') + self.opt.decay_epochs = [] + for curr_epoch in decay_epochs: + self.opt.decay_epochs += [int(curr_epoch)] + + # create full image paths in traverse/deploy mode + if (not self.isTrain) and (self.opt.traverse or self.opt.deploy): + with open(self.opt.image_path_file,'r') as f: + # temp_paths = f.read().splitlines() + self.opt.image_path_list = f.read().splitlines() + + args = vars(self.opt) + + print('------------ Options -------------') + for k, v in sorted(args.items()): + print('%s: %s' % (str(k), str(v))) + print('-------------- End ----------------') + + # save to the disk + expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) + util.mkdirs(expr_dir) + if save:# and not self.opt.continue_train: + file_name = os.path.join(expr_dir, 'opt.txt') + with open(file_name, 'wt') as opt_file: + opt_file.write('------------ Options -------------\n') + for k, v in sorted(args.items()): + opt_file.write('%s: %s\n' % (str(k), str(v))) + opt_file.write('-------------- End ----------------\n') + return self.opt diff --git a/options/test_options.py b/options/test_options.py new file mode 100755 index 0000000..d44b8ec --- /dev/null +++ b/options/test_options.py @@ -0,0 +1,25 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +from .base_options import BaseOptions + +class TestOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + self.parser.add_argument('--random_seed', type=int, default=-1, help='random seed for generating different outputs from the same model.') + self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') + self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') + self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') + self.parser.add_argument('--in_the_wild', action='store_true', help='for evaluating on in the wild images') + self.parser.add_argument('--traverse', action='store_true', help='when true, run latent space traversal on a list of images') + self.parser.add_argument('--full_progression', action='store_true', help='when true, deploy mode saves all outputs as a single image') + self.parser.add_argument('--make_video', action='store_true', help='when true, make a video from the traversal results') + self.parser.add_argument('--compare_to_trained_outputs', action='store_true', help='when true, interpolate a trained class in order to compare to trained outputs') + self.parser.add_argument('--compare_to_trained_class', type=int, default=1, help='what class to compare to') + self.parser.add_argument('--trained_class_jump', type=int, default=1, choices=[1,2],help='how many classes to jump') + self.parser.add_argument('--interp_step', type=float, default=0.5, help='step size of interpolated w space vectors between each 2 true w space vectors') + self.parser.add_argument('--deploy', action='store_true', help='when true, run forward pass on a list of images') + self.parser.add_argument('--image_path_file', type=str, help='a file with a list of images to perform run through the network and/or latent space traversal on') + self.parser.add_argument('--debug_mode', action='store_true', help='when true, all intermediate outputs are saved to the html file') + self.isTrain = False diff --git a/options/train_options.py b/options/train_options.py new file mode 100755 index 0000000..608f33d --- /dev/null +++ b/options/train_options.py @@ -0,0 +1,41 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +from .base_options import BaseOptions + +class TrainOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + + # for displays + self.parser.add_argument('--display_freq', type=int, default=40, help='frequency of showing training results on screen') + self.parser.add_argument('--print_freq', type=int, default=40, help='frequency of showing training results on console') + self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + self.parser.add_argument('--save_display_freq', type=int, default=5000, help='save the current display of results every save_display_freq_iterations') + self.parser.add_argument('--save_epoch_freq', type=int, default=20, help='frequency of saving checkpoints at the end of epochs') + self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') + + # for training & optimizer + self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') + self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model. This flag must be used when the continue_train flag is on') + self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + self.parser.add_argument('--epochs', type=int, default=400, help='# of epochs to train') + self.parser.add_argument('--decay_gamma', type=float, default=0.5, help='decay the learning rate by this value') + self.parser.add_argument('--decay_epochs', type=str, default='50,100', help='epochs to perform step lr decay') + self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam') + self.parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') + self.parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam') + self.parser.add_argument('--decay_adain_affine_layers', type=bool, default=True, help='when true adain affine layer learning rate is decayed by 0.01') + + # for discriminators + self.parser.add_argument('--n_layers_D', type=int, default=6, help='number of styled convolution layers in the discriminator') + self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') + + # loss weights + self.parser.add_argument('--lambda_cyc', type=float, default=10.0, help='weight for cycle loss') + self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss') + self.parser.add_argument('--lambda_id', type=float, default=1.0, help='weight for identity encoding consistency loss') + self.parser.add_argument('--lambda_age', type=float, default=1.0, help='weight for age encoding consistency loss') + + self.isTrain = True diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3618282 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +opencv-python +visdom +dominate +numpy +scipy +pillow +unidecode +requests +tqdm +dlib diff --git a/run.ipynb b/run.ipynb new file mode 100644 index 0000000..d732f8a --- /dev/null +++ b/run.ipynb @@ -0,0 +1 @@ +{"cells":[{"cell_type":"markdown","metadata":{"id":"tlwOLgIhTBQ8"},"source":["# Use Model\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"E9aZlYl0MZom"},"outputs":[],"source":["# from google.colab import drive\n","# drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"J2uubo7PsvxQ"},"outputs":[],"source":["%cd /face\n","!pip3 install -r requirements.txt\n","!python download_models.py"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"rQKoh2xrw697"},"outputs":[],"source":["import os\n","from collections import OrderedDict\n","from options.test_options import TestOptions\n","from data.data_loader import CreateDataLoader\n","from models.models import create_model\n","import util.util as util\n","from util.visualizer import Visualizer\n","\n","opt = TestOptions().parse(save=False)\n","opt.display_id = 0 # do not launch visdom\n","opt.nThreads = 1 # test code only supports nThreads = 1\n","opt.batchSize = 1 # test code only supports batchSize = 1\n","opt.serial_batches = True # no shuffle\n","opt.no_flip = True # no flip\n","opt.in_the_wild = True # This triggers preprocessing of in the wild images in the dataloader\n","opt.traverse = True # This tells the model to traverse the latent space between anchor classes\n","opt.interp_step = 0.05 # this controls the number of images to interpolate between anchor classes\n","\n","data_loader = CreateDataLoader(opt)\n","dataset = data_loader.load_data()\n","visualizer = Visualizer(opt)"]},{"cell_type":"markdown","metadata":{"id":"jDnq2nS7T0QC"},"source":["# CoLab"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"FZsgYQSBl_9I"},"outputs":[],"source":["opt.name = 'males_model' # or 'femail_model'\n","model = create_model(opt)\n","model.eval()\n","\n","# upload your image (the code supports only a single image at a time)\n","from google.colab import files\n","uploaded = files.upload()\n","for filename in uploaded.keys():\n"," img_path = filename\n"," print('User uploaded file \"{name}\"'.format(name=filename))"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Av4qUAeemAgQ"},"outputs":[],"source":["data = dataset.dataset.get_item_from_path(img_path)\n","visuals = model.inference(data)\n","\n","os.makedirs('results', exist_ok=True)\n","out_path = os.path.join('results', os.path.splitext(img_path)[0].replace(' ', '_') + '.mp4')\n","visualizer.make_video(visuals, out_path)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dDA0JPOfmCAG"},"outputs":[],"source":["# Result\n","# use_webm = False\n","\n","!pip3 install webm\n","webm_out_path = os.path.join('results', os.path.splitext(img_path)[0].replace(' ', '_') + '.webm')\n","# !webm -i $out_path $webm_out_path\n","use_webm = True\n","\n","from IPython.display import HTML\n","from base64 import b64encode\n","video_path = webm_out_path if use_webm else out_path\n","video_type = \"video/webm\" if use_webm else \"video/mp4\"\n","mp4 = open(video_path,'rb').read()\n","data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n","HTML(\"\"\"\n","\n","\"\"\".format(opt.fineSize, data_url, video_type))"]},{"cell_type":"markdown","metadata":{"id":"_n9pJtLIT7zp"},"source":["# FLASK"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"5Dz3AeiTOs8c"},"outputs":[],"source":["!pip install flask-ngrok\n","!pip install pyngrok==4.1.1\n","!ngrok authtoken ''"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"DdPwCWosOub2"},"outputs":[],"source":["import os\n","from crypt import methods\n","from flask_ngrok import run_with_ngrok\n","import urllib.request\n","from flask import Flask, flash, request, redirect, url_for, render_template\n","from werkzeug.utils import secure_filename\n","import pandas as pd\n","\n","UPLOAD_FOLDER = 'static/uploads/'\n","\n","app = Flask(__name__)\n","run_with_ngrok(app) \n","\n","app.secret_key = \"secret key\"\n","app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER\n","app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6M1bMmluUhHb"},"outputs":[],"source":["ALLOWED_EXTENSIONS = set(['png', 'jpg', 'jpeg', 'gif'])\n","\n","def allowed_file(filename):\n","\treturn '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS\n","\n","@app.route('/', methods=['GET', 'POST'])\n","\n","def index():\n"," return render_template('index.html', title='face GAN')\n","\n","@app.route('/index', methods=['POST'])\n","\n","def aging():\n"," age = request.form['age']\n"," target_age = request.form['target_age']\n"," gender = request.form['gender']\n"," image = request.files['image']\n"," path = os.path.join(app.config[\"UPLOAD_FOLDER\"], image.filename)\n"," image.save(path)\n"," \n"," if gender == \"male\":\n"," input = pd.DataFrame({\n"," 'age' : [int(age)],\n"," 'target_age' :[int(target_age)]\n"," })\n","\n"," opt.name = 'males_model'\n"," model = create_model(opt)\n"," model.eval()\n","\n"," elif gender == \"female\":\n"," input = pd.DataFrame({\n"," 'age' : [int(age)],\n"," 'target_age' :[int(target_age)]\n"," })\n","\n"," opt.name = 'females_model'\n"," model = create_model(opt)\n"," model.eval() \n","\n"," name = f'./static/uploads/{image.filename}'\n"," data = dataset.dataset.get_item_from_path(name)\n"," visuals = model.inference(data)\n","\n"," # Model running (images)\n"," os.makedirs(f'results/{image.filename}', exist_ok=True)\n"," out_pathi = f'./results/{image.filename}' \n","\n"," visualizer.save_images_deploy(visuals, out_pathi)\n"," \n"," # Model running (video)\n"," os.makedirs('results', exist_ok=True)\n"," out_pathv = os.path.join('results', os.path.splitext(name)[0].replace(' ', '_') + '.webm')\n"," visualizer.make_video(visuals, out_pathv)\n","\n"," return render_template('output.html', filename=image.filename) #Output = ModelOutput)\n","\n","if __name__ == \"__main__\":\n"," app.run()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["@inproceedings{orel2020lifespan,\n"," title={Lifespan Age Transformation Synthesis},\n"," author={Or-El, Roy\n"," and Sengupta, Soumyadip\n"," and Fried, Ohad\n"," and Shechtman, Eli\n"," and Kemelmacher-Shlizerman, Ira},\n"," booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},\n"," year={2020}\n","}"]}],"metadata":{"accelerator":"GPU","colab":{"collapsed_sections":[],"name":"run.ipynb","private_outputs":true,"provenance":[]},"gpuClass":"standard","kernelspec":{"display_name":"Python 3.8.13 ('krc3')","language":"python","name":"python3"},"language_info":{"name":"python","version":"3.8.13"},"vscode":{"interpreter":{"hash":"9ab68bc9b37578d62bcac5013120789b9661e00d8bcff4fda830d0f3b16c2dab"}}},"nbformat":4,"nbformat_minor":0} diff --git a/run_flask.py b/run_flask.py new file mode 100644 index 0000000..076801f --- /dev/null +++ b/run_flask.py @@ -0,0 +1,101 @@ +from crypt import methods +from flask_ngrok import run_with_ngrok +from flask import Flask +from flask import render_template , url_for, redirect, flash +from flask import request +import pandas as pd +import os +from werkzeug.utils import secure_filename + +# MODEL +import os +from collections import OrderedDict +from options.test_options import TestOptions +from data.data_loader import CreateDataLoader +from models.models import create_model +import util.util as util +from util.visualizer import Visualizer + +opt = TestOptions().parse(save=False) +opt.display_id = 0 # do not launch visdom +opt.nThreads = 1 # test code only supports nThreads = 1 +opt.batchSize = 1 # test code only supports batchSize = 1 +opt.serial_batches = True # no shuffle +opt.no_flip = True # no flip +opt.in_the_wild = True # This triggers preprocessing of in the wild images in the dataloader +opt.traverse = True # This tells the model to traverse the latent space between anchor classes +opt.interp_step = 0.05 # this controls the number of images to interpolate between anchor classes + +data_loader = CreateDataLoader(opt) +dataset = data_loader.load_data() +visualizer = Visualizer(opt) + +# !ngrok authtoken +UPLOAD_FOLDER = 'static/uploads/' +app = Flask(__name__, static_folder='static') +# run_with_ngrok(app) + +app.secret_key = "secret key" +app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER +app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 + +ALLOWED_EXTENSIONS = set(['png', 'jpg', 'jpeg', 'gif']) + +def allowed_file(filename): + return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + +@app.route('/', methods=['GET', 'POST']) + +def index(): + return render_template('index.html', title='face GAN') + +@app.route('/index', methods=['POST']) + +def aging(): + age = request.form['age'] + target_age = request.form['target_age'] + gender = request.form['gender'] + image = request.files['image'] + path = os.path.join(app.config["UPLOAD_FOLDER"], image.filename) + image.save(path) + + if gender == "male": + input = pd.DataFrame({ + 'age' : [int(age)], + 'target_age' :[int(target_age)] + }) + + opt.name = 'males_model' + model = create_model(opt) + model.eval() + + elif gender == "female": + input = pd.DataFrame({ + 'age' : [int(age)], + 'target_age' :[int(target_age)] + }) + + opt.name = 'females_model' + model = create_model(opt) + model.eval() + + name = f'./static/uploads/{image.filename}' + data = dataset.dataset.get_item_from_path(name) + visuals = model.inference(data) + + # Model running (images) + os.makedirs(f'results/{image.filename}', exist_ok=True) + out_pathi = f'./results/{image.filename}' + + visualizer.save_images_deploy(visuals, out_pathi) + + # Model running (video) + os.makedirs('static', exist_ok=True) + out_pathv = os.path.join('static', os.path.splitext(name)[0].replace(' ', '_') + '.webm') + visualizer.make_video(visuals, out_pathv) + + return render_template('output.html', filename=image.filename) #Output = ModelOutput) + +if __name__ == "__main__": + app.run() +# app.run(host='0.0.0.0', port=5000, debug=True) \ No newline at end of file diff --git a/run_scripts/deploy.bat b/run_scripts/deploy.bat new file mode 100755 index 0000000..9328456 --- /dev/null +++ b/run_scripts/deploy.bat @@ -0,0 +1,5 @@ +@echo off + +set CUDA_VISIBLE_DEVICES=0 + +python test.py --dataroot ./datasets/males --name males_model --which_epoch latest --display_id 0 --deploy --image_path_file males_image_list.txt --full_progression --verbose diff --git a/run_scripts/deploy.sh b/run_scripts/deploy.sh new file mode 100755 index 0000000..7cfbf18 --- /dev/null +++ b/run_scripts/deploy.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python test.py --dataroot ./datasets/males --name males_model --which_epoch latest --display_id 0 --deploy --image_path_file males_image_list.txt --full_progression --verbose diff --git a/run_scripts/in_the_wild.bat b/run_scripts/in_the_wild.bat new file mode 100755 index 0000000..63aac44 --- /dev/null +++ b/run_scripts/in_the_wild.bat @@ -0,0 +1,5 @@ +@echo off + +set CUDA_VISIBLE_DEVICES=0 + +python test.py --name males_model --which_epoch latest --display_id 0 --traverse --interp_step 0.05 --image_path_file males_image_list.txt --make_video --in_the_wild --verbose diff --git a/run_scripts/in_the_wild.sh b/run_scripts/in_the_wild.sh new file mode 100755 index 0000000..dfb74c3 --- /dev/null +++ b/run_scripts/in_the_wild.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python test.py --name males_model --which_epoch latest --display_id 0 --traverse --interp_step 0.05 --image_path_file males_image_list.txt --make_video --in_the_wild --verbose diff --git a/run_scripts/test.bat b/run_scripts/test.bat new file mode 100755 index 0000000..8d99fa5 --- /dev/null +++ b/run_scripts/test.bat @@ -0,0 +1,5 @@ +@echo off + +set CUDA_VISIBLE_DEVICES=0 + +python test.py --verbose --dataroot ./datasets/males --name males_model --which_epoch latest --how_many 100 --display_id 0 diff --git a/run_scripts/test.sh b/run_scripts/test.sh new file mode 100755 index 0000000..884555a --- /dev/null +++ b/run_scripts/test.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python test.py --verbose --dataroot ./datasets/males --name males_model --which_epoch latest --how_many 100 --display_id 0 diff --git a/run_scripts/train.bat b/run_scripts/train.bat new file mode 100755 index 0000000..75521d6 --- /dev/null +++ b/run_scripts/train.bat @@ -0,0 +1,5 @@ +@echo off + +set CUDA_VISIBLE_DEVICES=0,1,2,3 + +python train.py --gpu_ids 0,1,2,3 --dataroot ./datasets/males --name males_model --batchSize 6 --verbose diff --git a/run_scripts/train.sh b/run_scripts/train.sh new file mode 100755 index 0000000..e4bdfcf --- /dev/null +++ b/run_scripts/train.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --gpu_ids 0,1,2,3 --dataroot ./datasets/males --name males_model --batchSize 6 --verbose diff --git a/run_scripts/traversal.bat b/run_scripts/traversal.bat new file mode 100755 index 0000000..564dd79 --- /dev/null +++ b/run_scripts/traversal.bat @@ -0,0 +1,5 @@ +@echo off + +set CUDA_VISIBLE_DEVICES=0 + +python test.py --dataroot ./datasets/males --name males_model --which_epoch latest --display_id 0 --traverse --interp_step 0.05 --image_path_file males_image_list.txt --make_video --verbose diff --git a/run_scripts/traversal.sh b/run_scripts/traversal.sh new file mode 100755 index 0000000..1590afc --- /dev/null +++ b/run_scripts/traversal.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python test.py --dataroot ./datasets/males --name males_model --which_epoch latest --display_id 0 --traverse --interp_step 0.05 --image_path_file males_image_list.txt --make_video --verbose diff --git a/static/script.js b/static/script.js new file mode 100755 index 0000000..595e754 --- /dev/null +++ b/static/script.js @@ -0,0 +1,107 @@ +$(document).ready(function () { + + var timer = null; + var self = $(".wrap button"); + var clicked = false; + $(".wrap button").on("click", function () { + if (clicked === false){ + self.removeClass("filled"); + self.addClass("circle"); + self.html(""); + clicked = true; + $("svg").css("display", "block"); + $(".circle_2").attr("class", "circle_2 fill_circle"); + + timer = setInterval( + function tick() { + self.removeClass("circle"); + self.addClass("filled"); + // self.html("b"); + $(".wrap img").css("display", "block"); + $("svg").css("display", "none"); + clearInterval(timer); + }, 2500); + } + }); +}); + +//selecting all required elements +const dropArea = document.querySelector(".drag-area"), +dragText = dropArea.querySelector("header"), +button = dropArea.querySelector("button"), +input = dropArea.querySelector("input"); +let file; //this is a global variable and we'll use it inside multiple functions +button.onclick = ()=>{ + input.click(); //if user click on the button then the input also clicked +} +input.addEventListener("change", function(){ + //getting user select file and [0] this means if user select multiple files then we'll select only the first one + file = this.files[0]; + dropArea.classList.add("active"); + showFile(); //calling function +}); +//If user Drag File Over DropArea +dropArea.addEventListener("dragover", (event)=>{ + event.preventDefault(); //preventing from default behaviour + dropArea.classList.add("active"); + dragText.textContent = "업로드 할 사진을 놓으세요"; +}); +//If user leave dragged File from DropArea +dropArea.addEventListener("dragleave", ()=>{ + dropArea.classList.remove("active"); + dragText.textContent = "사진을 끌어오세요"; +}); +//If user drop File on DropArea +dropArea.addEventListener("drop", (event)=>{ + event.preventDefault(); //preventing from default behaviour + //getting user select file and [0] this means if user select multiple files then we'll select only the first one + file = event.dataTransfer.files[0]; + showFile(); //calling function +}); +function showFile(){ + let fileType = file.type; //getting selected file type + let validExtensions = ["image/jpeg", "image/jpg", "image/png"]; //adding some valid image extensions in array + if(validExtensions.includes(fileType)){ //if user selected file is an image file + let fileReader = new FileReader(); //creating new FileReader object + fileReader.onload = ()=>{ + let fileURL = fileReader.result; //passing user file source in fileURL variable + // UNCOMMENT THIS BELOW LINE. I GOT AN ERROR WHILE UPLOADING THIS POST SO I COMMENTED IT + let imgTag = `image`; //creating an img tag and passing user selected file source inside src attribute + dropArea.innerHTML = imgTag; //adding that created img tag inside dropArea container + } + fileReader.readAsDataURL(file); + }else{ + alert("이미지 파일이 아닙니다!"); + dropArea.classList.remove("active"); + dragText.textContent = "사진을 끌어오세요"; + } +} + +// $(document).ready(function () { +// alert('document loaded') +// var timer = null; +// var self = $(".wrap button"); +// var clicked = false; +// $(".wrap button").on("click", function () { +// if (clicked === false){ +// self.removeClass("filled"); +// self.addClass("circle"); +// self.html(""); +// clicked = true; +// $("svg").css("display", "block"); +// $(".circle_2").attr("class", "circle_2 fill_circle"); + +// timer = setInterval( +// function tick() { +// self.removeClass("circle"); +// self.addClass("filled"); +// // self.html("b"); +// $(".wrap img").css("display", "block"); +// $("svg").css("display", "none"); +// clearInterval(timer); +// }, 2500); +// } +// }); +// }); + + diff --git a/static/style.css b/static/style.css new file mode 100755 index 0000000..f0e687a --- /dev/null +++ b/static/style.css @@ -0,0 +1,346 @@ +body { + height: 70vh; + -webkit-text-size-adjust: 100%; + -webkit-font-smoothing: antialiased; + + align-items: center; + text-align: center; + align-content: center; + font-family: "Lato"; + justify-content: center; + + min-height: 70vh; + background: #031132; +} + +* { + box-sizing: border-box; +} + +.inp { + position: relative; + margin: auto; + width: 100%; + display: flex; + max-width: 280px; + border-radius: 3px; + overflow: hidden; +} +.inp .label { + position: absolute; + top: 20px; + left: 12px; + font-size: 16px; + color: rgb(255, 255, 255); + font-weight: 500; + transform-origin: 0 0; + transform: translate3d(0, 0, 0); + transition: all 0.2s ease; + pointer-events: none; +} +.inp .focus-bg { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + background: rgba(255, 255, 255, 0.05); + z-index: -1; + transform: scaleX(0); + +} +.inp input { + -webkit-appearance: none; + -moz-appearance: none; + appearance: none; + width: 100%; + border: 0; + font-family: inherit; + padding: 16px 12px 0 12px; + height: 56px; + font-size: 16px; + font-weight: 400; + background: rgba(255, 255, 255, 0.02); + box-shadow: inset 0 -1px 0 rgba(255, 255, 255, 0.3); + color: rgb(190, 190, 190); + transition: all 0.15s ease; +} +.inp input:hover { + background: rgba(255, 255, 255, 0.04); + box-shadow: inset 0 -1px 0 rgba(255, 255, 255, 0.5); +} +.inp input:not(:-moz-placeholder-shown) + .label { + color: rgba(255, 255, 255, 0.5); + transform: translate3d(0, -12px, 0) scale(0.75); +} +.inp input:not(:-ms-input-placeholder) + .label { + color: rgba(255, 255, 255, 0.5); + transform: translate3d(0, -12px, 0) scale(0.75); +} +.inp input:not(:placeholder-shown) + .label { + color: rgba(255, 255, 255, 0.5); + transform: translate3d(0, -12px, 0) scale(0.75); +} +.inp input:focus { + background: rgba(255, 255, 255, 0.05); + outline: none; + box-shadow: inset 0 -2px 0 #1ECD97; +} +.inp input:focus + .label { + color: #1ECD97; + transform: translate3d(0, -12px, 0) scale(0.75); +} +.inp input:focus + .label + .focus-bg { + transform: scaleX(1); + transition: all 0.1s ease; +} + + +.drag-area { + border: 2px dashed #fff; + height: 350px; + width: 500px; + border-radius: 5px; + display: flex; + align-items: center; + justify-content: center; + flex-direction: column; +} +.drag-area.active { + border: 2px solid #fff; +} +.drag-area .icon { + font-size: 100px; + color: #fff; +} +.drag-area header { + font-size: 30px; + font-weight: 500; + color: #fff; +} +.drag-area span { + font-size: 25px; + font-weight: 500; + color: #fff; + margin: 10px 0 15px 0; +} +.drag-area button { + padding: 10px 25px; + font-size: 20px; + font-weight: 500; + border: none; + outline: none; + background: #fff; + color: #5256ad; + border-radius: 5px; + cursor: pointer; +} +.drag-area img { + height: 100%; + width: 100%; + object-fit: cover; + border-radius: 5px; +} + + + + +html { + line-height: 1; +} + +ol, ul { + list-style: none; +} + +table { + border-collapse: collapse; + border-spacing: 0; +} + +caption, th, td { + text-align: left; + font-weight: normal; + vertical-align: middle; +} + +q, blockquote { + quotes: none; +} +q:before, q:after, blockquote:before, blockquote:after { + content: ""; + content: none; +} + +a img { + border: none; +} + +article, aside, details, figcaption, figure, footer, header, hgroup, main, menu, nav, section, summary { + display: block; +} + + +.wrap { + position: relative; + margin: auto; + margin-top: 3%; + width: 191px; + text-align: center; +} +.wrap button { + display: block; + height: 60px; + padding: 0; + width: 191px; + background: none; + margin: auto; + border: 2px solid #1ECD97; + font-size: 18px; + font-family: "Lato"; + color: #1ECD97; + cursor: pointer; + outline: none; + text-align: center; + -moz-box-sizing: border-box; + -webkit-box-sizing: border-box; + box-sizing: border-box; + -moz-border-radius: 30px; + -webkit-border-radius: 30px; + border-radius: 30px; + -moz-transition: background 0.4s, color 0.4s, font-size 0.05s, width 0.4s, border 0.4s; + -o-transition: background 0.4s, color 0.4s, font-size 0.05s, width 0.4s, border 0.4s; + -webkit-transition: background 0.4s, color 0.4s, font-size 0.05s, width 0.4s, border 0.4s; + transition: background 0.4s, color 0.4s, font-size 0.05s, width 0.4s, border 0.4s; +} +.wrap button:hover { + background: #1ECD97; + color: white; +} +.wrap img { + position: absolute; + top: 11px; + display: none; + left: 71.5px; + -moz-transform: scale(0.6, 0.6); + -ms-transform: scale(0.6, 0.6); + -webkit-transform: scale(0.6, 0.6); + transform: scale(0.6, 0.6); +} +.wrap svg { + -moz-transform: rotate(270deg); + -ms-transform: rotate(270deg); + -webkit-transform: rotate(270deg); + transform: rotate(270deg); + /* @include rotate(270deg); */ + position: absolute; + top: -2px; + left: 62px; + display: none; +} +.wrap svg .circle_2 { + stroke-dasharray: 0 200; +} +.wrap svg .fill_circle { + -moz-animation: fill-stroke 2s 0.4s linear forwards; + -webkit-animation: fill-stroke 2s 0.4s linear forwards; + animation: fill-stroke 2s 0.4s linear forwards; +} +.wrap .circle { + width: 60px; + border: 3px solid #c3c3c3; + /* border: none; */ +} +.wrap .circle:hover { + background: none; +} +.wrap .filled { + background: #1ECD97; + color: white; + line-height: 60px; + font-size: 160%; +} + +footer p { + color: #738087; + margin-top: 100px; + font-size: 18px; + line-height: 28px; +} + +@-moz-keyframes fill-stroke { + 0% { + stroke-dasharray: 0 200; + } + 20% { + stroke-dasharray: 20 200; + } + 40% { + stroke-dasharray: 30 200; + } + 50% { + stroke-dasharray: 90 200; + } + 70% { + stroke-dasharray: 120 200; + } + 90% { + stroke-dasharray: 140 200; + } + 100% { + stroke-dasharray: 182 200; + } +} +@-webkit-keyframes fill-stroke { + 0% { + stroke-dasharray: 0 200; + } + 20% { + stroke-dasharray: 20 200; + } + 40% { + stroke-dasharray: 30 200; + } + 50% { + stroke-dasharray: 90 200; + } + 70% { + stroke-dasharray: 120 200; + } + 90% { + stroke-dasharray: 140 200; + } + 100% { + stroke-dasharray: 182 200; + } +} +@keyframes fill-stroke { + 0% { + stroke-dasharray: 0 200; + } + 20% { + stroke-dasharray: 20 200; + } + 40% { + stroke-dasharray: 30 200; + } + 50% { + stroke-dasharray: 90 200; + } + 70% { + stroke-dasharray: 120 200; + } + 90% { + stroke-dasharray: 140 200; + } + 100% { + stroke-dasharray: 182 200; + } +} +a, p { + line-height: 1.6em; +} + +a { + color: #738087; +} diff --git a/templates/index.html b/templates/index.html new file mode 100755 index 0000000..f756cad --- /dev/null +++ b/templates/index.html @@ -0,0 +1,70 @@ + + + + + face GAN + + + + + + + + + +

Face Aging GAN

+ +
+
+
+
사진을 끌어오세요
+ OR + + + + + +
+ + + + + + +
+ +
+ +
+ +
+ +
+ + 남성
+ 여성
+ +
+ + + +
+ + + + + +
+
+ + + + diff --git a/templates/output.html b/templates/output.html new file mode 100644 index 0000000..3f0d66b --- /dev/null +++ b/templates/output.html @@ -0,0 +1,20 @@ + + + + + face GAN + + + + +
+

Result!!

+ +
+ + + + + \ No newline at end of file diff --git a/test.py b/test.py new file mode 100755 index 0000000..0d38d56 --- /dev/null +++ b/test.py @@ -0,0 +1,89 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import os +import scipy # this is to prevent a potential error caused by importing torch before scipy (happens due to a bad combination of torch & scipy versions) +from collections import OrderedDict +from options.test_options import TestOptions +from data.data_loader import CreateDataLoader +from models.models import create_model +import util.util as util +from util.visualizer import Visualizer +from util import html +import torch +from pdb import set_trace as st + + +def test(opt): + opt.nThreads = 1 # test code only supports nThreads = 1 + opt.batchSize = 1 # test code only supports batchSize = 1 + opt.serial_batches = True # no shuffle + opt.no_flip = True # no flip + + data_loader = CreateDataLoader(opt) + dataset = data_loader.load_data() + dataset_size = len(data_loader) + print('#test batches = %d' % (int(dataset_size / len(opt.sort_order)))) + visualizer = Visualizer(opt) + model = create_model(opt) + model.eval() + + # create webpage + if opt.random_seed != -1: + exp_dir = '%s_%s_seed%s' % (opt.phase, opt.which_epoch, str(opt.random_seed)) + else: + exp_dir = '%s_%s' % (opt.phase, opt.which_epoch) + web_dir = os.path.join(opt.results_dir, opt.name, exp_dir) + + if opt.traverse or opt.deploy: + if opt.traverse: + out_dirname = 'traversal' + else: + out_dirname = 'deploy' + output_dir = os.path.join(web_dir,out_dirname) + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + + for image_path in opt.image_path_list: + print(image_path) + data = dataset.dataset.get_item_from_path(image_path) + visuals = model.inference(data) + if opt.traverse and opt.make_video: + out_path = os.path.join(output_dir, os.path.splitext(os.path.basename(image_path))[0] + '.mp4') + visualizer.make_video(visuals, out_path) + elif opt.traverse or (opt.deploy and opt.full_progression): + if opt.traverse and opt.compare_to_trained_outputs: + out_path = os.path.join(output_dir, os.path.splitext(os.path.basename(image_path))[0] + '_compare_to_{}_jump_{}.png'.format(opt.compare_to_trained_class, opt.trained_class_jump)) + else: + out_path = os.path.join(output_dir, os.path.splitext(os.path.basename(image_path))[0] + '.png') + visualizer.save_row_image(visuals, out_path, traverse=opt.traverse) + else: + out_path = os.path.join(output_dir, os.path.basename(image_path[:-4])) + visualizer.save_images_deploy(visuals, out_path) + else: + webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) + + # test + for i, data in enumerate(dataset): + if i >= opt.how_many: + break + + visuals = model.inference(data) + img_path = data['Paths'] + rem_ind = [] + for i, path in enumerate(img_path): + if path != '': + print('process image... %s' % path) + else: + rem_ind += [i] + + for ind in reversed(rem_ind): + del img_path[ind] + + visualizer.save_images(webpage, visuals, img_path) + + webpage.save() + + +if __name__ == "__main__": + opt = TestOptions().parse(save=False) + test(opt) diff --git a/train.py b/train.py new file mode 100755 index 0000000..69a6827 --- /dev/null +++ b/train.py @@ -0,0 +1,159 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import time +import scipy # this is to prevent a potential error caused by importing torch before scipy (happens due to a bad combination of torch & scipy versions) +from collections import OrderedDict +from options.train_options import TrainOptions +from data.data_loader import CreateDataLoader +from models.models import create_model +import util.util as util +from util.visualizer import Visualizer +import os +import numpy as np +import torch +from pdb import set_trace as st + +def train(opt): + iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') + + if opt.continue_train: + if opt.which_epoch == 'latest': + try: + start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) + except: + start_epoch, epoch_iter = 1, 0 + else: + start_epoch, epoch_iter = int(opt.which_epoch), 0 + + print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) + for update_point in opt.decay_epochs: + if start_epoch < update_point: + break + + opt.lr *= opt.decay_gamma + else: + start_epoch, epoch_iter = 0, 0 + + data_loader = CreateDataLoader(opt) + dataset = data_loader.load_data() + dataset_size = len(data_loader) + print('#training images = %d' % dataset_size) + + model = create_model(opt) + visualizer = Visualizer(opt) + + total_steps = (start_epoch) * dataset_size + epoch_iter + + display_delta = total_steps % opt.display_freq + print_delta = total_steps % opt.print_freq + save_delta = total_steps % opt.save_latest_freq + bSize = opt.batchSize + + #in case there's no display sample one image from each class to test after every epoch + if opt.display_id == 0: + dataset.dataset.set_sample_mode(True) + dataset.num_workers = 1 + for i, data in enumerate(dataset): + if i*opt.batchSize >= opt.numClasses: + break + if i == 0: + sample_data = data + else: + for key, value in data.items(): + if torch.is_tensor(data[key]): + sample_data[key] = torch.cat((sample_data[key], data[key]), 0) + else: + sample_data[key] = sample_data[key] + data[key] + dataset.num_workers = opt.nThreads + dataset.dataset.set_sample_mode(False) + + for epoch in range(start_epoch, opt.epochs): + epoch_start_time = time.time() + if epoch != start_epoch: + epoch_iter = 0 + for i, data in enumerate(dataset, start=epoch_iter): + iter_start_time = time.time() + total_steps += opt.batchSize + epoch_iter += opt.batchSize + + # whether to collect output images + save_fake = (total_steps % opt.display_freq == display_delta) and (opt.display_id > 0) + + ############## Network Pass ######################## + model.set_inputs(data) + disc_losses = model.update_D() + gen_losses, gen_in, gen_out, rec_out, cyc_out = model.update_G(infer=save_fake) + loss_dict = dict(gen_losses, **disc_losses) + ################################################## + + ############## Display results and errors ########## + ### print out errors + if total_steps % opt.print_freq == print_delta: + errors = {k: v.item() if not (isinstance(v, float) or isinstance(v, int)) else v for k, v in loss_dict.items()} + t = (time.time() - iter_start_time) / opt.batchSize + visualizer.print_current_errors(epoch+1, epoch_iter, errors, t) + if opt.display_id > 0: + visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) + + ### display output images + if save_fake and opt.display_id > 0: + class_a_suffix = ' class {}'.format(data['A_class'][0]) + class_b_suffix = ' class {}'.format(data['B_class'][0]) + classes = None + + visuals = OrderedDict() + visuals_A = OrderedDict([('real image' + class_a_suffix, util.tensor2im(gen_in.data[0]))]) + visuals_B = OrderedDict([('real image' + class_b_suffix, util.tensor2im(gen_in.data[bSize]))]) + + A_out_vis = OrderedDict([('synthesized image' + class_b_suffix, util.tensor2im(gen_out.data[0]))]) + B_out_vis = OrderedDict([('synthesized image' + class_a_suffix, util.tensor2im(gen_out.data[bSize]))]) + if opt.lambda_rec > 0: + A_out_vis.update([('reconstructed image' + class_a_suffix, util.tensor2im(rec_out.data[0]))]) + B_out_vis.update([('reconstructed image' + class_b_suffix, util.tensor2im(rec_out.data[bSize]))]) + if opt.lambda_cyc > 0: + A_out_vis.update([('cycled image' + class_a_suffix, util.tensor2im(cyc_out.data[0]))]) + B_out_vis.update([('cycled image' + class_b_suffix, util.tensor2im(cyc_out.data[bSize]))]) + + visuals_A.update(A_out_vis) + visuals_B.update(B_out_vis) + visuals.update(visuals_A) + visuals.update(visuals_B) + + ncols = len(visuals_A) + visualizer.display_current_results(visuals, epoch, classes, ncols) + + ### save latest model + if total_steps % opt.save_latest_freq == save_delta: + print('saving the latest model (epoch %d, total_steps %d)' % (epoch+1, total_steps)) + model.save('latest') + np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') + if opt.display_id == 0: + model.eval() + visuals = model.inference(sample_data) + visualizer.save_matrix_image(visuals, 'latest') + model.train() + + # end of epoch + iter_end_time = time.time() + print('End of epoch %d / %d \t Time Taken: %d sec' % + (epoch+1, opt.epochs, time.time() - epoch_start_time)) + + ### save model for this epoch + if (epoch+1) % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' % (epoch+1, total_steps)) + model.save('latest') + model.save(epoch+1) + np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') + if opt.display_id == 0: + model.eval() + visuals = model.inference(sample_data) + visualizer.save_matrix_image(visuals, epoch+1) + model.train() + + ### multiply learning rate by opt.decay_gamma after certain iterations + if (epoch+1) in opt.decay_epochs: + model.update_learning_rate() + +if __name__ == "__main__": + opt = TrainOptions().parse() + train(opt) diff --git a/util/__init__.py b/util/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/util/__pycache__/__init__.cpython-310.pyc b/util/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..0893879 Binary files /dev/null and b/util/__pycache__/__init__.cpython-310.pyc differ diff --git a/util/__pycache__/__init__.cpython-38.pyc b/util/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..2a6ebd3 Binary files /dev/null and b/util/__pycache__/__init__.cpython-38.pyc differ diff --git a/util/__pycache__/deeplab.cpython-38.pyc b/util/__pycache__/deeplab.cpython-38.pyc new file mode 100644 index 0000000..09fde7f Binary files /dev/null and b/util/__pycache__/deeplab.cpython-38.pyc differ diff --git a/util/__pycache__/html.cpython-38.pyc b/util/__pycache__/html.cpython-38.pyc new file mode 100644 index 0000000..4547ada Binary files /dev/null and b/util/__pycache__/html.cpython-38.pyc differ diff --git a/util/__pycache__/preprocess_itw_im.cpython-310.pyc b/util/__pycache__/preprocess_itw_im.cpython-310.pyc new file mode 100644 index 0000000..2b1383c Binary files /dev/null and b/util/__pycache__/preprocess_itw_im.cpython-310.pyc differ diff --git a/util/__pycache__/preprocess_itw_im.cpython-38.pyc b/util/__pycache__/preprocess_itw_im.cpython-38.pyc new file mode 100644 index 0000000..2c760fa Binary files /dev/null and b/util/__pycache__/preprocess_itw_im.cpython-38.pyc differ diff --git a/util/__pycache__/util.cpython-310.pyc b/util/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000..4aa3922 Binary files /dev/null and b/util/__pycache__/util.cpython-310.pyc differ diff --git a/util/__pycache__/util.cpython-38.pyc b/util/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000..4a2c672 Binary files /dev/null and b/util/__pycache__/util.cpython-38.pyc differ diff --git a/util/__pycache__/visualizer.cpython-38.pyc b/util/__pycache__/visualizer.cpython-38.pyc new file mode 100644 index 0000000..2e6f2aa Binary files /dev/null and b/util/__pycache__/visualizer.cpython-38.pyc differ diff --git a/util/deeplab.py b/util/deeplab.py new file mode 100644 index 0000000..002bdec --- /dev/null +++ b/util/deeplab.py @@ -0,0 +1,257 @@ +# Copyright (c) 2020, Roy Or-El. All rights reserved. +# +# This work is licensed under the Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +# This file was taken as is from the https://github.com/chenxi116/DeepLabv3.pytorch repository. + +import torch +import torch.nn as nn +import math +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +__all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152'] + + +model_urls = { + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +class Conv2d(nn.Conv2d): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + + def forward(self, x): + # return super(Conv2d, self).forward(x) + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, + keepdim=True).mean(dim=3, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv2d(x, weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +class ASPP(nn.Module): + + def __init__(self, C, depth, num_classes, conv=nn.Conv2d, norm=nn.BatchNorm2d, momentum=0.0003, mult=1): + super(ASPP, self).__init__() + self._C = C + self._depth = depth + self._num_classes = num_classes + + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.relu = nn.ReLU(inplace=True) + self.aspp1 = conv(C, depth, kernel_size=1, stride=1, bias=False) + self.aspp2 = conv(C, depth, kernel_size=3, stride=1, + dilation=int(6*mult), padding=int(6*mult), + bias=False) + self.aspp3 = conv(C, depth, kernel_size=3, stride=1, + dilation=int(12*mult), padding=int(12*mult), + bias=False) + self.aspp4 = conv(C, depth, kernel_size=3, stride=1, + dilation=int(18*mult), padding=int(18*mult), + bias=False) + self.aspp5 = conv(C, depth, kernel_size=1, stride=1, bias=False) + self.aspp1_bn = norm(depth, momentum) + self.aspp2_bn = norm(depth, momentum) + self.aspp3_bn = norm(depth, momentum) + self.aspp4_bn = norm(depth, momentum) + self.aspp5_bn = norm(depth, momentum) + self.conv2 = conv(depth * 5, depth, kernel_size=1, stride=1, + bias=False) + self.bn2 = norm(depth, momentum) + self.conv3 = nn.Conv2d(depth, num_classes, kernel_size=1, stride=1) + + def forward(self, x): + x1 = self.aspp1(x) + x1 = self.aspp1_bn(x1) + x1 = self.relu(x1) + x2 = self.aspp2(x) + x2 = self.aspp2_bn(x2) + x2 = self.relu(x2) + x3 = self.aspp3(x) + x3 = self.aspp3_bn(x3) + x3 = self.relu(x3) + x4 = self.aspp4(x) + x4 = self.aspp4_bn(x4) + x4 = self.relu(x4) + x5 = self.global_pooling(x) + x5 = self.aspp5(x5) + x5 = self.aspp5_bn(x5) + x5 = self.relu(x5) + x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear', + align_corners=True)(x5) + x = torch.cat((x1, x2, x3, x4, x5), 1) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.conv3(x) + + return x + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, conv=None, norm=None): + super(Bottleneck, self).__init__() + self.conv1 = conv(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = norm(planes) + self.conv2 = conv(planes, planes, kernel_size=3, stride=stride, + dilation=dilation, padding=dilation, bias=False) + self.bn2 = norm(planes) + self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = norm(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes, num_groups=None, weight_std=False, beta=False): + self.inplanes = 64 + self.norm = lambda planes, momentum=0.05: nn.BatchNorm2d(planes, momentum=momentum) if num_groups is None else nn.GroupNorm(num_groups, planes) + self.conv = Conv2d if weight_std else nn.Conv2d + + super(ResNet, self).__init__() + if not beta: + self.conv1 = self.conv(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + else: + self.conv1 = nn.Sequential( + self.conv(3, 64, 3, stride=2, padding=1, bias=False), + self.conv(64, 64, 3, stride=1, padding=1, bias=False), + self.conv(64, 64, 3, stride=1, padding=1, bias=False)) + self.bn1 = self.norm(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1, + dilation=2) + self.aspp = ASPP(512 * block.expansion, 256, num_classes, conv=self.conv, norm=self.norm) + + for m in self.modules(): + if isinstance(m, self.conv): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or dilation != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + self.conv(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, dilation=max(1, dilation/2), bias=False), + self.norm(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, dilation=max(1, dilation/2), conv=self.conv, norm=self.norm)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation, conv=self.conv, norm=self.norm)) + + return nn.Sequential(*layers) + + def forward(self, x): + size = (x.shape[2], x.shape[3]) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.aspp(x) + x = nn.Upsample(size, mode='bilinear', align_corners=True)(x) + return x + + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + return model + + +def resnet101(pretrained=False, num_groups=None, weight_std=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], num_groups=num_groups, weight_std=weight_std, **kwargs) + if pretrained: + model_dict = model.state_dict() + if num_groups and weight_std: + pretrained_dict = torch.load('deeplab_model/R-101-GN-WS.pth.tar') + overlap_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict} + assert len(overlap_dict) == 312 + elif not num_groups and not weight_std: + pretrained_dict = model_zoo.load_url(model_urls['resnet101']) + overlap_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + else: + raise ValueError('Currently only support BN or GN+WS') + model_dict.update(overlap_dict) + model.load_state_dict(model_dict) + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + return model diff --git a/util/html.py b/util/html.py new file mode 100755 index 0000000..f144550 --- /dev/null +++ b/util/html.py @@ -0,0 +1,82 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import dominate +import math +from dominate.tags import * +import os + + +class HTML: + def __init__(self, web_dir, title, refresh=0): + self.title = title + self.web_dir = web_dir + self.img_dir = os.path.join(self.web_dir, 'images') + if not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + + self.doc = dominate.document(title=title) + if refresh > 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + def get_image_dir(self): + return self.img_dir + + def add_header(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_images(self, ims, txts, links, width=512, cols=0): + imNum = len(ims) + self.add_table() + with self.t: + if cols == 0: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + else: + rows = int(math.ceil(float(imNum) / float(cols))) + for i in range(rows): + with tr(): + for j in range(cols): + im = ims[i*cols + j] + txt = txts[i*cols + j] + link = links[i*cols + j] + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/util/preprocess_itw_im.py b/util/preprocess_itw_im.py new file mode 100644 index 0000000..989611c --- /dev/null +++ b/util/preprocess_itw_im.py @@ -0,0 +1,188 @@ +import os +import dlib +import shutil +import requests +import numpy as np +import scipy.ndimage +import torch +import torchvision.transforms as transforms +import util.deeplab as deeplab +from PIL import Image +from util.util import download_file +from pdb import set_trace as st + +resnet_file_path = 'deeplab_model/R-101-GN-WS.pth.tar' +deeplab_file_path = 'deeplab_model/deeplab_model.pth' +predictor_file_path = 'util/shape_predictor_68_face_landmarks.dat' +model_fname = 'deeplab_model/deeplab_model.pth' +deeplab_classes = ['background' ,'skin','nose','eye_g','l_eye','r_eye','l_brow','r_brow','l_ear','r_ear','mouth','u_lip','l_lip','hair','hat','ear_r','neck_l','neck','cloth'] + + +class preprocessInTheWildImage(): + def __init__(self, out_size=256): + self.out_size = out_size + + # load landmark detector models + self.detector = dlib.get_frontal_face_detector() + if not os.path.isfile(predictor_file_path): + print('Cannot find landmarks shape predictor model.\n'\ + 'Please run download_models.py to download the model') + raise OSError + + self.predictor = dlib.shape_predictor(predictor_file_path) + + # deeplab data properties + self.deeplab_data_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + self.deeplab_input_size = 513 + + # load deeplab model + assert torch.cuda.is_available() + torch.backends.cudnn.benchmark = True + if not os.path.isfile(resnet_file_path): + print('Cannot find DeeplabV3 backbone Resnet model.\n' \ + 'Please run download_models.py to download the model') + raise OSError + + self.deeplab_model = getattr(deeplab, 'resnet101')( + pretrained=True, + num_classes=len(deeplab_classes), + num_groups=32, + weight_std=True, + beta=False) + + self.deeplab_model.eval() + if not os.path.isfile(deeplab_file_path): + print('Cannot find DeeplabV3 model.\n' \ + 'Please run download_models.py to download the model') + raise OSError + + checkpoint = torch.load(model_fname) + state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items() if 'tracked' not in k} + self.deeplab_model.load_state_dict(state_dict) + + def dlib_shape_to_landmarks(self, shape): + # initialize the list of (x, y)-coordinates + landmarks = np.zeros((68, 2), dtype=np.float32) + # loop over the 68 facial landmarks and convert them + # to a 2-tuple of (x, y)-coordinates + for i in range(0, 68): + landmarks[i] = (shape.part(i).x, shape.part(i).y) + # return the list of (x, y)-coordinates + return landmarks + + def extract_face_landmarks(self, img): + # detect all faces in the image and + # keep the detection with the largest bounding box + dets = self.detector(img, 1) + if len(dets) == 0: + print ('Could not detect any face in the image, please try again with a different image') + raise + + max_area = 0 + max_idx = -1 + for k, d in enumerate(dets): + area = (d.right() - d.left()) * (d.bottom() - d.top()) + if area > max_area: + max_area = area + max_idx = k + + # Get the landmarks/parts for the face in box d. + dlib_shape = self.predictor(img, dets[max_idx]) + landmarks = self.dlib_shape_to_landmarks(dlib_shape) + return landmarks + + def align_in_the_wild_image(self, np_img, lm, transform_size=4096, enable_padding=True): + # Parse landmarks. + lm_chin = lm[0 : 17] # left-right + lm_eyebrow_left = lm[17 : 22] # left-right + lm_eyebrow_right = lm[22 : 27] # left-right + lm_nose = lm[27 : 31] # top-down + lm_nostrils = lm[31 : 36] # top-down + lm_eye_left = lm[36 : 42] # left-clockwise + lm_eye_right = lm[42 : 48] # left-clockwise + lm_mouth_outer = lm[48 : 60] # left-clockwise + lm_mouth_inner = lm[60 : 68] # left-clockwise + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + x /= np.hypot(*x) + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 2.2) # This results in larger crops then the original FFHQ. For the original crops, replace 2.2 with 1.8 + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + qsize = np.hypot(*x) * 2 + + # Load in-the-wild image. + img = Image.fromarray(np_img) + + # Shrink. + shrink = int(np.floor(qsize / self.out_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w, _ = img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3])) + blur = qsize * 0.02 + img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) + img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') + quad += pad[:2] + + # Transform. + img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR) + if self.out_size < transform_size: + img = img.resize((self.out_size, self.out_size), Image.ANTIALIAS) + + return img + + + def get_segmentation_maps(self, img): + img = img.resize((self.deeplab_input_size,self.deeplab_input_size),Image.BILINEAR) + img = self.deeplab_data_transform(img) + img = img.cuda() + self.deeplab_model.cuda() + outputs = self.deeplab_model(img.unsqueeze(0)) + self.deeplab_model.cpu() + _, pred = torch.max(outputs, 1) + pred = pred.data.cpu().numpy().squeeze().astype(np.uint8) + seg_map = Image.fromarray(pred) + seg_map = np.uint8(seg_map.resize((self.out_size,self.out_size), Image.NEAREST)) + return seg_map + + def forward(self, img): + landmarks = self.extract_face_landmarks(img) + aligned_img = self.align_in_the_wild_image(img, landmarks) + seg_map = self.get_segmentation_maps(aligned_img) + aligned_img = np.array(aligned_img.getdata(), dtype=np.uint8).reshape(self.out_size, self.out_size, 3) + return aligned_img, seg_map diff --git a/util/util.py b/util/util.py new file mode 100755 index 0000000..40df520 --- /dev/null +++ b/util/util.py @@ -0,0 +1,196 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import os +import html +import glob +import uuid +import hashlib +import requests +import torch +import zipfile +import numpy as np +from tqdm import tqdm +from PIL import Image +from pdb import set_trace as st + + +males_model_spec = dict(file_url='https://drive.google.com/uc?id=1MsXN54hPi9PWDmn1HKdmKfv-J5hWYFVZ', + alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/males_model.zip', + file_path='checkpoints/males_model.zip', file_size=213175683, file_md5='0079186147ec816176b946a073d1f396') +females_model_spec = dict(file_url='https://drive.google.com/uc?id=1LNm0zAuiY0CIJnI0lHTq1Ttcu9_M1NAJ', + alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/females_model.zip', + file_path='checkpoints/females_model.zip', file_size=213218113, file_md5='0675f809413c026170cf1f22b27f3c5d') +resnet_file_spec = dict(file_url='https://drive.google.com/uc?id=1oRGgrI4KNdefbWVpw0rRkEP1gbJIRokM', + alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/R-101-GN-WS.pth.tar', + file_path='deeplab_model/R-101-GN-WS.pth.tar', file_size=178260167, file_md5='aa48cc3d3ba3b7ac357c1489b169eb32') +deeplab_file_spec = dict(file_url='https://drive.google.com/uc?id=1w2XjDywFr2NjuUWaLQDRktH7VwIfuNlY', + alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/deeplab_model.pth', + file_path='deeplab_model/deeplab_model.pth', file_size=464446305, file_md5='8e8345b1b9d95e02780f9bed76cc0293') +predictor_file_spec = dict(file_url='https://drive.google.com/uc?id=1fhq5lvWy-rjrzuHdMoZfLsULvF0gJGwD', + alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/shape_predictor_68_face_landmarks.dat', + file_path='util/shape_predictor_68_face_landmarks.dat', file_size=99693937, file_md5='73fde5e05226548677a050913eed4e04') + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8, normalize=True): + im_sz = image_tensor.size() + ndims = image_tensor.dim() + if ndims == 2: + image_numpy = image_tensor.cpu().float().numpy() + image_numpy = (image_numpy + 1) / 2.0 * 255.0 + elif ndims == 3: + image_numpy = image_tensor.cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + elif ndims == 4 and im_sz[0] == 1: + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + elif ndims == 4: + image_numpy = image_tensor.cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0 + else: # ndims == 5 + image_numpy = image_tensor.cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (0, 1, 3, 4, 2)) + 1) / 2.0 * 255.0 + + return image_numpy.astype(imtype) + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + +def download_pretrained_models(): + print('Downloading males model') + with requests.Session() as session: + try: + download_file(session, males_model_spec) + except: + print('Google Drive download failed.\n' \ + 'Trying do download from alternate server') + download_file(session, males_model_spec, use_alt_url=True) + + print('Extracting males model zip file') + with zipfile.ZipFile('./checkpoints/males_model.zip','r') as zip_fname: + zip_fname.extractall('./checkpoints') + + print('Done!') + os.remove(males_model_spec['file_path']) + + print('Downloading females model') + with requests.Session() as session: + try: + download_file(session, females_model_spec) + except: + print('Google Drive download failed.\n' \ + 'Trying do download from alternate server') + download_file(session, females_model_spec, use_alt_url=True) + + print('Extracting females model zip file') + with zipfile.ZipFile('./checkpoints/females_model.zip','r') as zip_fname: + zip_fname.extractall('./checkpoints') + + print('Done!') + os.remove(females_model_spec['file_path']) + + print('Downloading face landmarks shape predictor') + with requests.Session() as session: + try: + download_file(session, predictor_file_spec) + except: + print('Google Drive download failed.\n' \ + 'Trying do download from alternate server') + download_file(session, predictor_file_spec, use_alt_url=True) + + print('Done!') + + print('Downloading DeeplabV3 backbone Resnet Model parameters') + with requests.Session() as session: + try: + download_file(session, resnet_file_spec) + except: + print('Google Drive download failed.\n' \ + 'Trying do download from alternate server') + download_file(session, resnet_file_spec, use_alt_url=True) + + print('Done!') + + print('Downloading DeeplabV3 Model parameters') + with requests.Session() as session: + try: + download_file(session, deeplab_file_spec) + except: + print('Google Drive download failed.\n' \ + 'Trying do download from alternate server') + download_file(session, deeplab_file_spec, use_alt_url=True) + + print('Done!') + +def download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10): + file_path = file_spec['file_path'] + if use_alt_url: + file_url = file_spec['alt_url'] + else: + file_url = file_spec['file_url'] + + file_dir = os.path.dirname(file_path) + tmp_path = file_path + '.tmp.' + uuid.uuid4().hex + if file_dir: + os.makedirs(file_dir, exist_ok=True) + + progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True) + for attempts_left in reversed(range(num_attempts)): + data_size = 0 + progress_bar.reset() + try: + # Download. + data_md5 = hashlib.md5() + with session.get(file_url, stream=True) as res: + res.raise_for_status() + with open(tmp_path, 'wb') as f: + for chunk in res.iter_content(chunk_size=chunk_size<<10): + progress_bar.update(len(chunk)) + f.write(chunk) + data_size += len(chunk) + data_md5.update(chunk) + + # Validate. + if 'file_size' in file_spec and data_size != file_spec['file_size']: + raise IOError('Incorrect file size', file_path) + if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']: + raise IOError('Incorrect file MD5', file_path) + break + + except: + # Last attempt => raise error. + if not attempts_left: + raise + + # Handle Google Drive virus checker nag. + if data_size > 0 and data_size < 8192: + with open(tmp_path, 'rb') as f: + data = f.read() + links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'export=download' in link] + if len(links) == 1: + file_url = requests.compat.urljoin(file_url, links[0]) + continue + + progress_bar.close() + + # Rename temp file to the correct name. + os.replace(tmp_path, file_path) # atomic + + # Attempt to clean up any leftover temps. + for filename in glob.glob(file_path + '.tmp.*'): + try: + os.remove(filename) + except: + pass diff --git a/util/visualizer.py b/util/visualizer.py new file mode 100755 index 0000000..a21d04d --- /dev/null +++ b/util/visualizer.py @@ -0,0 +1,239 @@ +### Copyright (C) 2020 Roy Or-El. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import numpy as np +import os +import cv2 +import time +import unidecode +from . import util +from . import html +from pdb import set_trace as st + +class Visualizer(): + def __init__(self, opt): + # self.opt = opt + self.display_id = opt.display_id + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + self.numClasses = opt.numClasses + self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'images') + self.isTrain = opt.isTrain + if self.isTrain: + self.save_freq = opt.save_display_freq + + if self.display_id > 0: + import visdom + self.vis = visdom.Visdom(port = opt.display_port) + self.display_single_pane_ncols = opt.display_single_pane_ncols + + if self.use_html: + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + if self.isTrain: + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, it, classes, ncols): + if self.display_single_pane_ncols > 0: + h, w = next(iter(visuals.values())).shape[:2] + table_css = """""" % (w, h) + # ncols = self.display_single_pane_ncols + title = self.name + label_html = '' + label_html_row = '' + nrows = int(np.ceil(len(visuals.items()) / ncols)) + images = [] + idx = 0 + for label, image_numpy in visuals.items(): + label_html_row += '%s' % label + if image_numpy.ndim < 3: + image_numpy = np.expand_dims(image_numpy, 2) + image_numpy = np.tile(image_numpy, (1, 1, 3)) + + images.append(image_numpy.transpose([2, 0, 1])) + idx += 1 + if idx % ncols == 0: + label_html += '%s' % label_html_row + label_html_row = '' + white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 + while idx % ncols != 0: + images.append(white_image) + label_html_row += '' + idx += 1 + if label_html_row != '': + label_html += '%s' % label_html_row + + self.vis.images(images, nrow=ncols, win=self.display_id + 1, + padding=2, opts=dict(title=title + ' images')) + label_html = '%s
' % label_html + self.vis.text(table_css + label_html, win = self.display_id + 2, + opts=dict(title=title + ' labels')) + else: + idx = 1 + for label, image_numpy in visuals.items(): + self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), + win=self.display_id + idx) + idx += 1 + + + # errors: dictionary of error labels and values + def plot_current_errors(self, epoch, counter_ratio, opt, errors): + if not hasattr(self, 'plot_data'): + self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) + + # errors: same format as |errors| of plotCurrentErrors + def print_current_errors(self, epoch, i, errors, t): + message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) + for k, v in errors.items(): + message += '%s: %.3f ' % (k, v) + + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) + + def save_matrix_image(self, visuals, epoch): + for i in range(len(visuals)): + visual = visuals[i] + orig_img = visual['orig_img_cls_' + str(i)] + curr_row_img = orig_img + for cls in range(self.numClasses): + next_im = visual['tex_trans_to_class_' + str(cls)] + curr_row_img = np.concatenate((curr_row_img, next_im), 1) + + if i == 0: + matrix_img = curr_row_img + else: + matrix_img = np.concatenate((matrix_img, curr_row_img), 0) + + if epoch != 'latest': + epoch_txt = 'epoch_' + str(epoch) + else: + epoch_txt = epochs + + image_path = os.path.join(self.img_dir,'sample_batch_{}.png'.format(epoch_txt)) + util.save_image(matrix_img, image_path) + + def save_row_image(self, visuals, image_path, traverse=False): + visual = visuals[0] + orig_img = visual['orig_img'] + h, w, c = orig_img.shape + traversal_img = np.concatenate((orig_img, np.full((h, 10, c), 255, dtype=np.uint8)), 1) + if traverse: + out_classes = len(visual) - 1 + else: + out_classes = self.numClasses + for cls in range(out_classes): + next_im = visual['tex_trans_to_class_' + str(cls)] + traversal_img = np.concatenate((traversal_img, next_im), 1) + + util.save_image(traversal_img, image_path) + + def make_video(self, visuals, video_path): + fps = 20#25 + visual = visuals[0] + orig_img = visual['orig_img'] + h, w = orig_img.shape[0], orig_img.shape[1] + writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w,h)) + out_classes = len(visual) - 1 + for cls in range(out_classes): + next_im = visual['tex_trans_to_class_' + str(cls)] + writer.write(next_im[:,:,::-1]) + + writer.release() + + # save image to the disk + def save_images_deploy(self, visuals, image_path): + for i in range(len(visuals)): + visual = visuals[i] + for label, image_numpy in visual.items(): + save_path = '%s_%s.png' % (image_path, label) + util.save_image(image_numpy, save_path) + + + # save image to the disk + def save_images(self, webpage, visuals, image_path, gt_visuals=None, gt_path=None): + cols = self.numClasses+1 + image_dir = webpage.get_image_dir() + if gt_visuals == None or gt_path == None: + for i in range(len(visuals)): + visual = visuals[i] + short_path = os.path.basename(image_path[i]) + name = unidecode.unidecode(os.path.splitext(short_path)[0]) #removes accents which cause html load error + webpage.add_header(name) + ims = [] + txts = [] + links = [] + for label, image_numpy in visual.items(): + image_name = '%s_%s.png' % (name, label) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + + webpage.add_images(ims, txts, links, width=self.win_size,cols=cols) + else: + batchSize = len(image_path) + + # save ground truth images + if gt_path is not None: + gt_short_path = os.path.basename(gt_path[0]) + gt_name = os.path.splitext(gt_path)[0] + gt_ims = [] + gt_txts = [] + gt_links = [] + for label, image_numpy in gt_visuals.items(): + image_name = '%s_%s.png' % (gt_name, label) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path) + + gt_ims.append(image_name) + gt_txts.append(label) + gt_links.append(image_name) + + for i in range(batchSize): + short_path = os.path.basename(image_path[i]) + name = os.path.splitext(short_path)[0] + + # webpage.add_header(name) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals[i].items(): + image_name = '%s_%s.png' % (name, label) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + print("saving results for: " + name) + + if gt_path is not None: + webpage.add_header(gt_name) + webpage.add_images(gt_ims, gt_txts, gt_links, width=self.win_size, cols=batchSize) + + webpage.add_header(name) + webpage.add_images(ims, txts, links, width=self.win_size, cols=self.numClasses + 1)