diff --git a/export.py b/export.py new file mode 100755 index 000000000..a4d187e2c --- /dev/null +++ b/export.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +import cv2 +import numpy as np +import os +import random +import torch +import torch.onnx +import argparse + +from shutil import copyfile +from src.config import Config +from src.models import EdgeModel, InpaintingModel + +MAX_WIDTH = 600 +MAX_HEIGHT = 512 + +def main(): + """Exports models as ONNX file + + """ + + parser = argparse.ArgumentParser() + parser.add_argument('--path', '--checkpoints', type=str, default='./checkpoints', help='model checkpoints path (default: ./checkpoints)') + parser.add_argument('--model', type=int, choices=[1, 2], help='1: edge model, 2: inpaint model') + + args = parser.parse_args() + config_path = os.path.join(args.path, 'config.yml') + + # create checkpoints path if does't exist + if not os.path.exists(args.path): + os.makedirs(args.path) + + # copy config template if does't exist + if not os.path.exists(config_path): + copyfile('./config.yml.example', config_path) + + # load config file + config = Config(config_path) + + # cuda visble devices + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(e) for e in config.GPU) + + # init device + if torch.cuda.is_available(): + config.DEVICE = torch.device("cuda") + torch.backends.cudnn.benchmark = True # cudnn auto-tuner + else: + config.DEVICE = torch.device("cpu") + + # set cv2 running threads to 1 (prevents deadlocks with pytorch dataloader) + cv2.setNumThreads(0) + + # initialize random seed + torch.manual_seed(config.SEED) + torch.cuda.manual_seed_all(config.SEED) + np.random.seed(config.SEED) + random.seed(config.SEED) + + # Model dummy input + dummy_input = ( + torch.randn(1, 1, MAX_HEIGHT, MAX_WIDTH, requires_grad=True).to(config.DEVICE) if args.model == 1 else torch.randn(1, 3, MAX_HEIGHT, MAX_WIDTH, requires_grad=True).to(config.DEVICE), # Image + torch.randn(1, 1, MAX_HEIGHT, MAX_WIDTH, requires_grad=True).to(config.DEVICE), # Masks + torch.randn(1, 1, MAX_HEIGHT, MAX_WIDTH, requires_grad=True).to(config.DEVICE) # Edge + ) + + # Edge model + if args.model == 1: + # Create edge model and initialize + edge_model = EdgeModel(config).to(config.DEVICE) + + # Load model + edge_model.load() + + # Eval mode + edge_model.eval() + + # Export as ONNX + torch.onnx.export( + edge_model, + dummy_input, + "edge-model.onnx", + export_params=True, + opset_version=10, + do_constant_folding=True, + input_names = ["InputImage", "Mask", "Edges"], + output_names = ["OutputImage"], + ) + else: # Inpaint model + # Create inpainting model and initialize + inpaint_model = InpaintingModel(config).to(config.DEVICE) + # Load model + inpaint_model.load() + + # Eval mode + inpaint_model.eval() + + # Export as ONNX + torch.onnx.export( + inpaint_model, + dummy_input, + "edge-connect-inpaint.onnx", + export_params=True, + opset_version=10, + do_constant_folding=True, + input_names = ["InputImage", "Mask", "Edges"], + output_names = ["OutputImage"], + ) + +if __name__ == "__main__": + main() diff --git a/src/dataset.py b/src/dataset.py index 96aaafebc..209a2f781 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -6,8 +6,8 @@ import numpy as np import torchvision.transforms.functional as F from torch.utils.data import DataLoader +from imageio import imread from PIL import Image -from scipy.misc import imread from skimage.feature import canny from skimage.color import rgb2gray, gray2rgb from .utils import create_mask @@ -54,7 +54,7 @@ def load_item(self, index): size = self.input_size # load image - img = imread(self.data[index]) + img = imread(self.data[index])[:,:,:3] # gray to rgb if len(img.shape) < 3: @@ -87,7 +87,7 @@ def load_edge(self, img, index, mask): # in test mode images are masked (with masked regions), # using 'mask' parameter prevents canny to detect edges for the masked regions - mask = None if self.training else (1 - mask / 255).astype(np.bool) + mask = None if self.training else (1 - mask / 255).astype(bool) # canny if self.edge == 1: @@ -99,12 +99,12 @@ def load_edge(self, img, index, mask): if sigma == 0: sigma = random.randint(1, 4) - return canny(img, sigma=sigma, mask=mask).astype(np.float) + return canny(img, sigma=sigma, mask=mask).astype(float) # external else: imgh, imgw = img.shape[0:2] - edge = imread(self.edge_data[index]) + edge = imread(self.edge_data[index])[:,:,:3] edge = self.resize(edge, imgh, imgw) # non-max suppression @@ -137,7 +137,7 @@ def load_mask(self, img, index): # external if mask_type == 3: mask_index = random.randint(0, len(self.mask_data) - 1) - mask = imread(self.mask_data[mask_index]) + mask = imread(self.mask_data[mask_index])[:,:,:3] mask = self.resize(mask, imgh, imgw) mask = (mask > 0).astype(np.uint8) * 255 # threshold due to interpolation return mask @@ -146,7 +146,8 @@ def load_mask(self, img, index): if mask_type == 6: mask = imread(self.mask_data[index]) mask = self.resize(mask, imgh, imgw, centerCrop=False) - mask = rgb2gray(mask) + if mask.shape[-1] == 3: + mask = rgb2gray(mask) mask = (mask > 0).astype(np.uint8) * 255 return mask @@ -165,7 +166,7 @@ def resize(self, img, height, width, centerCrop=True): i = (imgw - side) // 2 img = img[j:j + side, i:i + side, ...] - img = scipy.misc.imresize(img, [height, width]) + img = np.array(Image.fromarray(img).resize((width, height))) return img