+UPLOAD_FOLDER = 'static/uploads/'
+app = Flask(__name__, static_folder='static')
+# run_with_ngrok(app)
+
+app.secret_key = "secret key"
+app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
+app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
+
+ALLOWED_EXTENSIONS = set(['png', 'jpg', 'jpeg', 'gif'])
+
+def allowed_file(filename):
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
+
+@app.route('/', methods=['GET', 'POST'])
+
+def index():
+ return render_template('index.html', title='face GAN')
+
+@app.route('/index', methods=['POST'])
+
+def aging():
+ age = request.form['age']
+ target_age = request.form['target_age']
+ gender = request.form['gender']
+ image = request.files['image']
+ path = os.path.join(app.config["UPLOAD_FOLDER"], image.filename)
+ image.save(path)
+
+ if gender == "male":
+ input = pd.DataFrame({
+ 'age' : [int(age)],
+ 'target_age' :[int(target_age)]
+ })
+
+ opt.name = 'males_model'
+ model = create_model(opt)
+ model.eval()
+
+ elif gender == "female":
+ input = pd.DataFrame({
+ 'age' : [int(age)],
+ 'target_age' :[int(target_age)]
+ })
+
+ opt.name = 'females_model'
+ model = create_model(opt)
+ model.eval()
+
+ name = f'./static/uploads/{image.filename}'
+ data = dataset.dataset.get_item_from_path(name)
+ visuals = model.inference(data)
+
+ # Model running (images)
+ os.makedirs(f'results/{image.filename}', exist_ok=True)
+ out_pathi = f'./results/{image.filename}'
+
+ visualizer.save_images_deploy(visuals, out_pathi)
+
+ # Model running (video)
+ os.makedirs('static', exist_ok=True)
+ out_pathv = os.path.join('static', os.path.splitext(name)[0].replace(' ', '_') + '.webm')
+ visualizer.make_video(visuals, out_pathv)
+
+ return render_template('output.html', filename=image.filename) #Output = ModelOutput)
+
+if __name__ == "__main__":
+ app.run()
+# app.run(host='0.0.0.0', port=5000, debug=True)
\ No newline at end of file
diff --git a/run_scripts/deploy.bat b/run_scripts/deploy.bat
new file mode 100755
index 0000000..9328456
--- /dev/null
+++ b/run_scripts/deploy.bat
@@ -0,0 +1,5 @@
+@echo off
+
+set CUDA_VISIBLE_DEVICES=0
+
+python test.py --dataroot ./datasets/males --name males_model --which_epoch latest --display_id 0 --deploy --image_path_file males_image_list.txt --full_progression --verbose
diff --git a/run_scripts/deploy.sh b/run_scripts/deploy.sh
new file mode 100755
index 0000000..7cfbf18
--- /dev/null
+++ b/run_scripts/deploy.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0 python test.py --dataroot ./datasets/males --name males_model --which_epoch latest --display_id 0 --deploy --image_path_file males_image_list.txt --full_progression --verbose
diff --git a/run_scripts/in_the_wild.bat b/run_scripts/in_the_wild.bat
new file mode 100755
index 0000000..63aac44
--- /dev/null
+++ b/run_scripts/in_the_wild.bat
@@ -0,0 +1,5 @@
+@echo off
+
+set CUDA_VISIBLE_DEVICES=0
+
+python test.py --name males_model --which_epoch latest --display_id 0 --traverse --interp_step 0.05 --image_path_file males_image_list.txt --make_video --in_the_wild --verbose
diff --git a/run_scripts/in_the_wild.sh b/run_scripts/in_the_wild.sh
new file mode 100755
index 0000000..dfb74c3
--- /dev/null
+++ b/run_scripts/in_the_wild.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0 python test.py --name males_model --which_epoch latest --display_id 0 --traverse --interp_step 0.05 --image_path_file males_image_list.txt --make_video --in_the_wild --verbose
diff --git a/run_scripts/test.bat b/run_scripts/test.bat
new file mode 100755
index 0000000..8d99fa5
--- /dev/null
+++ b/run_scripts/test.bat
@@ -0,0 +1,5 @@
+@echo off
+
+set CUDA_VISIBLE_DEVICES=0
+
+python test.py --verbose --dataroot ./datasets/males --name males_model --which_epoch latest --how_many 100 --display_id 0
diff --git a/run_scripts/test.sh b/run_scripts/test.sh
new file mode 100755
index 0000000..884555a
--- /dev/null
+++ b/run_scripts/test.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0 python test.py --verbose --dataroot ./datasets/males --name males_model --which_epoch latest --how_many 100 --display_id 0
diff --git a/run_scripts/train.bat b/run_scripts/train.bat
new file mode 100755
index 0000000..75521d6
--- /dev/null
+++ b/run_scripts/train.bat
@@ -0,0 +1,5 @@
+@echo off
+
+set CUDA_VISIBLE_DEVICES=0,1,2,3
+
+python train.py --gpu_ids 0,1,2,3 --dataroot ./datasets/males --name males_model --batchSize 6 --verbose
diff --git a/run_scripts/train.sh b/run_scripts/train.sh
new file mode 100755
index 0000000..e4bdfcf
--- /dev/null
+++ b/run_scripts/train.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --gpu_ids 0,1,2,3 --dataroot ./datasets/males --name males_model --batchSize 6 --verbose
diff --git a/run_scripts/traversal.bat b/run_scripts/traversal.bat
new file mode 100755
index 0000000..564dd79
--- /dev/null
+++ b/run_scripts/traversal.bat
@@ -0,0 +1,5 @@
+@echo off
+
+set CUDA_VISIBLE_DEVICES=0
+
+python test.py --dataroot ./datasets/males --name males_model --which_epoch latest --display_id 0 --traverse --interp_step 0.05 --image_path_file males_image_list.txt --make_video --verbose
diff --git a/run_scripts/traversal.sh b/run_scripts/traversal.sh
new file mode 100755
index 0000000..1590afc
--- /dev/null
+++ b/run_scripts/traversal.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0 python test.py --dataroot ./datasets/males --name males_model --which_epoch latest --display_id 0 --traverse --interp_step 0.05 --image_path_file males_image_list.txt --make_video --verbose
diff --git a/static/script.js b/static/script.js
new file mode 100755
index 0000000..595e754
--- /dev/null
+++ b/static/script.js
@@ -0,0 +1,107 @@
+$(document).ready(function () {
+
+ var timer = null;
+ var self = $(".wrap button");
+ var clicked = false;
+ $(".wrap button").on("click", function () {
+ if (clicked === false){
+ self.removeClass("filled");
+ self.addClass("circle");
+ self.html("");
+ clicked = true;
+ $("svg").css("display", "block");
+ $(".circle_2").attr("class", "circle_2 fill_circle");
+
+ timer = setInterval(
+ function tick() {
+ self.removeClass("circle");
+ self.addClass("filled");
+ // self.html("b");
+ $(".wrap img").css("display", "block");
+ $("svg").css("display", "none");
+ clearInterval(timer);
+ }, 2500);
+ }
+ });
+});
+
+//selecting all required elements
+const dropArea = document.querySelector(".drag-area"),
+dragText = dropArea.querySelector("header"),
+button = dropArea.querySelector("button"),
+input = dropArea.querySelector("input");
+let file; //this is a global variable and we'll use it inside multiple functions
+button.onclick = ()=>{
+ input.click(); //if user click on the button then the input also clicked
+}
+input.addEventListener("change", function(){
+ //getting user select file and [0] this means if user select multiple files then we'll select only the first one
+ file = this.files[0];
+ dropArea.classList.add("active");
+ showFile(); //calling function
+});
+//If user Drag File Over DropArea
+dropArea.addEventListener("dragover", (event)=>{
+ event.preventDefault(); //preventing from default behaviour
+ dropArea.classList.add("active");
+ dragText.textContent = "업로드 할 사진을 놓으세요";
+});
+//If user leave dragged File from DropArea
+dropArea.addEventListener("dragleave", ()=>{
+ dropArea.classList.remove("active");
+ dragText.textContent = "사진을 끌어오세요";
+});
+//If user drop File on DropArea
+dropArea.addEventListener("drop", (event)=>{
+ event.preventDefault(); //preventing from default behaviour
+ //getting user select file and [0] this means if user select multiple files then we'll select only the first one
+ file = event.dataTransfer.files[0];
+ showFile(); //calling function
+});
+function showFile(){
+ let fileType = file.type; //getting selected file type
+ let validExtensions = ["image/jpeg", "image/jpg", "image/png"]; //adding some valid image extensions in array
+ if(validExtensions.includes(fileType)){ //if user selected file is an image file
+ let fileReader = new FileReader(); //creating new FileReader object
+ fileReader.onload = ()=>{
+ let fileURL = fileReader.result; //passing user file source in fileURL variable
+ // UNCOMMENT THIS BELOW LINE. I GOT AN ERROR WHILE UPLOADING THIS POST SO I COMMENTED IT
+ let imgTag = `
`; //creating an img tag and passing user selected file source inside src attribute
+ dropArea.innerHTML = imgTag; //adding that created img tag inside dropArea container
+ }
+ fileReader.readAsDataURL(file);
+ }else{
+ alert("이미지 파일이 아닙니다!");
+ dropArea.classList.remove("active");
+ dragText.textContent = "사진을 끌어오세요";
+ }
+}
+
+// $(document).ready(function () {
+// alert('document loaded')
+// var timer = null;
+// var self = $(".wrap button");
+// var clicked = false;
+// $(".wrap button").on("click", function () {
+// if (clicked === false){
+// self.removeClass("filled");
+// self.addClass("circle");
+// self.html("");
+// clicked = true;
+// $("svg").css("display", "block");
+// $(".circle_2").attr("class", "circle_2 fill_circle");
+
+// timer = setInterval(
+// function tick() {
+// self.removeClass("circle");
+// self.addClass("filled");
+// // self.html("b");
+// $(".wrap img").css("display", "block");
+// $("svg").css("display", "none");
+// clearInterval(timer);
+// }, 2500);
+// }
+// });
+// });
+
+
diff --git a/static/style.css b/static/style.css
new file mode 100755
index 0000000..f0e687a
--- /dev/null
+++ b/static/style.css
@@ -0,0 +1,346 @@
+body {
+ height: 70vh;
+ -webkit-text-size-adjust: 100%;
+ -webkit-font-smoothing: antialiased;
+
+ align-items: center;
+ text-align: center;
+ align-content: center;
+ font-family: "Lato";
+ justify-content: center;
+
+ min-height: 70vh;
+ background: #031132;
+}
+
+* {
+ box-sizing: border-box;
+}
+
+.inp {
+ position: relative;
+ margin: auto;
+ width: 100%;
+ display: flex;
+ max-width: 280px;
+ border-radius: 3px;
+ overflow: hidden;
+}
+.inp .label {
+ position: absolute;
+ top: 20px;
+ left: 12px;
+ font-size: 16px;
+ color: rgb(255, 255, 255);
+ font-weight: 500;
+ transform-origin: 0 0;
+ transform: translate3d(0, 0, 0);
+ transition: all 0.2s ease;
+ pointer-events: none;
+}
+.inp .focus-bg {
+ position: absolute;
+ top: 0;
+ left: 0;
+ width: 100%;
+ height: 100%;
+ background: rgba(255, 255, 255, 0.05);
+ z-index: -1;
+ transform: scaleX(0);
+
+}
+.inp input {
+ -webkit-appearance: none;
+ -moz-appearance: none;
+ appearance: none;
+ width: 100%;
+ border: 0;
+ font-family: inherit;
+ padding: 16px 12px 0 12px;
+ height: 56px;
+ font-size: 16px;
+ font-weight: 400;
+ background: rgba(255, 255, 255, 0.02);
+ box-shadow: inset 0 -1px 0 rgba(255, 255, 255, 0.3);
+ color: rgb(190, 190, 190);
+ transition: all 0.15s ease;
+}
+.inp input:hover {
+ background: rgba(255, 255, 255, 0.04);
+ box-shadow: inset 0 -1px 0 rgba(255, 255, 255, 0.5);
+}
+.inp input:not(:-moz-placeholder-shown) + .label {
+ color: rgba(255, 255, 255, 0.5);
+ transform: translate3d(0, -12px, 0) scale(0.75);
+}
+.inp input:not(:-ms-input-placeholder) + .label {
+ color: rgba(255, 255, 255, 0.5);
+ transform: translate3d(0, -12px, 0) scale(0.75);
+}
+.inp input:not(:placeholder-shown) + .label {
+ color: rgba(255, 255, 255, 0.5);
+ transform: translate3d(0, -12px, 0) scale(0.75);
+}
+.inp input:focus {
+ background: rgba(255, 255, 255, 0.05);
+ outline: none;
+ box-shadow: inset 0 -2px 0 #1ECD97;
+}
+.inp input:focus + .label {
+ color: #1ECD97;
+ transform: translate3d(0, -12px, 0) scale(0.75);
+}
+.inp input:focus + .label + .focus-bg {
+ transform: scaleX(1);
+ transition: all 0.1s ease;
+}
+
+
+.drag-area {
+ border: 2px dashed #fff;
+ height: 350px;
+ width: 500px;
+ border-radius: 5px;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ flex-direction: column;
+}
+.drag-area.active {
+ border: 2px solid #fff;
+}
+.drag-area .icon {
+ font-size: 100px;
+ color: #fff;
+}
+.drag-area header {
+ font-size: 30px;
+ font-weight: 500;
+ color: #fff;
+}
+.drag-area span {
+ font-size: 25px;
+ font-weight: 500;
+ color: #fff;
+ margin: 10px 0 15px 0;
+}
+.drag-area button {
+ padding: 10px 25px;
+ font-size: 20px;
+ font-weight: 500;
+ border: none;
+ outline: none;
+ background: #fff;
+ color: #5256ad;
+ border-radius: 5px;
+ cursor: pointer;
+}
+.drag-area img {
+ height: 100%;
+ width: 100%;
+ object-fit: cover;
+ border-radius: 5px;
+}
+
+
+
+
+html {
+ line-height: 1;
+}
+
+ol, ul {
+ list-style: none;
+}
+
+table {
+ border-collapse: collapse;
+ border-spacing: 0;
+}
+
+caption, th, td {
+ text-align: left;
+ font-weight: normal;
+ vertical-align: middle;
+}
+
+q, blockquote {
+ quotes: none;
+}
+q:before, q:after, blockquote:before, blockquote:after {
+ content: "";
+ content: none;
+}
+
+a img {
+ border: none;
+}
+
+article, aside, details, figcaption, figure, footer, header, hgroup, main, menu, nav, section, summary {
+ display: block;
+}
+
+
+.wrap {
+ position: relative;
+ margin: auto;
+ margin-top: 3%;
+ width: 191px;
+ text-align: center;
+}
+.wrap button {
+ display: block;
+ height: 60px;
+ padding: 0;
+ width: 191px;
+ background: none;
+ margin: auto;
+ border: 2px solid #1ECD97;
+ font-size: 18px;
+ font-family: "Lato";
+ color: #1ECD97;
+ cursor: pointer;
+ outline: none;
+ text-align: center;
+ -moz-box-sizing: border-box;
+ -webkit-box-sizing: border-box;
+ box-sizing: border-box;
+ -moz-border-radius: 30px;
+ -webkit-border-radius: 30px;
+ border-radius: 30px;
+ -moz-transition: background 0.4s, color 0.4s, font-size 0.05s, width 0.4s, border 0.4s;
+ -o-transition: background 0.4s, color 0.4s, font-size 0.05s, width 0.4s, border 0.4s;
+ -webkit-transition: background 0.4s, color 0.4s, font-size 0.05s, width 0.4s, border 0.4s;
+ transition: background 0.4s, color 0.4s, font-size 0.05s, width 0.4s, border 0.4s;
+}
+.wrap button:hover {
+ background: #1ECD97;
+ color: white;
+}
+.wrap img {
+ position: absolute;
+ top: 11px;
+ display: none;
+ left: 71.5px;
+ -moz-transform: scale(0.6, 0.6);
+ -ms-transform: scale(0.6, 0.6);
+ -webkit-transform: scale(0.6, 0.6);
+ transform: scale(0.6, 0.6);
+}
+.wrap svg {
+ -moz-transform: rotate(270deg);
+ -ms-transform: rotate(270deg);
+ -webkit-transform: rotate(270deg);
+ transform: rotate(270deg);
+ /* @include rotate(270deg); */
+ position: absolute;
+ top: -2px;
+ left: 62px;
+ display: none;
+}
+.wrap svg .circle_2 {
+ stroke-dasharray: 0 200;
+}
+.wrap svg .fill_circle {
+ -moz-animation: fill-stroke 2s 0.4s linear forwards;
+ -webkit-animation: fill-stroke 2s 0.4s linear forwards;
+ animation: fill-stroke 2s 0.4s linear forwards;
+}
+.wrap .circle {
+ width: 60px;
+ border: 3px solid #c3c3c3;
+ /* border: none; */
+}
+.wrap .circle:hover {
+ background: none;
+}
+.wrap .filled {
+ background: #1ECD97;
+ color: white;
+ line-height: 60px;
+ font-size: 160%;
+}
+
+footer p {
+ color: #738087;
+ margin-top: 100px;
+ font-size: 18px;
+ line-height: 28px;
+}
+
+@-moz-keyframes fill-stroke {
+ 0% {
+ stroke-dasharray: 0 200;
+ }
+ 20% {
+ stroke-dasharray: 20 200;
+ }
+ 40% {
+ stroke-dasharray: 30 200;
+ }
+ 50% {
+ stroke-dasharray: 90 200;
+ }
+ 70% {
+ stroke-dasharray: 120 200;
+ }
+ 90% {
+ stroke-dasharray: 140 200;
+ }
+ 100% {
+ stroke-dasharray: 182 200;
+ }
+}
+@-webkit-keyframes fill-stroke {
+ 0% {
+ stroke-dasharray: 0 200;
+ }
+ 20% {
+ stroke-dasharray: 20 200;
+ }
+ 40% {
+ stroke-dasharray: 30 200;
+ }
+ 50% {
+ stroke-dasharray: 90 200;
+ }
+ 70% {
+ stroke-dasharray: 120 200;
+ }
+ 90% {
+ stroke-dasharray: 140 200;
+ }
+ 100% {
+ stroke-dasharray: 182 200;
+ }
+}
+@keyframes fill-stroke {
+ 0% {
+ stroke-dasharray: 0 200;
+ }
+ 20% {
+ stroke-dasharray: 20 200;
+ }
+ 40% {
+ stroke-dasharray: 30 200;
+ }
+ 50% {
+ stroke-dasharray: 90 200;
+ }
+ 70% {
+ stroke-dasharray: 120 200;
+ }
+ 90% {
+ stroke-dasharray: 140 200;
+ }
+ 100% {
+ stroke-dasharray: 182 200;
+ }
+}
+a, p {
+ line-height: 1.6em;
+}
+
+a {
+ color: #738087;
+}
diff --git a/templates/index.html b/templates/index.html
new file mode 100755
index 0000000..f756cad
--- /dev/null
+++ b/templates/index.html
@@ -0,0 +1,70 @@
+
+
+
+
+ face GAN
+
+
+
+
+
+
+
+
+
+Face Aging GAN
+
+
+
+
+
+
diff --git a/templates/output.html b/templates/output.html
new file mode 100644
index 0000000..3f0d66b
--- /dev/null
+++ b/templates/output.html
@@ -0,0 +1,20 @@
+
+
+
+
+ face GAN
+
+
+
+
+
+
Result!!
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100755
index 0000000..0d38d56
--- /dev/null
+++ b/test.py
@@ -0,0 +1,89 @@
+### Copyright (C) 2020 Roy Or-El. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import os
+import scipy # this is to prevent a potential error caused by importing torch before scipy (happens due to a bad combination of torch & scipy versions)
+from collections import OrderedDict
+from options.test_options import TestOptions
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+import util.util as util
+from util.visualizer import Visualizer
+from util import html
+import torch
+from pdb import set_trace as st
+
+
+def test(opt):
+ opt.nThreads = 1 # test code only supports nThreads = 1
+ opt.batchSize = 1 # test code only supports batchSize = 1
+ opt.serial_batches = True # no shuffle
+ opt.no_flip = True # no flip
+
+ data_loader = CreateDataLoader(opt)
+ dataset = data_loader.load_data()
+ dataset_size = len(data_loader)
+ print('#test batches = %d' % (int(dataset_size / len(opt.sort_order))))
+ visualizer = Visualizer(opt)
+ model = create_model(opt)
+ model.eval()
+
+ # create webpage
+ if opt.random_seed != -1:
+ exp_dir = '%s_%s_seed%s' % (opt.phase, opt.which_epoch, str(opt.random_seed))
+ else:
+ exp_dir = '%s_%s' % (opt.phase, opt.which_epoch)
+ web_dir = os.path.join(opt.results_dir, opt.name, exp_dir)
+
+ if opt.traverse or opt.deploy:
+ if opt.traverse:
+ out_dirname = 'traversal'
+ else:
+ out_dirname = 'deploy'
+ output_dir = os.path.join(web_dir,out_dirname)
+ if not os.path.isdir(output_dir):
+ os.makedirs(output_dir)
+
+ for image_path in opt.image_path_list:
+ print(image_path)
+ data = dataset.dataset.get_item_from_path(image_path)
+ visuals = model.inference(data)
+ if opt.traverse and opt.make_video:
+ out_path = os.path.join(output_dir, os.path.splitext(os.path.basename(image_path))[0] + '.mp4')
+ visualizer.make_video(visuals, out_path)
+ elif opt.traverse or (opt.deploy and opt.full_progression):
+ if opt.traverse and opt.compare_to_trained_outputs:
+ out_path = os.path.join(output_dir, os.path.splitext(os.path.basename(image_path))[0] + '_compare_to_{}_jump_{}.png'.format(opt.compare_to_trained_class, opt.trained_class_jump))
+ else:
+ out_path = os.path.join(output_dir, os.path.splitext(os.path.basename(image_path))[0] + '.png')
+ visualizer.save_row_image(visuals, out_path, traverse=opt.traverse)
+ else:
+ out_path = os.path.join(output_dir, os.path.basename(image_path[:-4]))
+ visualizer.save_images_deploy(visuals, out_path)
+ else:
+ webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
+
+ # test
+ for i, data in enumerate(dataset):
+ if i >= opt.how_many:
+ break
+
+ visuals = model.inference(data)
+ img_path = data['Paths']
+ rem_ind = []
+ for i, path in enumerate(img_path):
+ if path != '':
+ print('process image... %s' % path)
+ else:
+ rem_ind += [i]
+
+ for ind in reversed(rem_ind):
+ del img_path[ind]
+
+ visualizer.save_images(webpage, visuals, img_path)
+
+ webpage.save()
+
+
+if __name__ == "__main__":
+ opt = TestOptions().parse(save=False)
+ test(opt)
diff --git a/train.py b/train.py
new file mode 100755
index 0000000..69a6827
--- /dev/null
+++ b/train.py
@@ -0,0 +1,159 @@
+### Copyright (C) 2020 Roy Or-El. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import time
+import scipy # this is to prevent a potential error caused by importing torch before scipy (happens due to a bad combination of torch & scipy versions)
+from collections import OrderedDict
+from options.train_options import TrainOptions
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+import util.util as util
+from util.visualizer import Visualizer
+import os
+import numpy as np
+import torch
+from pdb import set_trace as st
+
+def train(opt):
+ iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
+
+ if opt.continue_train:
+ if opt.which_epoch == 'latest':
+ try:
+ start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
+ except:
+ start_epoch, epoch_iter = 1, 0
+ else:
+ start_epoch, epoch_iter = int(opt.which_epoch), 0
+
+ print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
+ for update_point in opt.decay_epochs:
+ if start_epoch < update_point:
+ break
+
+ opt.lr *= opt.decay_gamma
+ else:
+ start_epoch, epoch_iter = 0, 0
+
+ data_loader = CreateDataLoader(opt)
+ dataset = data_loader.load_data()
+ dataset_size = len(data_loader)
+ print('#training images = %d' % dataset_size)
+
+ model = create_model(opt)
+ visualizer = Visualizer(opt)
+
+ total_steps = (start_epoch) * dataset_size + epoch_iter
+
+ display_delta = total_steps % opt.display_freq
+ print_delta = total_steps % opt.print_freq
+ save_delta = total_steps % opt.save_latest_freq
+ bSize = opt.batchSize
+
+ #in case there's no display sample one image from each class to test after every epoch
+ if opt.display_id == 0:
+ dataset.dataset.set_sample_mode(True)
+ dataset.num_workers = 1
+ for i, data in enumerate(dataset):
+ if i*opt.batchSize >= opt.numClasses:
+ break
+ if i == 0:
+ sample_data = data
+ else:
+ for key, value in data.items():
+ if torch.is_tensor(data[key]):
+ sample_data[key] = torch.cat((sample_data[key], data[key]), 0)
+ else:
+ sample_data[key] = sample_data[key] + data[key]
+ dataset.num_workers = opt.nThreads
+ dataset.dataset.set_sample_mode(False)
+
+ for epoch in range(start_epoch, opt.epochs):
+ epoch_start_time = time.time()
+ if epoch != start_epoch:
+ epoch_iter = 0
+ for i, data in enumerate(dataset, start=epoch_iter):
+ iter_start_time = time.time()
+ total_steps += opt.batchSize
+ epoch_iter += opt.batchSize
+
+ # whether to collect output images
+ save_fake = (total_steps % opt.display_freq == display_delta) and (opt.display_id > 0)
+
+ ############## Network Pass ########################
+ model.set_inputs(data)
+ disc_losses = model.update_D()
+ gen_losses, gen_in, gen_out, rec_out, cyc_out = model.update_G(infer=save_fake)
+ loss_dict = dict(gen_losses, **disc_losses)
+ ##################################################
+
+ ############## Display results and errors ##########
+ ### print out errors
+ if total_steps % opt.print_freq == print_delta:
+ errors = {k: v.item() if not (isinstance(v, float) or isinstance(v, int)) else v for k, v in loss_dict.items()}
+ t = (time.time() - iter_start_time) / opt.batchSize
+ visualizer.print_current_errors(epoch+1, epoch_iter, errors, t)
+ if opt.display_id > 0:
+ visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
+
+ ### display output images
+ if save_fake and opt.display_id > 0:
+ class_a_suffix = ' class {}'.format(data['A_class'][0])
+ class_b_suffix = ' class {}'.format(data['B_class'][0])
+ classes = None
+
+ visuals = OrderedDict()
+ visuals_A = OrderedDict([('real image' + class_a_suffix, util.tensor2im(gen_in.data[0]))])
+ visuals_B = OrderedDict([('real image' + class_b_suffix, util.tensor2im(gen_in.data[bSize]))])
+
+ A_out_vis = OrderedDict([('synthesized image' + class_b_suffix, util.tensor2im(gen_out.data[0]))])
+ B_out_vis = OrderedDict([('synthesized image' + class_a_suffix, util.tensor2im(gen_out.data[bSize]))])
+ if opt.lambda_rec > 0:
+ A_out_vis.update([('reconstructed image' + class_a_suffix, util.tensor2im(rec_out.data[0]))])
+ B_out_vis.update([('reconstructed image' + class_b_suffix, util.tensor2im(rec_out.data[bSize]))])
+ if opt.lambda_cyc > 0:
+ A_out_vis.update([('cycled image' + class_a_suffix, util.tensor2im(cyc_out.data[0]))])
+ B_out_vis.update([('cycled image' + class_b_suffix, util.tensor2im(cyc_out.data[bSize]))])
+
+ visuals_A.update(A_out_vis)
+ visuals_B.update(B_out_vis)
+ visuals.update(visuals_A)
+ visuals.update(visuals_B)
+
+ ncols = len(visuals_A)
+ visualizer.display_current_results(visuals, epoch, classes, ncols)
+
+ ### save latest model
+ if total_steps % opt.save_latest_freq == save_delta:
+ print('saving the latest model (epoch %d, total_steps %d)' % (epoch+1, total_steps))
+ model.save('latest')
+ np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
+ if opt.display_id == 0:
+ model.eval()
+ visuals = model.inference(sample_data)
+ visualizer.save_matrix_image(visuals, 'latest')
+ model.train()
+
+ # end of epoch
+ iter_end_time = time.time()
+ print('End of epoch %d / %d \t Time Taken: %d sec' %
+ (epoch+1, opt.epochs, time.time() - epoch_start_time))
+
+ ### save model for this epoch
+ if (epoch+1) % opt.save_epoch_freq == 0:
+ print('saving the model at the end of epoch %d, iters %d' % (epoch+1, total_steps))
+ model.save('latest')
+ model.save(epoch+1)
+ np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
+ if opt.display_id == 0:
+ model.eval()
+ visuals = model.inference(sample_data)
+ visualizer.save_matrix_image(visuals, epoch+1)
+ model.train()
+
+ ### multiply learning rate by opt.decay_gamma after certain iterations
+ if (epoch+1) in opt.decay_epochs:
+ model.update_learning_rate()
+
+if __name__ == "__main__":
+ opt = TrainOptions().parse()
+ train(opt)
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100755
index 0000000..e69de29
diff --git a/util/__pycache__/__init__.cpython-310.pyc b/util/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000..0893879
Binary files /dev/null and b/util/__pycache__/__init__.cpython-310.pyc differ
diff --git a/util/__pycache__/__init__.cpython-38.pyc b/util/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..2a6ebd3
Binary files /dev/null and b/util/__pycache__/__init__.cpython-38.pyc differ
diff --git a/util/__pycache__/deeplab.cpython-38.pyc b/util/__pycache__/deeplab.cpython-38.pyc
new file mode 100644
index 0000000..09fde7f
Binary files /dev/null and b/util/__pycache__/deeplab.cpython-38.pyc differ
diff --git a/util/__pycache__/html.cpython-38.pyc b/util/__pycache__/html.cpython-38.pyc
new file mode 100644
index 0000000..4547ada
Binary files /dev/null and b/util/__pycache__/html.cpython-38.pyc differ
diff --git a/util/__pycache__/preprocess_itw_im.cpython-310.pyc b/util/__pycache__/preprocess_itw_im.cpython-310.pyc
new file mode 100644
index 0000000..2b1383c
Binary files /dev/null and b/util/__pycache__/preprocess_itw_im.cpython-310.pyc differ
diff --git a/util/__pycache__/preprocess_itw_im.cpython-38.pyc b/util/__pycache__/preprocess_itw_im.cpython-38.pyc
new file mode 100644
index 0000000..2c760fa
Binary files /dev/null and b/util/__pycache__/preprocess_itw_im.cpython-38.pyc differ
diff --git a/util/__pycache__/util.cpython-310.pyc b/util/__pycache__/util.cpython-310.pyc
new file mode 100644
index 0000000..4aa3922
Binary files /dev/null and b/util/__pycache__/util.cpython-310.pyc differ
diff --git a/util/__pycache__/util.cpython-38.pyc b/util/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000..4a2c672
Binary files /dev/null and b/util/__pycache__/util.cpython-38.pyc differ
diff --git a/util/__pycache__/visualizer.cpython-38.pyc b/util/__pycache__/visualizer.cpython-38.pyc
new file mode 100644
index 0000000..2e6f2aa
Binary files /dev/null and b/util/__pycache__/visualizer.cpython-38.pyc differ
diff --git a/util/deeplab.py b/util/deeplab.py
new file mode 100644
index 0000000..002bdec
--- /dev/null
+++ b/util/deeplab.py
@@ -0,0 +1,257 @@
+# Copyright (c) 2020, Roy Or-El. All rights reserved.
+#
+# This work is licensed under the Creative Commons
+# Attribution-NonCommercial-ShareAlike 4.0 International License.
+# To view a copy of this license, visit
+# http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to
+# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
+
+# This file was taken as is from the https://github.com/chenxi116/DeepLabv3.pytorch repository.
+
+import torch
+import torch.nn as nn
+import math
+import torch.utils.model_zoo as model_zoo
+from torch.nn import functional as F
+
+
+__all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152']
+
+
+model_urls = {
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+
+class Conv2d(nn.Conv2d):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True):
+ super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias)
+
+ def forward(self, x):
+ # return super(Conv2d, self).forward(x)
+ weight = self.weight
+ weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
+ keepdim=True).mean(dim=3, keepdim=True)
+ weight = weight - weight_mean
+ std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
+ weight = weight / std.expand_as(weight)
+ return F.conv2d(x, weight, self.bias, self.stride,
+ self.padding, self.dilation, self.groups)
+
+
+class ASPP(nn.Module):
+
+ def __init__(self, C, depth, num_classes, conv=nn.Conv2d, norm=nn.BatchNorm2d, momentum=0.0003, mult=1):
+ super(ASPP, self).__init__()
+ self._C = C
+ self._depth = depth
+ self._num_classes = num_classes
+
+ self.global_pooling = nn.AdaptiveAvgPool2d(1)
+ self.relu = nn.ReLU(inplace=True)
+ self.aspp1 = conv(C, depth, kernel_size=1, stride=1, bias=False)
+ self.aspp2 = conv(C, depth, kernel_size=3, stride=1,
+ dilation=int(6*mult), padding=int(6*mult),
+ bias=False)
+ self.aspp3 = conv(C, depth, kernel_size=3, stride=1,
+ dilation=int(12*mult), padding=int(12*mult),
+ bias=False)
+ self.aspp4 = conv(C, depth, kernel_size=3, stride=1,
+ dilation=int(18*mult), padding=int(18*mult),
+ bias=False)
+ self.aspp5 = conv(C, depth, kernel_size=1, stride=1, bias=False)
+ self.aspp1_bn = norm(depth, momentum)
+ self.aspp2_bn = norm(depth, momentum)
+ self.aspp3_bn = norm(depth, momentum)
+ self.aspp4_bn = norm(depth, momentum)
+ self.aspp5_bn = norm(depth, momentum)
+ self.conv2 = conv(depth * 5, depth, kernel_size=1, stride=1,
+ bias=False)
+ self.bn2 = norm(depth, momentum)
+ self.conv3 = nn.Conv2d(depth, num_classes, kernel_size=1, stride=1)
+
+ def forward(self, x):
+ x1 = self.aspp1(x)
+ x1 = self.aspp1_bn(x1)
+ x1 = self.relu(x1)
+ x2 = self.aspp2(x)
+ x2 = self.aspp2_bn(x2)
+ x2 = self.relu(x2)
+ x3 = self.aspp3(x)
+ x3 = self.aspp3_bn(x3)
+ x3 = self.relu(x3)
+ x4 = self.aspp4(x)
+ x4 = self.aspp4_bn(x4)
+ x4 = self.relu(x4)
+ x5 = self.global_pooling(x)
+ x5 = self.aspp5(x5)
+ x5 = self.aspp5_bn(x5)
+ x5 = self.relu(x5)
+ x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
+ align_corners=True)(x5)
+ x = torch.cat((x1, x2, x3, x4, x5), 1)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+ x = self.conv3(x)
+
+ return x
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, conv=None, norm=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = conv(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = norm(planes)
+ self.conv2 = conv(planes, planes, kernel_size=3, stride=stride,
+ dilation=dilation, padding=dilation, bias=False)
+ self.bn2 = norm(planes)
+ self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = norm(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes, num_groups=None, weight_std=False, beta=False):
+ self.inplanes = 64
+ self.norm = lambda planes, momentum=0.05: nn.BatchNorm2d(planes, momentum=momentum) if num_groups is None else nn.GroupNorm(num_groups, planes)
+ self.conv = Conv2d if weight_std else nn.Conv2d
+
+ super(ResNet, self).__init__()
+ if not beta:
+ self.conv1 = self.conv(3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ else:
+ self.conv1 = nn.Sequential(
+ self.conv(3, 64, 3, stride=2, padding=1, bias=False),
+ self.conv(64, 64, 3, stride=1, padding=1, bias=False),
+ self.conv(64, 64, 3, stride=1, padding=1, bias=False))
+ self.bn1 = self.norm(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
+ dilation=2)
+ self.aspp = ASPP(512 * block.expansion, 256, num_classes, conv=self.conv, norm=self.norm)
+
+ for m in self.modules():
+ if isinstance(m, self.conv):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
+ downsample = None
+ if stride != 1 or dilation != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ self.conv(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, dilation=max(1, dilation/2), bias=False),
+ self.norm(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, dilation=max(1, dilation/2), conv=self.conv, norm=self.norm))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, dilation=dilation, conv=self.conv, norm=self.norm))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ size = (x.shape[2], x.shape[3])
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.aspp(x)
+ x = nn.Upsample(size, mode='bilinear', align_corners=True)(x)
+ return x
+
+
+def resnet50(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
+ return model
+
+
+def resnet101(pretrained=False, num_groups=None, weight_std=False, **kwargs):
+ """Constructs a ResNet-101 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 23, 3], num_groups=num_groups, weight_std=weight_std, **kwargs)
+ if pretrained:
+ model_dict = model.state_dict()
+ if num_groups and weight_std:
+ pretrained_dict = torch.load('deeplab_model/R-101-GN-WS.pth.tar')
+ overlap_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
+ assert len(overlap_dict) == 312
+ elif not num_groups and not weight_std:
+ pretrained_dict = model_zoo.load_url(model_urls['resnet101'])
+ overlap_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
+ else:
+ raise ValueError('Currently only support BN or GN+WS')
+ model_dict.update(overlap_dict)
+ model.load_state_dict(model_dict)
+ return model
+
+
+def resnet152(pretrained=False, **kwargs):
+ """Constructs a ResNet-152 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
+ return model
diff --git a/util/html.py b/util/html.py
new file mode 100755
index 0000000..f144550
--- /dev/null
+++ b/util/html.py
@@ -0,0 +1,82 @@
+### Copyright (C) 2020 Roy Or-El. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import dominate
+import math
+from dominate.tags import *
+import os
+
+
+class HTML:
+ def __init__(self, web_dir, title, refresh=0):
+ self.title = title
+ self.web_dir = web_dir
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ if not os.path.exists(self.web_dir):
+ os.makedirs(self.web_dir)
+ if not os.path.exists(self.img_dir):
+ os.makedirs(self.img_dir)
+
+ self.doc = dominate.document(title=title)
+ if refresh > 0:
+ with self.doc.head:
+ meta(http_equiv="refresh", content=str(refresh))
+
+ def get_image_dir(self):
+ return self.img_dir
+
+ def add_header(self, str):
+ with self.doc:
+ h3(str)
+
+ def add_table(self, border=1):
+ self.t = table(border=border, style="table-layout: fixed;")
+ self.doc.add(self.t)
+
+ def add_images(self, ims, txts, links, width=512, cols=0):
+ imNum = len(ims)
+ self.add_table()
+ with self.t:
+ if cols == 0:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
+ br()
+ p(txt)
+ else:
+ rows = int(math.ceil(float(imNum) / float(cols)))
+ for i in range(rows):
+ with tr():
+ for j in range(cols):
+ im = ims[i*cols + j]
+ txt = txts[i*cols + j]
+ link = links[i*cols + j]
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
+ br()
+ p(txt)
+
+ def save(self):
+ html_file = '%s/index.html' % self.web_dir
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__':
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims = []
+ txts = []
+ links = []
+ for n in range(4):
+ ims.append('image_%d.png' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.png' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/util/preprocess_itw_im.py b/util/preprocess_itw_im.py
new file mode 100644
index 0000000..989611c
--- /dev/null
+++ b/util/preprocess_itw_im.py
@@ -0,0 +1,188 @@
+import os
+import dlib
+import shutil
+import requests
+import numpy as np
+import scipy.ndimage
+import torch
+import torchvision.transforms as transforms
+import util.deeplab as deeplab
+from PIL import Image
+from util.util import download_file
+from pdb import set_trace as st
+
+resnet_file_path = 'deeplab_model/R-101-GN-WS.pth.tar'
+deeplab_file_path = 'deeplab_model/deeplab_model.pth'
+predictor_file_path = 'util/shape_predictor_68_face_landmarks.dat'
+model_fname = 'deeplab_model/deeplab_model.pth'
+deeplab_classes = ['background' ,'skin','nose','eye_g','l_eye','r_eye','l_brow','r_brow','l_ear','r_ear','mouth','u_lip','l_lip','hair','hat','ear_r','neck_l','neck','cloth']
+
+
+class preprocessInTheWildImage():
+ def __init__(self, out_size=256):
+ self.out_size = out_size
+
+ # load landmark detector models
+ self.detector = dlib.get_frontal_face_detector()
+ if not os.path.isfile(predictor_file_path):
+ print('Cannot find landmarks shape predictor model.\n'\
+ 'Please run download_models.py to download the model')
+ raise OSError
+
+ self.predictor = dlib.shape_predictor(predictor_file_path)
+
+ # deeplab data properties
+ self.deeplab_data_transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+ ])
+ self.deeplab_input_size = 513
+
+ # load deeplab model
+ assert torch.cuda.is_available()
+ torch.backends.cudnn.benchmark = True
+ if not os.path.isfile(resnet_file_path):
+ print('Cannot find DeeplabV3 backbone Resnet model.\n' \
+ 'Please run download_models.py to download the model')
+ raise OSError
+
+ self.deeplab_model = getattr(deeplab, 'resnet101')(
+ pretrained=True,
+ num_classes=len(deeplab_classes),
+ num_groups=32,
+ weight_std=True,
+ beta=False)
+
+ self.deeplab_model.eval()
+ if not os.path.isfile(deeplab_file_path):
+ print('Cannot find DeeplabV3 model.\n' \
+ 'Please run download_models.py to download the model')
+ raise OSError
+
+ checkpoint = torch.load(model_fname)
+ state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items() if 'tracked' not in k}
+ self.deeplab_model.load_state_dict(state_dict)
+
+ def dlib_shape_to_landmarks(self, shape):
+ # initialize the list of (x, y)-coordinates
+ landmarks = np.zeros((68, 2), dtype=np.float32)
+ # loop over the 68 facial landmarks and convert them
+ # to a 2-tuple of (x, y)-coordinates
+ for i in range(0, 68):
+ landmarks[i] = (shape.part(i).x, shape.part(i).y)
+ # return the list of (x, y)-coordinates
+ return landmarks
+
+ def extract_face_landmarks(self, img):
+ # detect all faces in the image and
+ # keep the detection with the largest bounding box
+ dets = self.detector(img, 1)
+ if len(dets) == 0:
+ print ('Could not detect any face in the image, please try again with a different image')
+ raise
+
+ max_area = 0
+ max_idx = -1
+ for k, d in enumerate(dets):
+ area = (d.right() - d.left()) * (d.bottom() - d.top())
+ if area > max_area:
+ max_area = area
+ max_idx = k
+
+ # Get the landmarks/parts for the face in box d.
+ dlib_shape = self.predictor(img, dets[max_idx])
+ landmarks = self.dlib_shape_to_landmarks(dlib_shape)
+ return landmarks
+
+ def align_in_the_wild_image(self, np_img, lm, transform_size=4096, enable_padding=True):
+ # Parse landmarks.
+ lm_chin = lm[0 : 17] # left-right
+ lm_eyebrow_left = lm[17 : 22] # left-right
+ lm_eyebrow_right = lm[22 : 27] # left-right
+ lm_nose = lm[27 : 31] # top-down
+ lm_nostrils = lm[31 : 36] # top-down
+ lm_eye_left = lm[36 : 42] # left-clockwise
+ lm_eye_right = lm[42 : 48] # left-clockwise
+ lm_mouth_outer = lm[48 : 60] # left-clockwise
+ lm_mouth_inner = lm[60 : 68] # left-clockwise
+
+ # Calculate auxiliary vectors.
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ mouth_left = lm_mouth_outer[0]
+ mouth_right = lm_mouth_outer[6]
+ mouth_avg = (mouth_left + mouth_right) * 0.5
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Choose oriented crop rectangle.
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ x /= np.hypot(*x)
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 2.2) # This results in larger crops then the original FFHQ. For the original crops, replace 2.2 with 1.8
+ y = np.flipud(x) * [-1, 1]
+ c = eye_avg + eye_to_mouth * 0.1
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ qsize = np.hypot(*x) * 2
+
+ # Load in-the-wild image.
+ img = Image.fromarray(np_img)
+
+ # Shrink.
+ shrink = int(np.floor(qsize / self.out_size * 0.5))
+ if shrink > 1:
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
+ img = img.resize(rsize, Image.ANTIALIAS)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop.
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+ img = img.crop(crop)
+ quad -= crop[0:2]
+
+ # Pad.
+ pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ h, w, _ = img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
+ blur = qsize * 0.02
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+ quad += pad[:2]
+
+ # Transform.
+ img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
+ if self.out_size < transform_size:
+ img = img.resize((self.out_size, self.out_size), Image.ANTIALIAS)
+
+ return img
+
+
+ def get_segmentation_maps(self, img):
+ img = img.resize((self.deeplab_input_size,self.deeplab_input_size),Image.BILINEAR)
+ img = self.deeplab_data_transform(img)
+ img = img.cuda()
+ self.deeplab_model.cuda()
+ outputs = self.deeplab_model(img.unsqueeze(0))
+ self.deeplab_model.cpu()
+ _, pred = torch.max(outputs, 1)
+ pred = pred.data.cpu().numpy().squeeze().astype(np.uint8)
+ seg_map = Image.fromarray(pred)
+ seg_map = np.uint8(seg_map.resize((self.out_size,self.out_size), Image.NEAREST))
+ return seg_map
+
+ def forward(self, img):
+ landmarks = self.extract_face_landmarks(img)
+ aligned_img = self.align_in_the_wild_image(img, landmarks)
+ seg_map = self.get_segmentation_maps(aligned_img)
+ aligned_img = np.array(aligned_img.getdata(), dtype=np.uint8).reshape(self.out_size, self.out_size, 3)
+ return aligned_img, seg_map
diff --git a/util/util.py b/util/util.py
new file mode 100755
index 0000000..40df520
--- /dev/null
+++ b/util/util.py
@@ -0,0 +1,196 @@
+### Copyright (C) 2020 Roy Or-El. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import os
+import html
+import glob
+import uuid
+import hashlib
+import requests
+import torch
+import zipfile
+import numpy as np
+from tqdm import tqdm
+from PIL import Image
+from pdb import set_trace as st
+
+
+males_model_spec = dict(file_url='https://drive.google.com/uc?id=1MsXN54hPi9PWDmn1HKdmKfv-J5hWYFVZ',
+ alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/males_model.zip',
+ file_path='checkpoints/males_model.zip', file_size=213175683, file_md5='0079186147ec816176b946a073d1f396')
+females_model_spec = dict(file_url='https://drive.google.com/uc?id=1LNm0zAuiY0CIJnI0lHTq1Ttcu9_M1NAJ',
+ alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/females_model.zip',
+ file_path='checkpoints/females_model.zip', file_size=213218113, file_md5='0675f809413c026170cf1f22b27f3c5d')
+resnet_file_spec = dict(file_url='https://drive.google.com/uc?id=1oRGgrI4KNdefbWVpw0rRkEP1gbJIRokM',
+ alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/R-101-GN-WS.pth.tar',
+ file_path='deeplab_model/R-101-GN-WS.pth.tar', file_size=178260167, file_md5='aa48cc3d3ba3b7ac357c1489b169eb32')
+deeplab_file_spec = dict(file_url='https://drive.google.com/uc?id=1w2XjDywFr2NjuUWaLQDRktH7VwIfuNlY',
+ alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/deeplab_model.pth',
+ file_path='deeplab_model/deeplab_model.pth', file_size=464446305, file_md5='8e8345b1b9d95e02780f9bed76cc0293')
+predictor_file_spec = dict(file_url='https://drive.google.com/uc?id=1fhq5lvWy-rjrzuHdMoZfLsULvF0gJGwD',
+ alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/shape_predictor_68_face_landmarks.dat',
+ file_path='util/shape_predictor_68_face_landmarks.dat', file_size=99693937, file_md5='73fde5e05226548677a050913eed4e04')
+
+# Converts a Tensor into a Numpy array
+# |imtype|: the desired type of the converted numpy array
+def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
+ im_sz = image_tensor.size()
+ ndims = image_tensor.dim()
+ if ndims == 2:
+ image_numpy = image_tensor.cpu().float().numpy()
+ image_numpy = (image_numpy + 1) / 2.0 * 255.0
+ elif ndims == 3:
+ image_numpy = image_tensor.cpu().float().numpy()
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
+ elif ndims == 4 and im_sz[0] == 1:
+ image_numpy = image_tensor[0].cpu().float().numpy()
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
+ elif ndims == 4:
+ image_numpy = image_tensor.cpu().float().numpy()
+ image_numpy = (np.transpose(image_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0
+ else: # ndims == 5
+ image_numpy = image_tensor.cpu().float().numpy()
+ image_numpy = (np.transpose(image_numpy, (0, 1, 3, 4, 2)) + 1) / 2.0 * 255.0
+
+ return image_numpy.astype(imtype)
+
+def save_image(image_numpy, image_path):
+ image_pil = Image.fromarray(image_numpy)
+ image_pil.save(image_path)
+
+def mkdirs(paths):
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+def download_pretrained_models():
+ print('Downloading males model')
+ with requests.Session() as session:
+ try:
+ download_file(session, males_model_spec)
+ except:
+ print('Google Drive download failed.\n' \
+ 'Trying do download from alternate server')
+ download_file(session, males_model_spec, use_alt_url=True)
+
+ print('Extracting males model zip file')
+ with zipfile.ZipFile('./checkpoints/males_model.zip','r') as zip_fname:
+ zip_fname.extractall('./checkpoints')
+
+ print('Done!')
+ os.remove(males_model_spec['file_path'])
+
+ print('Downloading females model')
+ with requests.Session() as session:
+ try:
+ download_file(session, females_model_spec)
+ except:
+ print('Google Drive download failed.\n' \
+ 'Trying do download from alternate server')
+ download_file(session, females_model_spec, use_alt_url=True)
+
+ print('Extracting females model zip file')
+ with zipfile.ZipFile('./checkpoints/females_model.zip','r') as zip_fname:
+ zip_fname.extractall('./checkpoints')
+
+ print('Done!')
+ os.remove(females_model_spec['file_path'])
+
+ print('Downloading face landmarks shape predictor')
+ with requests.Session() as session:
+ try:
+ download_file(session, predictor_file_spec)
+ except:
+ print('Google Drive download failed.\n' \
+ 'Trying do download from alternate server')
+ download_file(session, predictor_file_spec, use_alt_url=True)
+
+ print('Done!')
+
+ print('Downloading DeeplabV3 backbone Resnet Model parameters')
+ with requests.Session() as session:
+ try:
+ download_file(session, resnet_file_spec)
+ except:
+ print('Google Drive download failed.\n' \
+ 'Trying do download from alternate server')
+ download_file(session, resnet_file_spec, use_alt_url=True)
+
+ print('Done!')
+
+ print('Downloading DeeplabV3 Model parameters')
+ with requests.Session() as session:
+ try:
+ download_file(session, deeplab_file_spec)
+ except:
+ print('Google Drive download failed.\n' \
+ 'Trying do download from alternate server')
+ download_file(session, deeplab_file_spec, use_alt_url=True)
+
+ print('Done!')
+
+def download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10):
+ file_path = file_spec['file_path']
+ if use_alt_url:
+ file_url = file_spec['alt_url']
+ else:
+ file_url = file_spec['file_url']
+
+ file_dir = os.path.dirname(file_path)
+ tmp_path = file_path + '.tmp.' + uuid.uuid4().hex
+ if file_dir:
+ os.makedirs(file_dir, exist_ok=True)
+
+ progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True)
+ for attempts_left in reversed(range(num_attempts)):
+ data_size = 0
+ progress_bar.reset()
+ try:
+ # Download.
+ data_md5 = hashlib.md5()
+ with session.get(file_url, stream=True) as res:
+ res.raise_for_status()
+ with open(tmp_path, 'wb') as f:
+ for chunk in res.iter_content(chunk_size=chunk_size<<10):
+ progress_bar.update(len(chunk))
+ f.write(chunk)
+ data_size += len(chunk)
+ data_md5.update(chunk)
+
+ # Validate.
+ if 'file_size' in file_spec and data_size != file_spec['file_size']:
+ raise IOError('Incorrect file size', file_path)
+ if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']:
+ raise IOError('Incorrect file MD5', file_path)
+ break
+
+ except:
+ # Last attempt => raise error.
+ if not attempts_left:
+ raise
+
+ # Handle Google Drive virus checker nag.
+ if data_size > 0 and data_size < 8192:
+ with open(tmp_path, 'rb') as f:
+ data = f.read()
+ links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'export=download' in link]
+ if len(links) == 1:
+ file_url = requests.compat.urljoin(file_url, links[0])
+ continue
+
+ progress_bar.close()
+
+ # Rename temp file to the correct name.
+ os.replace(tmp_path, file_path) # atomic
+
+ # Attempt to clean up any leftover temps.
+ for filename in glob.glob(file_path + '.tmp.*'):
+ try:
+ os.remove(filename)
+ except:
+ pass
diff --git a/util/visualizer.py b/util/visualizer.py
new file mode 100755
index 0000000..a21d04d
--- /dev/null
+++ b/util/visualizer.py
@@ -0,0 +1,239 @@
+### Copyright (C) 2020 Roy Or-El. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import numpy as np
+import os
+import cv2
+import time
+import unidecode
+from . import util
+from . import html
+from pdb import set_trace as st
+
+class Visualizer():
+ def __init__(self, opt):
+ # self.opt = opt
+ self.display_id = opt.display_id
+ self.use_html = opt.isTrain and not opt.no_html
+ self.win_size = opt.display_winsize
+ self.name = opt.name
+ self.numClasses = opt.numClasses
+ self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'images')
+ self.isTrain = opt.isTrain
+ if self.isTrain:
+ self.save_freq = opt.save_display_freq
+
+ if self.display_id > 0:
+ import visdom
+ self.vis = visdom.Visdom(port = opt.display_port)
+ self.display_single_pane_ncols = opt.display_single_pane_ncols
+
+ if self.use_html:
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ if self.isTrain:
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+ # |visuals|: dictionary of images to display or save
+ def display_current_results(self, visuals, it, classes, ncols):
+ if self.display_single_pane_ncols > 0:
+ h, w = next(iter(visuals.values())).shape[:2]
+ table_css = """""" % (w, h)
+ # ncols = self.display_single_pane_ncols
+ title = self.name
+ label_html = ''
+ label_html_row = ''
+ nrows = int(np.ceil(len(visuals.items()) / ncols))
+ images = []
+ idx = 0
+ for label, image_numpy in visuals.items():
+ label_html_row += '%s | ' % label
+ if image_numpy.ndim < 3:
+ image_numpy = np.expand_dims(image_numpy, 2)
+ image_numpy = np.tile(image_numpy, (1, 1, 3))
+
+ images.append(image_numpy.transpose([2, 0, 1]))
+ idx += 1
+ if idx % ncols == 0:
+ label_html += '%s
' % label_html_row
+ label_html_row = ''
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
+ while idx % ncols != 0:
+ images.append(white_image)
+ label_html_row += ' | '
+ idx += 1
+ if label_html_row != '':
+ label_html += '%s
' % label_html_row
+
+ self.vis.images(images, nrow=ncols, win=self.display_id + 1,
+ padding=2, opts=dict(title=title + ' images'))
+ label_html = '' % label_html
+ self.vis.text(table_css + label_html, win = self.display_id + 2,
+ opts=dict(title=title + ' labels'))
+ else:
+ idx = 1
+ for label, image_numpy in visuals.items():
+ self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),
+ win=self.display_id + idx)
+ idx += 1
+
+
+ # errors: dictionary of error labels and values
+ def plot_current_errors(self, epoch, counter_ratio, opt, errors):
+ if not hasattr(self, 'plot_data'):
+ self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
+ self.plot_data['X'].append(epoch + counter_ratio)
+ self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
+ self.vis.line(
+ X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
+ Y=np.array(self.plot_data['Y']),
+ opts={
+ 'title': self.name + ' loss over time',
+ 'legend': self.plot_data['legend'],
+ 'xlabel': 'epoch',
+ 'ylabel': 'loss'},
+ win=self.display_id)
+
+ # errors: same format as |errors| of plotCurrentErrors
+ def print_current_errors(self, epoch, i, errors, t):
+ message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
+ for k, v in errors.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message)
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message)
+
+ def save_matrix_image(self, visuals, epoch):
+ for i in range(len(visuals)):
+ visual = visuals[i]
+ orig_img = visual['orig_img_cls_' + str(i)]
+ curr_row_img = orig_img
+ for cls in range(self.numClasses):
+ next_im = visual['tex_trans_to_class_' + str(cls)]
+ curr_row_img = np.concatenate((curr_row_img, next_im), 1)
+
+ if i == 0:
+ matrix_img = curr_row_img
+ else:
+ matrix_img = np.concatenate((matrix_img, curr_row_img), 0)
+
+ if epoch != 'latest':
+ epoch_txt = 'epoch_' + str(epoch)
+ else:
+ epoch_txt = epochs
+
+ image_path = os.path.join(self.img_dir,'sample_batch_{}.png'.format(epoch_txt))
+ util.save_image(matrix_img, image_path)
+
+ def save_row_image(self, visuals, image_path, traverse=False):
+ visual = visuals[0]
+ orig_img = visual['orig_img']
+ h, w, c = orig_img.shape
+ traversal_img = np.concatenate((orig_img, np.full((h, 10, c), 255, dtype=np.uint8)), 1)
+ if traverse:
+ out_classes = len(visual) - 1
+ else:
+ out_classes = self.numClasses
+ for cls in range(out_classes):
+ next_im = visual['tex_trans_to_class_' + str(cls)]
+ traversal_img = np.concatenate((traversal_img, next_im), 1)
+
+ util.save_image(traversal_img, image_path)
+
+ def make_video(self, visuals, video_path):
+ fps = 20#25
+ visual = visuals[0]
+ orig_img = visual['orig_img']
+ h, w = orig_img.shape[0], orig_img.shape[1]
+ writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w,h))
+ out_classes = len(visual) - 1
+ for cls in range(out_classes):
+ next_im = visual['tex_trans_to_class_' + str(cls)]
+ writer.write(next_im[:,:,::-1])
+
+ writer.release()
+
+ # save image to the disk
+ def save_images_deploy(self, visuals, image_path):
+ for i in range(len(visuals)):
+ visual = visuals[i]
+ for label, image_numpy in visual.items():
+ save_path = '%s_%s.png' % (image_path, label)
+ util.save_image(image_numpy, save_path)
+
+
+ # save image to the disk
+ def save_images(self, webpage, visuals, image_path, gt_visuals=None, gt_path=None):
+ cols = self.numClasses+1
+ image_dir = webpage.get_image_dir()
+ if gt_visuals == None or gt_path == None:
+ for i in range(len(visuals)):
+ visual = visuals[i]
+ short_path = os.path.basename(image_path[i])
+ name = unidecode.unidecode(os.path.splitext(short_path)[0]) #removes accents which cause html load error
+ webpage.add_header(name)
+ ims = []
+ txts = []
+ links = []
+ for label, image_numpy in visual.items():
+ image_name = '%s_%s.png' % (name, label)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(image_numpy, save_path)
+
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+
+ webpage.add_images(ims, txts, links, width=self.win_size,cols=cols)
+ else:
+ batchSize = len(image_path)
+
+ # save ground truth images
+ if gt_path is not None:
+ gt_short_path = os.path.basename(gt_path[0])
+ gt_name = os.path.splitext(gt_path)[0]
+ gt_ims = []
+ gt_txts = []
+ gt_links = []
+ for label, image_numpy in gt_visuals.items():
+ image_name = '%s_%s.png' % (gt_name, label)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(image_numpy, save_path)
+
+ gt_ims.append(image_name)
+ gt_txts.append(label)
+ gt_links.append(image_name)
+
+ for i in range(batchSize):
+ short_path = os.path.basename(image_path[i])
+ name = os.path.splitext(short_path)[0]
+
+ # webpage.add_header(name)
+ ims = []
+ txts = []
+ links = []
+
+ for label, image_numpy in visuals[i].items():
+ image_name = '%s_%s.png' % (name, label)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(image_numpy, save_path)
+
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ print("saving results for: " + name)
+
+ if gt_path is not None:
+ webpage.add_header(gt_name)
+ webpage.add_images(gt_ims, gt_txts, gt_links, width=self.win_size, cols=batchSize)
+
+ webpage.add_header(name)
+ webpage.add_images(ims, txts, links, width=self.win_size, cols=self.numClasses + 1)