diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..5cb922e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+outdir/
+__pycache__
+*.egg-info/
+*.swp
diff --git a/.pylintrc b/.pylintrc
new file mode 100644
index 0000000..82d676a
--- /dev/null
+++ b/.pylintrc
@@ -0,0 +1,20 @@
+[MESSAGES CONTROL]
+
+disable=
+ bad-continuation,
+ bad-whitespace,
+ invalid-name,
+ no-else-return,
+ superfluous-parens,
+ too-few-public-methods,
+ trailing-newlines,
+ duplicate-code,
+ missing-function-docstring,
+ missing-class-docstring,
+ missing-module-docstring,
+ consider-using-f-string,
+
+[TYPECHECK]
+generated-members=
+ torch
+
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..9d0fff5
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,26 @@
+Copyright (c) 2021-2025, The LS4GAN Project Developers
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+POSSIBILITY OF SUCH DAMAGE.
+
+The above copyright and license notices apply to all files in this repository
+except for any file that contains its own copyright and/or license declaration.
diff --git a/README.md b/README.md
index cf5ed5b..023e838 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,308 @@
-The code will be available shortly.
+# UVCGAN-S: Stratified CycleGAN for Unsupervised Data Decomposition
+
+
+
+## Overview
+
+
+This repository provides a reference implementation of Stratified CycleGAN
+(UVCGAN-S), an architecture for unsupervised signal extraction from mixed data.
+
+What problem does Stratified CycleGAN solve?
+
+Imagine you have three datasets. The first contains clean signals. The second
+contains backgrounds. The third contains mixed data where signals and
+backgrounds have been combined in some complicated way. You don't know exactly
+how the mixing happens, and you can't find pairs that show which clean signal
+corresponds to which mixed observation. But you need to take new mixed data and
+decompose it back into signal and background components.
+
+Stratified CycleGAN learns to do this decomposition from unpaired examples.
+You show it random samples of signals, random samples of backgrounds and mixed
+data, and it figures out both how to combine signals and backgrounds into
+realistic mixed data, and how to decompose mixed data back into its parts.
+
+
+
+
+
+
+See the [Quick Start](#quick-start-guide) section for a concrete example
+using cat and dog images.
+
+
+## Installation
+
+The package was tested only under Linux systems.
+
+
+### Environment Setup
+
+Development environment based on
+`pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime` container.
+
+There are several ways to setup the package environment:
+
+**Option 1: Docker**
+
+Download the docker container `pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime`.
+Inside the container, create a virtual environment to avoid package conflicts:
+```bash
+python3 -m venv --system-site-packages ~/.venv/uvcgan-s
+source ~/.venv/uvcgan-s/bin/activate
+```
+
+**Option 2: Conda**
+```bash
+conda env create -f contrib/conda_env.yaml
+conda activate uvcgan-s
+```
+
+### Install Package
+
+Once the environment is set, install the `uvcgan-s` package and its
+requirements:
+
+```bash
+pip install -r requirements.txt
+pip install -e .
+```
+
+### Environment Variables
+
+By default, UVCGAN-S reads datasets from `./data` and saves models to
+`./outdir`. If any other location is desired, these defaults can be overriden
+with:
+```bash
+export UVCGAN_S_DATA=/path/to/datasets
+export UVCGAN_S_OUTDIR=/path/to/models
+```
+
+## Quick Start Guide
+
+This package was developed for sPHENIX jet signal extraction. However, jet
+signal analysis requires familiarity with sPHENIX-specific reconstruction
+algorithms and jet quality analysis procedures. This section demonstrates
+application of Stratified CycleGAN method on a simpler toy problem with
+intuitive interpretation. The toy example illustrates the basic workflow and
+serves as a template for applying the method to your own data.
+
+
+
+
+
+The toy problem is this: we have images that contain a blurry mix of cat and
+dog faces. The goal is to automatically decompose these mixed images into
+separate cat and dog images. To this end, we present Stratified CycleGAN with
+the mixed images, a random sample of cat images, and a random sample of dog
+images. By observing these three collections, the model learns how cats look on
+average, how dogs look on average, and what would be the best way to decompose
+the current mixed image into a cat and a dog. Importantly, the model is never
+shown training pairs like "this specific mixed image was created from this
+specific cat image and this specific dog image." It only sees random examples
+from each collection and figures out the decomposition on its own.
+
+### Installation
+
+Before proceeding further, the package needs to be installed following
+instructions at the top of this README, if not installed already.
+
+### Dataset Preparation
+
+The toy example uses cat and dog images from the AFHQ dataset. To download and
+preprocess it:
+
+```bash
+# Download the AFHQ dataset
+./scripts/download_dataset.sh afhq
+
+# Resize all images to 256x256 pixels
+python3 scripts/downsize_right.py -s 256 256 -i lanczos \
+ "${UVCGAN_S_DATA:-./data}/afhq/" \
+ "${UVCGAN_S_DATA:-./data}/afhq_resized_lanczos"
+```
+
+The resizing script creates a new directory `afhq_resized_lanczos` containing
+256x256 versions of all images, which is the format expected by the training
+script.
+
+### Training
+
+To train the model, run the following command:
+```bash
+python3 scripts/train/toy_mix_blur/train_uvcgan-s.py
+```
+
+The script trains the Stratified CycleGAN model for 100 epochs. On an RTX 3090
+GPU, each epoch takes approximately 3 minutes, so the complete training process
+requires about 5 hours. The trained model and intermediate checkpoints are
+saved in the directory
+`${UVCGAN_S_OUTDIR:-./outdir}/toy_mix_blur/uvcgan-s/model_m(uvcgan-s)_d(resnet)_g(vit-modnet)_cat_dog_sub/`.
+
+The structure of the model directory is described in the
+[F.A.Q.](#what-is-the-structure-of-a-model-directory).
+
+### Evaluation
+
+
+
+
+
+
+After training completes, the model can be used to decompose images from the
+validation set of AFHQ. Run:
+```bash
+python3 scripts/translate_images.py \
+ "${UVCGAN_S_OUTDIR:-./outdir}/toy_mix_blur/uvcgan-s/cat_dog_sub" \
+ --split val \
+ --domain 2 \
+ --format image
+```
+
+This command takes the mixture cat-dog images (Domain B) and decomposes them
+into separate cat and dog components (Domain A). The `--domain 2` flag
+specifies that the input images come from Domain B, which contains the mixed
+data.
+
+The results are saved in the model directory under
+`evals/final/translated(None)_domain(2)_eval-val/`. This evaluation directory
+contains several subdirectories:
+
+- `fake_a0/` - extracted cat components
+- `fake_a1/` - extracted dog components
+- `real_b/` - original blurred mixture inputs
+
+Each subdirectory contains numbered image files (`sample_0.png`,
+ `sample_1.png`, etc.) corresponding to the validation set.
+
+
+### Adapting to Your Own Data
+
+To apply Stratified CycleGAN to a different decomposition problem, use the toy
+example training script as a starting point. The script
+`scripts/train/toy_mix_blur/train_uvcgan-s.py` contains a declarative
+configuration showing how to structure the three required datasets and set up
+the domain structure for decomposition. For a more complex example, see
+`scripts/train/sphenix/train_uvcgan-s.py`.
+
+
+## sPHENIX Application: Jet Background Subtraction
+
+The package was developed for extracting particle jets from heavy-ion collision
+backgrounds in sPHENIX calorimeter data. This section describes how to
+reproduce the paper results.
+
+### Dataset
+
+The sPHENIX dataset can be downloaded from Zenodo: https://zenodo.org/records/17783990
+
+Alternatively, use the download script:
+```bash
+./scripts/download_dataset.sh sphenix
+```
+
+The dataset contains HDF5 files with calorimeter energy measurements organized
+as 24×64 eta-phi grids. The data is split into training, validation, and test
+sets. Training data consists of three components: PYTHIA jets (the signal
+component), HIJING minimum-bias events (the background component), and
+embedded PYTHIA+HIJING events (the mixed data). The test set uses JEWEL jets
+embedded in HIJING backgrounds. JEWEL models jet-medium interactions
+differently from PYTHIA, providing an out-of-distribution test of the model's
+generalization capability.
+
+
+### Training or Using Pre-trained Model
+
+There are two options for obtaining a trained model: training from scratch or
+downloading the pre-trained model from the paper.
+
+To train a new model from scratch:
+```bash
+python3 scripts/train/sphenix/train_uvcgan-s.py
+```
+
+The training configuration uses the same Stratified CycleGAN architecture as
+the toy example, adapted for single-channel calorimeter data. The trained model
+is saved in the directory
+`${UVCGAN_S_OUTDIR:-./outdir}/sphenix/uvcgan-s/model_m(uvcgan-s)_d(resnet)_g(vit-modnet)_sgn_bkg_sub`
+(see [F.A.Q.](#what-is-the-structure-of-a-model-directory) for details on the
+ model directory structure).
+
+Alternatively, a pre-trained model can be downloaded from Zenodo: https://zenodo.org/records/17809156
+
+The pre-trained model can be used directly for evaluation without retraining.
+
+
+### Evaluation
+
+To evaluate the model on the test set:
+```bash
+python3 scripts/translate_images.py \
+ "${UVCGAN_S_OUTDIR:-./outdir}/path/to/sphenix/model" \
+ --split test \
+ --domain 2 \
+ --format ndarray
+```
+
+The `--format ndarray` flag saves results as NumPy arrays rather than images.
+The output structure is similar to the toy example: extracted signal and
+background components are saved in separate directories under `evals/final/`.
+Each output file contains a 24×64 calorimeter energy grid that can be used for
+physics analysis.
+
+
+# F.A.Q.
+
+## I am training my model on a multi-GPU node. How to make sure that I use only one GPU?
+
+You can specify GPUs that `pytorch` will use with the help of the
+`CUDA_VISIBLE_DEVICES` environment variable. This variable can be set to a list
+of comma-separated GPU indices. When it is set, `pytorch` will only use GPUs
+whose IDs are in the `CUDA_VISIBLE_DEVICES`.
+
+
+## What is the structure of a model directory?
+
+`uvcgan-s` saves each model in a separate directory that contains:
+ - `MODEL/config.json` -- model architecture, training, and evaluation
+ configurations
+ - `MODEL/net_*.pth` -- PyTorch weights of model networks
+ - `MODEL/opt_*.pth` -- PyTorch weights of training optimizers
+ - `MODEL/shed_*.pth` -- PyTorch weights of training schedulers
+ - `MODEL/checkpoints/` -- training checkpoints
+ - `MODEL/evals/` -- evaluation results
+
+
+## Training fails with "Config collision detected" error
+
+`uvcgan-s` enforces a one-model-per-directory policy to prevent accidental
+overwrites of existing models. Each model directory must have a unique
+configuration - if you try to place a model with different settings in a
+directory that already contains a model, you'll receive a "Config collision
+detected" error.
+
+This safeguard helps prevent situations where you might accidentally lose
+trained models by starting a new training run with different parameters in the
+same directory.
+
+Solutions:
+1. To overwrite the old model: delete the old `config.json` configuration file
+ and restart the training process.
+2. To preserve the old model: modify the training script of the new model and
+ update the `label` or `outdir` configuration options to avoid collisions.
+
+
+# LICENSE
+
+`uvcgan-s` is distributed under `BSD-2` license.
+
+`uvcgan-s` repository contains some code (primarily in `uvcgan_s/base`
+subdirectory) from [pytorch-CycleGAN-and-pix2pix][cyclegan_repo].
+This code is also licensed under `BSD-2` license (please refer to
+`uvcgan_s/base/LICENSE` for details).
+
+Each code snippet that was taken from
+[pytorch-CycleGAN-and-pix2pix][cyclegan_repo] has a note about proper copyright
+attribution.
+
+[cyclegan_repo]: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
diff --git a/contrib/conda_env.yaml b/contrib/conda_env.yaml
new file mode 100644
index 0000000..0f8e8a8
--- /dev/null
+++ b/contrib/conda_env.yaml
@@ -0,0 +1,119 @@
+name: uvcgan-s
+channels:
+ - pytorch
+ - nvidia
+ - defaults
+ - conda-forge
+dependencies:
+ - _libgcc_mutex=0.1
+ - _openmp_mutex=4.5
+ - backcall=0.2.0
+ - beautifulsoup4=4.11.1
+ - blas=1.0
+ - brotlipy=0.7.0
+ - bzip2=1.0.8
+ - ca-certificates=2022.07.19
+ - certifi=2022.6.15
+ - cffi=1.15.0
+ - chardet=4.0.0
+ - charset-normalizer=2.0.4
+ - colorama=0.4.4
+ - conda=4.13.0
+ - conda-build=3.21.9
+ - conda-content-trust=0.1.1
+ - conda-package-handling=1.8.1
+ - cryptography=37.0.1
+ - cudatoolkit=11.3.1
+ - decorator=5.1.1
+ - ffmpeg=4.3
+ - filelock=3.6.0
+ - freetype=2.11.0
+ - giflib=5.2.1
+ - glob2=0.7
+ - gmp=6.2.1
+ - gnutls=3.6.15
+ - icu=58.2
+ - idna=3.3
+ - intel-openmp=2021.4.0
+ - ipython=7.31.1
+ - jedi=0.18.1
+ - jinja2=2.10.1
+ - jpeg=9e
+ - lame=3.100
+ - lcms2=2.12
+ - ld_impl_linux-64=2.35.1
+ - libarchive=3.5.2
+ - libffi=3.3
+ - libgcc-ng=9.3.0
+ - libgomp=9.3.0
+ - libiconv=1.16
+ - libidn2=2.3.2
+ - liblief=0.11.5
+ - libpng=1.6.37
+ - libstdcxx-ng=9.3.0
+ - libtasn1=4.16.0
+ - libtiff=4.2.0
+ - libunistring=0.9.10
+ - libwebp=1.2.2
+ - libwebp-base=1.2.2
+ - libxml2=2.9.14
+ - lz4-c=1.9.3
+ - markupsafe=2.0.1
+ - matplotlib-inline=0.1.2
+ - mkl=2021.4.0
+ - mkl-service=2.4.0
+ - mkl_fft=1.3.1
+ - mkl_random=1.2.2
+ - ncurses=6.3
+ - nettle=3.7.3
+ - numpy=1.21.5
+ - numpy-base=1.21.5
+ - openh264=2.1.1
+ - openssl=1.1.1q
+ - parso=0.8.3
+ - patchelf=0.13
+ - pexpect=4.8.0
+ - pickleshare=0.7.5
+ - pillow=9.0.1
+ - pip=22.1.2
+ - pkginfo=1.8.2
+ - prompt-toolkit=3.0.20
+ - psutil=5.8.0
+ - ptyprocess=0.7.0
+ - py-lief=0.11.5
+ - pycosat=0.6.3
+ - pycparser=2.21
+ - pygments=2.11.2
+ - pyopenssl=22.0.0
+ - pysocks=1.7.1
+ - python=3.7.13
+ - python-libarchive-c=2.9
+ - pytorch=1.12.1
+ - pytorch-mutex=1.0
+ - pytz=2022.1
+ - pyyaml=6.0
+ - readline=8.1.2
+ - requests=2.27.1
+ - ripgrep=12.1.1
+ - ruamel_yaml=0.15.100
+ - setuptools=61.2.0
+ - six=1.16.0
+ - soupsieve=2.3.1
+ - sqlite=3.38.2
+ - tk=8.6.11
+ - torchtext=0.13.1
+ - torchvision=0.13.1
+ - tqdm=4.63.0
+ - traitlets=5.1.1
+ - typing_extensions=4.3.0
+ - tzdata=2022a
+ - urllib3=1.26.8
+ - wcwidth=0.2.5
+ - wheel=0.37.1
+ - xz=5.2.5
+ - yaml=0.2.5
+ - zlib=1.2.12
+ - zstd=1.5.2
+ - einops=0.4.*
+ - h5py=3.6.*
+ - pandas=1.3.*
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..ca95049
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,8 @@
+einops==0.4.1
+h5py==3.6.0
+numpy==1.21.5
+pandas==1.3.3
+Pillow==9.0.1
+torch==1.12.1
+torchvision==0.13.1
+tqdm==4.63.0
diff --git a/scripts/download_dataset.sh b/scripts/download_dataset.sh
new file mode 100755
index 0000000..7300076
--- /dev/null
+++ b/scripts/download_dataset.sh
@@ -0,0 +1,166 @@
+#!/usr/bin/env bash
+
+DATADIR="${UVCGAN_S_DATA:-data}"
+
+declare -A URL_LIST=(
+ [afhq]="https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip"
+ [sphenix]="https://zenodo.org/record/17783990/files/2025-06-05_jet_bkg_sub.tar"
+)
+
+declare -A CHECKSUMS=(
+ [afhq]="7f63dcc14ef58c0e849b59091287e1844da97016073aac20403ae6c6132b950f"
+)
+
+die ()
+{
+ echo "${*}"
+ exit 1
+}
+
+usage ()
+{
+ cat < 0: # create an empty pool
+ self.num_imgs = 0
+ self.images = []
+
+ def query(self, images):
+ """Return an image from the pool.
+
+ Parameters:
+ images: the latest generated images from the generator
+
+ Returns images from the buffer.
+
+ By 50/100, the buffer will return input images.
+ By 50/100, the buffer will return images previously stored in the buffer,
+ and insert the current images to the buffer.
+ """
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
+ return images
+ return_images = []
+ for image in images:
+ image = torch.unsqueeze(image.data, 0)
+ if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
+ self.num_imgs = self.num_imgs + 1
+ self.images.append(image)
+ return_images.append(image)
+ else:
+ p = random.uniform(0, 1)
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
+ tmp = self.images[random_id].clone()
+ self.images[random_id] = image
+ return_images.append(tmp)
+ else: # by another 50% chance, the buffer will return the current image
+ return_images.append(image)
+ return_images = torch.cat(return_images, 0) # collect all the images and return
+ return return_images
diff --git a/uvcgan_s/base/losses.py b/uvcgan_s/base/losses.py
new file mode 100644
index 0000000..025b03a
--- /dev/null
+++ b/uvcgan_s/base/losses.py
@@ -0,0 +1,200 @@
+import torch
+from torch import nn
+
+def reduce_loss(loss, reduction):
+ if (reduction is None) or (reduction == 'none'):
+ return loss
+
+ if reduction == 'mean':
+ return loss.mean()
+
+ if reduction == 'sum':
+ return loss.sum()
+
+ raise ValueError(f"Unknown reduction method: '{reduction}'")
+
+class GANLoss(nn.Module):
+ """Define different GAN objectives.
+
+ The GANLoss class abstracts away the need to create the target label tensor
+ that has the same size as the input.
+ """
+
+ def __init__(
+ self, gan_mode, target_real_label = 1.0, target_fake_label = 0.0,
+ reduction = 'mean'
+ ):
+ """ Initialize the GANLoss class.
+
+ Parameters:
+ gan_mode (str) -- the type of GAN objective.
+ Choices: vanilla, lsgan, and wgangp.
+ target_real_label (bool) -- label for a real image
+ target_fake_label (bool) -- label of a fake image
+
+ Note: Do not use sigmoid as the last layer of Discriminator.
+ LSGAN needs no sigmoid. Vanilla GANs will handle it with
+ BCEWithLogitsLoss.
+ """
+ super().__init__()
+
+ # pylint: disable=not-callable
+ self.register_buffer('real_label', torch.tensor(target_real_label))
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
+
+ self.gan_mode = gan_mode
+ self.reduction = reduction
+
+ if gan_mode == 'lsgan':
+ self.loss = nn.MSELoss(reduction = reduction)
+
+ elif gan_mode == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss(reduction = reduction)
+
+ elif gan_mode == 'softplus':
+ self.loss = nn.Softplus()
+
+ elif gan_mode == 'wgan':
+ self.loss = None
+
+ else:
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
+
+ def get_target_tensor(self, prediction, target_is_real):
+ """Create label tensors with the same size as the input.
+
+ Parameters:
+ prediction (tensor) -- tpyically the prediction from a
+ discriminator
+ target_is_real (bool) -- if the ground truth label is for real
+ images or fake images
+
+ Returns:
+ A label tensor filled with ground truth label, and with the size of
+ the input
+ """
+
+ if target_is_real:
+ target_tensor = self.real_label
+ else:
+ target_tensor = self.fake_label
+ return target_tensor.expand_as(prediction)
+
+ def eval_wgan_loss(self, prediction, target_is_real):
+ if target_is_real:
+ result = -prediction.mean()
+ else:
+ result = prediction.mean()
+
+ return reduce_loss(result, self.reduction)
+
+ def eval_softplus_loss(self, prediction, target_is_real):
+ if target_is_real:
+ result = self.loss(prediction)
+ else:
+ result = self.loss(-prediction)
+
+ return reduce_loss(result, self.reduction)
+
+ def forward(self, prediction, target_is_real):
+ """Calculate loss given Discriminator's output and grount truth labels.
+
+ Parameters:
+ prediction (tensor) -- tpyically the prediction output from a
+ discriminator
+ target_is_real (bool) -- if the ground truth label is for real
+ images or fake images
+
+ Returns:
+ the calculated loss.
+ """
+
+ if isinstance(prediction, (list, tuple)):
+ result = sum(self.forward(x, target_is_real) for x in prediction)
+ return result / len(prediction)
+
+ if self.gan_mode == 'wgan':
+ return self.eval_wgan_loss(prediction, target_is_real)
+
+ if self.gan_mode == 'softplus':
+ return self.eval_softplus_loss(prediction, target_is_real)
+
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
+ return self.loss(prediction, target_tensor)
+
+# pylint: disable=too-many-arguments
+# pylint: disable=redefined-builtin
+def cal_gradient_penalty(
+ netD, real_data, fake_data, device,
+ type = 'mixed', constant = 1.0, lambda_gp = 10.0
+):
+ """Calculate the gradient penalty loss, used in WGAN-GP
+
+ source: https://arxiv.org/abs/1704.00028
+
+ Arguments:
+ netD (network) -- discriminator network
+ real_data (tensor array) -- real images
+ fake_data (tensor array) -- generated images from the generator
+ device (str) -- torch device
+ type (str) -- if we mix real and fake data or not
+ Choices: [real | fake | mixed].
+ constant (float) -- the constant used in formula:
+ (||gradient||_2 - constant)^2
+ lambda_gp (float) -- weight for this loss
+
+ Returns the gradient penalty loss
+ """
+ if lambda_gp == 0.0:
+ return 0.0, None
+
+ if type == 'real':
+ interpolatesv = real_data
+ elif type == 'fake':
+ interpolatesv = fake_data
+ elif type == 'mixed':
+ alpha = torch.rand(real_data.shape[0], 1, device = device)
+ alpha = alpha.expand(
+ real_data.shape[0], real_data.nelement() // real_data.shape[0]
+ ).contiguous().view(*real_data.shape)
+
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
+ else:
+ raise NotImplementedError('{} not implemented'.format(type))
+
+ interpolatesv.requires_grad_(True)
+ disc_interpolates = netD(interpolatesv)
+
+ gradients = torch.autograd.grad(
+ outputs=disc_interpolates, inputs=interpolatesv,
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+ create_graph=True, retain_graph=True, only_inputs=True
+ )
+
+ gradients = gradients[0].view(real_data.size(0), -1)
+
+ gradient_penalty = (
+ ((gradients + 1e-16).norm(2, dim=1) - constant) ** 2
+ ).mean() * lambda_gp
+
+ return gradient_penalty, gradients
+
+def calc_zero_gp(model, x, **model_kwargs):
+ x.requires_grad_(True)
+ y = model(x, **model_kwargs)
+
+ grad = torch.autograd.grad(
+ outputs = y,
+ inputs = x,
+ grad_outputs = torch.ones(y.size()).to(y.device),
+ create_graph = True,
+ retain_graph = True,
+ only_inputs = True
+ )
+
+ grad = grad[0].view(x.shape[0], -1)
+ # NOTE: 1/2 for backward compatibility
+ gp = 1/2 * torch.sum(grad.square(), dim = 1).mean()
+
+ return gp, grad
+
diff --git a/uvcgan_s/base/networks.py b/uvcgan_s/base/networks.py
new file mode 100644
index 0000000..e70024c
--- /dev/null
+++ b/uvcgan_s/base/networks.py
@@ -0,0 +1,411 @@
+# pylint: disable=line-too-long
+# pylint: disable=redefined-builtin
+# pylint: disable=too-many-arguments
+# pylint: disable=unidiomatic-typecheck
+# pylint: disable=super-with-arguments
+
+import functools
+
+import torch
+from torch import nn
+
+class Identity(nn.Module):
+ # pylint: disable=no-self-use
+ def forward(self, x):
+ return x
+
+def get_norm_layer(norm_type='instance'):
+ """Return a normalization layer
+
+ Parameters:
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
+
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
+ """
+ if norm_type == 'batch':
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
+ elif norm_type == 'instance':
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
+ elif norm_type == 'none':
+ norm_layer = lambda _features : Identity()
+ else:
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
+
+ return norm_layer
+
+def join_args(a, b):
+ return { **a, **b }
+
+def select_base_generator(model, **kwargs):
+ default_args = dict(norm = 'instance', use_dropout = False, ngf = 64)
+ kwargs = join_args(default_args, kwargs)
+
+ if model == 'resnet_9blocks':
+ return ResnetGenerator(n_blocks = 9, **kwargs)
+
+ if model == 'resnet_6blocks':
+ return ResnetGenerator(n_blocks = 6, **kwargs)
+
+ if model == 'unet_128':
+ return UnetGenerator(num_downs = 7, **kwargs)
+
+ if model == 'unet_256':
+ return UnetGenerator(num_downs = 8, **kwargs)
+
+ raise ValueError("Unknown generator: %s" % model)
+
+def select_base_discriminator(model, **kwargs):
+ default_args = dict(norm = 'instance', ndf = 64)
+ kwargs = join_args(default_args, kwargs)
+
+ if model == 'basic':
+ return NLayerDiscriminator(n_layers = 3, **kwargs)
+
+ if model == 'n_layers':
+ return NLayerDiscriminator(**kwargs)
+
+ if model == 'pixel':
+ return PixelDiscriminator(**kwargs)
+
+ raise ValueError("Unknown discriminator: %s" % model)
+
+class ResnetGenerator(nn.Module):
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
+
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+
+ def __init__(self, image_shape, ngf=64, norm = 'batch', use_dropout=False, n_blocks=6, padding_type='reflect'):
+ """Construct a Resnet-based generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+
+ assert n_blocks >= 0
+ super().__init__()
+
+ norm_layer = get_norm_layer(norm_type = norm)
+
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model = [nn.ReflectionPad2d(3),
+ nn.Conv2d(image_shape[0], ngf, kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+
+ mult = 2 ** n_downsampling
+ for i in range(n_blocks): # add ResNet blocks
+
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(ngf, image_shape[0], kernel_size=7, padding=0)]
+
+ if image_shape[0] == 3:
+ model.append(nn.Sigmoid())
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class ResnetBlock(nn.Module):
+ """Define a Resnet block"""
+
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
+ """Initialize the Resnet block
+
+ A resnet block is a conv block with skip connections
+ We construct a conv block with build_conv_block function,
+ and implement skip connections in function.
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
+ """
+ super().__init__()
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
+
+ # pylint: disable=no-self-use
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
+ """Construct a convolutional block.
+
+ Parameters:
+ dim (int) -- the number of channels in the conv layer.
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ use_bias (bool) -- if the conv layer uses bias or not
+
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
+ """
+ conv_block = []
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
+
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ """Forward function (with skip connections)"""
+ out = x + self.conv_block(x) # add skip connections
+ return out
+
+
+class UnetGenerator(nn.Module):
+ """Create a Unet-based generator"""
+
+ def __init__(self, image_shape, num_downs, ngf=64, norm = 'batch', use_dropout=False):
+ """Construct a Unet generator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
+ image of size 128x128 will become of size 1x1 # at the bottleneck
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+
+ We construct the U-Net from the innermost layer to the outermost layer.
+ It is a recursive process.
+ """
+ super(UnetGenerator, self).__init__()
+ norm_layer = get_norm_layer(norm_type=norm)
+
+ # construct unet structure
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
+ for _i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
+ # gradually reduce the number of filters from ngf * 8 to ngf
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ self.model = UnetSkipConnectionBlock(image_shape[0], ngf, input_nc=image_shape[0], submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class UnetSkipConnectionBlock(nn.Module):
+ """Defines the Unet submodule with skip connection.
+ X -------------------identity----------------------
+ |-- downsampling -- |submodule| -- upsampling --|
+ """
+
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet submodule with skip connections.
+
+ Parameters:
+ outer_nc (int) -- the number of filters in the outer conv layer
+ inner_nc (int) -- the number of filters in the inner conv layer
+ input_nc (int) -- the number of channels in input images/features
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
+ outermost (bool) -- if this module is the outermost module
+ innermost (bool) -- if this module is the innermost module
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ """
+ # pylint: disable=too-many-locals
+ super().__init__()
+ self.outermost = outermost
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+ if input_nc is None:
+ input_nc = outer_nc
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
+ stride=2, padding=1, bias=use_bias)
+ downrelu = nn.LeakyReLU(0.2, True)
+ downnorm = norm_layer(inner_nc)
+ uprelu = nn.ReLU(True)
+ upnorm = norm_layer(outer_nc)
+
+ if outermost:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1)
+ down = [downconv]
+ up = [uprelu, upconv ]
+
+ if outer_nc == 3:
+ up.append(nn.Sigmoid())
+
+ model = down + [submodule] + up
+ elif innermost:
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv]
+ up = [uprelu, upconv, upnorm]
+ model = down + up
+ else:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv, downnorm]
+ up = [uprelu, upconv, upnorm]
+
+ if use_dropout:
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
+ else:
+ model = down + [submodule] + up
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ if self.outermost:
+ return self.model(x)
+ else: # add skip connections
+ return torch.cat([x, self.model(x)], 1)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator"""
+
+ def __init__(
+ self, image_shape, ndf=64, n_layers=3, norm='batch', max_mult=8,
+ shrink_output = True, return_intermediate_activations = False
+ ):
+ # pylint: disable=too-many-locals
+ """Construct a PatchGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+
+ norm_layer = get_norm_layer(norm_type = norm)
+
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(image_shape[0], ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, max_mult)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, max_mult)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ self.model = nn.Sequential(*sequence)
+ self.shrink_conv = None
+
+ if shrink_output:
+ self.shrink_conv = nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
+
+ self._intermediate = return_intermediate_activations
+
+ def forward(self, input):
+ """Standard forward."""
+ z = self.model(input)
+
+ if self.shrink_conv is None:
+ return z
+
+ y = self.shrink_conv(z)
+
+ if self._intermediate:
+ return (y, z)
+
+ return y
+
+class PixelDiscriminator(nn.Module):
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
+
+ def __init__(self, image_shape, ndf=64, norm='batch'):
+ """Construct a 1x1 PatchGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ """
+ super(PixelDiscriminator, self).__init__()
+
+ norm_layer = get_norm_layer(norm_type=norm)
+
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ self.net = [
+ nn.Conv2d(image_shape[0], ndf, kernel_size=1, stride=1, padding=0),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
+ norm_layer(ndf * 2),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
+
+ self.net = nn.Sequential(*self.net)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.net(input)
diff --git a/uvcgan_s/base/schedulers.py b/uvcgan_s/base/schedulers.py
new file mode 100644
index 0000000..b133123
--- /dev/null
+++ b/uvcgan_s/base/schedulers.py
@@ -0,0 +1,61 @@
+from torch.optim import lr_scheduler
+from uvcgan_s.torch.select import extract_name_kwargs
+
+def linear_scheduler(optimizer, epochs_warmup, epochs_anneal, verbose = True):
+
+ def lambda_rule(epoch, epochs_warmup, epochs_anneal):
+ if epoch < epochs_warmup:
+ return 1.0
+
+ return 1.0 - (epoch - epochs_warmup) / (epochs_anneal + 1)
+
+ lr_fn = lambda epoch : lambda_rule(epoch, epochs_warmup, epochs_anneal)
+
+ return lr_scheduler.LambdaLR(optimizer, lr_fn, verbose = verbose)
+
+SCHED_DICT = {
+ 'step' : lr_scheduler.StepLR,
+ 'plateau' : lr_scheduler.ReduceLROnPlateau,
+ 'cosine' : lr_scheduler.CosineAnnealingLR,
+ 'cosine-restarts' : lr_scheduler.CosineAnnealingWarmRestarts,
+ 'constant' : lr_scheduler.ConstantLR,
+ # lr scheds below are for backward compatibility
+ 'linear' : linear_scheduler,
+ 'linear-v2' : lr_scheduler.LinearLR,
+ 'CosineAnnealingWarmRestarts' : lr_scheduler.CosineAnnealingWarmRestarts,
+}
+
+def select_single_scheduler(optimizer, scheduler):
+ if scheduler is None:
+ return None
+
+ name, kwargs = extract_name_kwargs(scheduler)
+ kwargs['verbose'] = True
+
+ if name not in SCHED_DICT:
+ raise ValueError(
+ f"Unknown scheduler: '{name}'. Supported: {SCHED_DICT.keys()}"
+ )
+
+ return SCHED_DICT[name](optimizer, **kwargs)
+
+def select_scheduler(optimizer, scheduler, compose = False):
+ if scheduler is None:
+ return None
+
+ if not isinstance(scheduler, (list, tuple)):
+ scheduler = [ scheduler, ]
+
+ result = [ select_single_scheduler(optimizer, x) for x in scheduler ]
+
+ if compose:
+ if len(result) == 1:
+ return result[0]
+ else:
+ return lr_scheduler.ChainedScheduler(result)
+ else:
+ return result
+
+def get_scheduler(optimizer, scheduler):
+ return select_scheduler(optimizer, scheduler, compose = True)
+
diff --git a/uvcgan_s/base/weight_init.py b/uvcgan_s/base/weight_init.py
new file mode 100644
index 0000000..2f8ce81
--- /dev/null
+++ b/uvcgan_s/base/weight_init.py
@@ -0,0 +1,49 @@
+import logging
+from torch.nn import init
+
+from uvcgan_s.torch.select import extract_name_kwargs
+
+LOGGER = logging.getLogger('uvcgan_s.base')
+
+def winit_func(m, init_type = 'normal', init_gain = 0.2):
+ classname = m.__class__.__name__
+
+ if (
+ hasattr(m, 'weight')
+ and (classname.find('Conv') != -1 or classname.find('Linear') != -1)
+ ):
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, init_gain)
+
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain = init_gain)
+
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
+
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain = init_gain)
+
+ else:
+ raise NotImplementedError(
+ 'Initialization method [%s] is not implemented' % init_type
+ )
+
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+
+ elif classname.find('BatchNorm2d') != -1:
+ init.normal_(m.weight.data, 1.0, init_gain)
+ init.constant_(m.bias.data, 0.0)
+
+def init_weights(net, weight_init):
+ if weight_init is None:
+ return
+
+ name, kwargs = extract_name_kwargs(weight_init)
+
+ LOGGER.debug('Initializnig network with %s', name)
+ net.apply(
+ lambda m, name=name, kwargs=kwargs : winit_func(m, name, **kwargs)
+ )
+
diff --git a/uvcgan_s/cgan/__init__.py b/uvcgan_s/cgan/__init__.py
new file mode 100644
index 0000000..21eeffa
--- /dev/null
+++ b/uvcgan_s/cgan/__init__.py
@@ -0,0 +1,30 @@
+from .cyclegan import CycleGANModel
+from .pix2pix import Pix2PixModel
+from .autoencoder import Autoencoder
+from .simple_autoencoder import SimpleAutoencoder
+from .uvcgan2 import UVCGAN2
+from .uvcgan_s import UVCGAN_S
+
+CGAN_MODELS = {
+ 'cyclegan' : CycleGANModel,
+ 'pix2pix' : Pix2PixModel,
+ 'autoencoder' : Autoencoder,
+ 'simple-autoencoder' : SimpleAutoencoder,
+ 'uvcgan-v2' : UVCGAN2,
+ 'uvcgan-s' : UVCGAN_S,
+}
+
+def select_model(name, **kwargs):
+ if name not in CGAN_MODELS:
+ raise ValueError("Unknown model: %s" % name)
+
+ return CGAN_MODELS[name](**kwargs)
+
+def construct_model(savedir, config, is_train, device):
+ model = select_model(
+ config.model, savedir = savedir, config = config, is_train = is_train,
+ device = device, **config.model_args
+ )
+
+ return model
+
diff --git a/uvcgan_s/cgan/autoencoder.py b/uvcgan_s/cgan/autoencoder.py
new file mode 100644
index 0000000..f3bcf88
--- /dev/null
+++ b/uvcgan_s/cgan/autoencoder.py
@@ -0,0 +1,166 @@
+# pylint: disable=not-callable
+# NOTE: Mistaken lint:
+# E1102: self.encoder is not callable (not-callable)
+from uvcgan_s.torch.select import select_optimizer, select_loss
+from uvcgan_s.torch.background_penalty import BackgroundPenaltyReduction
+from uvcgan_s.torch.image_masking import select_masking
+from uvcgan_s.models.generator import construct_generator
+
+from .model_base import ModelBase
+from .named_dict import NamedDict
+from .funcs import set_two_domain_input
+
+class Autoencoder(ModelBase):
+
+ def _setup_images(self, _config):
+ images = [ 'real_a', 'reco_a', 'real_b', 'reco_b', ]
+
+ if self.masking is not None:
+ images += [ 'masked_a', 'masked_b' ]
+
+ return NamedDict(*images)
+
+ def _setup_models(self, config):
+ if self.joint:
+ image_shape = config.data.datasets[0].shape
+
+ assert image_shape == config.data.datasets[1].shape, (
+ "Joint autoencoder requires all datasets to have "
+ "the same image shape"
+ )
+
+ return NamedDict(
+ encoder = construct_generator(
+ config.generator, image_shape, image_shape, self.device
+ )
+ )
+
+ models = NamedDict('encoder_a', 'encoder_b')
+ models.encoder_a = construct_generator(
+ config.generator,
+ config.data.datasets[0].shape,
+ config.data.datasets[0].shape,
+ self.device
+ )
+ models.encoder_b = construct_generator(
+ config.generator,
+ config.data.datasets[1].shape,
+ config.data.datasets[1].shape,
+ self.device
+ )
+
+ return models
+
+ def _setup_losses(self, config):
+ self.loss_fn = select_loss(config.loss)
+
+ assert config.gradient_penalty is None, \
+ "Autoencoder model does not support gradient penalty"
+
+ return NamedDict('loss_a', 'loss_b')
+
+ def _setup_optimizers(self, config):
+ if self.joint:
+ return NamedDict(
+ encoder = select_optimizer(
+ self.models.encoder.parameters(),
+ config.generator.optimizer
+ )
+ )
+
+ optimizers = NamedDict('encoder_a', 'encoder_b')
+
+ optimizers.encoder_a = select_optimizer(
+ self.models.encoder_a.parameters(), config.generator.optimizer
+ )
+ optimizers.encoder_b = select_optimizer(
+ self.models.encoder_b.parameters(), config.generator.optimizer
+ )
+
+ return optimizers
+
+ def __init__(
+ self, savedir, config, is_train, device,
+ joint = False, background_penalty = None, masking = None
+ ):
+ # pylint: disable=too-many-arguments
+ self.joint = joint
+ self.masking = select_masking(masking)
+
+ assert len(config.data.datasets) == 2, \
+ "Autoencoder expects a pair of datasets"
+
+ super().__init__(savedir, config, is_train, device)
+
+ if background_penalty is None:
+ self.background_penalty = None
+ else:
+ self.background_penalty = BackgroundPenaltyReduction(
+ **background_penalty
+ )
+
+ assert config.discriminator is None, \
+ "Autoencoder model does not use discriminator"
+
+ def _handle_epoch_end(self):
+ if self.background_penalty is not None:
+ self.background_penalty.end_epoch(self.epoch)
+
+ def _set_input(self, inputs, domain):
+ set_two_domain_input(self.images, inputs, domain, self.device)
+
+ def forward(self):
+ input_a = self.images.real_a
+ input_b = self.images.real_b
+
+ if self.masking is not None:
+ if input_a is not None:
+ input_a = self.masking(input_a)
+
+ if input_b is not None:
+ input_b = self.masking(input_b)
+
+ self.images.masked_a = input_a
+ self.images.masked_b = input_b
+
+ if input_a is not None:
+ if self.joint:
+ self.images.reco_a = self.models.encoder (input_a)
+ else:
+ self.images.reco_a = self.models.encoder_a(input_a)
+
+ if input_b is not None:
+ if self.joint:
+ self.images.reco_b = self.models.encoder (input_b)
+ else:
+ self.images.reco_b = self.models.encoder_b(input_b)
+
+ def backward_generator_base(self, real, reco):
+ if self.background_penalty is not None:
+ reco = self.background_penalty(reco, real)
+
+ loss = self.loss_fn(reco, real)
+ loss.backward()
+
+ return loss
+
+ def backward_generators(self):
+ self.losses.loss_b = self.backward_generator_base(
+ self.images.real_b, self.images.reco_b
+ )
+
+ self.losses.loss_a = self.backward_generator_base(
+ self.images.real_a, self.images.reco_a
+ )
+
+ def optimization_step(self):
+ self.forward()
+
+ for optimizer in self.optimizers.values():
+ optimizer.zero_grad()
+
+ self.backward_generators()
+
+ for optimizer in self.optimizers.values():
+ optimizer.step()
+
diff --git a/uvcgan_s/cgan/checkpoint.py b/uvcgan_s/cgan/checkpoint.py
new file mode 100644
index 0000000..70c6a77
--- /dev/null
+++ b/uvcgan_s/cgan/checkpoint.py
@@ -0,0 +1,73 @@
+import os
+import re
+import torch
+
+CHECKPOINTS_DIR = 'checkpoints'
+
+def find_last_checkpoint_epoch(savedir, prefix = None):
+ root = os.path.join(savedir, CHECKPOINTS_DIR)
+ if not os.path.exists(root):
+ return -1
+
+ if prefix is None:
+ r = re.compile(r'(\d+)_.*')
+ else:
+ r = re.compile(r'(\d+)_' + re.escape(prefix) + '_.*')
+
+ last_epoch = -1
+
+ for fname in os.listdir(root):
+ m = r.match(fname)
+ if m:
+ epoch = int(m.groups()[0])
+ last_epoch = max(last_epoch, epoch)
+
+ return last_epoch
+
+def get_save_path(savedir, name, epoch, mkdir = False):
+ if epoch is None:
+ fname = '%s.pth' % (name)
+ root = savedir
+ else:
+ fname = '%04d_%s.pth' % (epoch, name)
+ root = os.path.join(savedir, CHECKPOINTS_DIR)
+
+ result = os.path.join(root, fname)
+
+ if mkdir:
+ os.makedirs(root, exist_ok = True)
+
+ return result
+
+def save(named_dict, savedir, prefix, epoch = None):
+ for (k,v) in named_dict.items():
+ if v is None:
+ continue
+
+ save_path = get_save_path(
+ savedir, prefix + '_' + k, epoch, mkdir = True
+ )
+
+ if isinstance(v, torch.nn.DataParallel):
+ torch.save(v.module.state_dict(), save_path)
+ else:
+ torch.save(v.state_dict(), save_path)
+
+def load(named_dict, savedir, prefix, epoch, device):
+ for (k,v) in named_dict.items():
+ if v is None:
+ continue
+
+ load_path = get_save_path(
+ savedir, prefix + '_' + k, epoch, mkdir = False
+ )
+
+ if isinstance(v, torch.nn.DataParallel):
+ v.module.load_state_dict(
+ torch.load(load_path, map_location = device)
+ )
+ else:
+ v.load_state_dict(
+ torch.load(load_path, map_location = device)
+ )
+
diff --git a/uvcgan_s/cgan/cyclegan.py b/uvcgan_s/cgan/cyclegan.py
new file mode 100644
index 0000000..295674f
--- /dev/null
+++ b/uvcgan_s/cgan/cyclegan.py
@@ -0,0 +1,226 @@
+# pylint: disable=not-callable
+# NOTE: Mistaken lint:
+# E1102: self.criterion_gan is not callable (not-callable)
+
+import itertools
+import torch
+
+from uvcgan_s.torch.select import select_optimizer
+from uvcgan_s.base.image_pool import ImagePool
+from uvcgan_s.base.losses import GANLoss, cal_gradient_penalty
+from uvcgan_s.models.discriminator import construct_discriminator
+from uvcgan_s.models.generator import construct_generator
+
+from .model_base import ModelBase
+from .named_dict import NamedDict
+from .funcs import set_two_domain_input
+
+class CycleGANModel(ModelBase):
+ # pylint: disable=too-many-instance-attributes
+
+ def _setup_images(self, _config):
+ images = [ 'real_a', 'fake_b', 'reco_a', 'real_b', 'fake_a', 'reco_b' ]
+
+ if self.is_train and self.lambda_idt > 0:
+ images += [ 'idt_a', 'idt_b' ]
+
+ return NamedDict(*images)
+
+ def _setup_models(self, config):
+ models = {}
+
+ models['gen_ab'] = construct_generator(
+ config.generator,
+ config.data.datasets[0].shape,
+ config.data.datasets[1].shape,
+ self.device
+ )
+ models['gen_ba'] = construct_generator(
+ config.generator,
+ config.data.datasets[1].shape,
+ config.data.datasets[0].shape,
+ self.device
+ )
+
+ if self.is_train:
+ models['disc_a'] = construct_discriminator(
+ config.discriminator,
+ config.data.datasets[0].shape,
+ self.device
+ )
+ models['disc_b'] = construct_discriminator(
+ config.discriminator,
+ config.data.datasets[1].shape,
+ self.device
+ )
+
+ return NamedDict(**models)
+
+ def _setup_losses(self, config):
+ losses = [
+ 'gen_ab', 'gen_ba', 'cycle_a', 'cycle_b', 'disc_a', 'disc_b'
+ ]
+
+ if self.is_train and self.lambda_idt > 0:
+ losses += [ 'idt_a', 'idt_b' ]
+
+ return NamedDict(*losses)
+
+ def _setup_optimizers(self, config):
+ optimizers = NamedDict('gen', 'disc')
+
+ optimizers.gen = select_optimizer(
+ itertools.chain(
+ self.models.gen_ab.parameters(),
+ self.models.gen_ba.parameters()
+ ),
+ config.generator.optimizer
+ )
+
+ optimizers.disc = select_optimizer(
+ itertools.chain(
+ self.models.disc_a.parameters(),
+ self.models.disc_b.parameters()
+ ),
+ config.discriminator.optimizer
+ )
+
+ return optimizers
+
+ def __init__(
+ self, savedir, config, is_train, device, pool_size = 50,
+ lambda_a = 10.0, lambda_b = 10.0, lambda_idt = 0.5
+ ):
+ # pylint: disable=too-many-arguments
+ self.lambda_a = lambda_a
+ self.lambda_b = lambda_b
+ self.lambda_idt = lambda_idt
+
+ assert len(config.data.datasets) == 2, \
+ "CycleGAN expects a pair of datasets"
+
+ super().__init__(savedir, config, is_train, device)
+
+ self.criterion_gan = GANLoss(config.loss).to(self.device)
+ self.gradient_penalty = config.gradient_penalty
+ self.criterion_cycle = torch.nn.L1Loss()
+ self.criterion_idt = torch.nn.L1Loss()
+
+ if self.is_train:
+ self.pred_a_pool = ImagePool(pool_size)
+ self.pred_b_pool = ImagePool(pool_size)
+
+ def _set_input(self, inputs, domain):
+ set_two_domain_input(self.images, inputs, domain, self.device)
+
+ def forward(self):
+ def simple_fwd(batch, gen_fwd, gen_bkw):
+ if batch is None:
+ return (None, None)
+
+ fake = gen_fwd(batch)
+ reco = gen_bkw(fake)
+
+ return (fake, reco)
+
+ self.images.fake_b, self.images.reco_a = simple_fwd(
+ self.images.real_a, self.models.gen_ab, self.models.gen_ba
+ )
+
+ self.images.fake_a, self.images.reco_b = simple_fwd(
+ self.images.real_b, self.models.gen_ba, self.models.gen_ab
+ )
+
+ def backward_discriminator_base(self, model, real, fake):
+ pred_real = model(real)
+ loss_real = self.criterion_gan(pred_real, True)
+
+ #
+ # NOTE:
+ # This is a workaround to a pytorch 1.9.0 bug that manifests when
+ # cudnn is enabled. When the bug is solved remove no_grad block and
+ # replace `model(fake)` by `model(fake.detach())`.
+ #
+ # bug: https://github.com/pytorch/pytorch/issues/48439
+ #
+ with torch.no_grad():
+ fake = fake.contiguous()
+
+ pred_fake = model(fake)
+ loss_fake = self.criterion_gan(pred_fake, False)
+
+ loss = (loss_real + loss_fake) * 0.5
+
+ if self.gradient_penalty is not None:
+ loss += cal_gradient_penalty(
+ model, real, fake, real.device, **self.gradient_penalty
+ )[0]
+
+ loss.backward()
+ return loss
+
+ def backward_discriminators(self):
+ fake_a = self.pred_a_pool.query(self.images.fake_a)
+ fake_b = self.pred_b_pool.query(self.images.fake_b)
+
+ self.losses.disc_b = self.backward_discriminator_base(
+ self.models.disc_b, self.images.real_b, fake_b
+ )
+
+ self.losses.disc_a = self.backward_discriminator_base(
+ self.models.disc_a, self.images.real_a, fake_a
+ )
+
+ def backward_generators(self):
+ lambda_idt = self.lambda_idt
+ lambda_a = self.lambda_a
+ lambda_b = self.lambda_b
+
+ self.losses.gen_ab = self.criterion_gan(
+ self.models.disc_b(self.images.fake_b), True
+ )
+ self.losses.gen_ba = self.criterion_gan(
+ self.models.disc_a(self.images.fake_a), True
+ )
+ self.losses.cycle_a = lambda_a * self.criterion_cycle(
+ self.images.reco_a, self.images.real_a
+ )
+ self.losses.cycle_b = lambda_b * self.criterion_cycle(
+ self.images.reco_b, self.images.real_b
+ )
+
+ loss = (
+ self.losses.gen_ab + self.losses.gen_ba
+ + self.losses.cycle_a + self.losses.cycle_b
+ )
+
+ if lambda_idt > 0:
+ self.images.idt_b = self.models.gen_ab(self.images.real_b)
+ self.losses.idt_b = lambda_b * lambda_idt * self.criterion_idt(
+ self.images.idt_b, self.images.real_b
+ )
+
+ self.images.idt_a = self.models.gen_ba(self.images.real_a)
+ self.losses.idt_a = lambda_a * lambda_idt * self.criterion_idt(
+ self.images.idt_a, self.images.real_a
+ )
+
+ loss += (self.losses.idt_a + self.losses.idt_b)
+
+ loss.backward()
+
+ def optimization_step(self):
+ self.forward()
+
+ # Generators
+ self.set_requires_grad([self.models.disc_a, self.models.disc_b], False)
+ self.optimizers.gen.zero_grad()
+ self.backward_generators()
+ self.optimizers.gen.step()
+
+ # Discriminators
+ self.set_requires_grad([self.models.disc_a, self.models.disc_b], True)
+ self.optimizers.disc.zero_grad()
+ self.backward_discriminators()
+ self.optimizers.disc.step()
+
diff --git a/uvcgan_s/cgan/funcs.py b/uvcgan_s/cgan/funcs.py
new file mode 100644
index 0000000..e042d04
--- /dev/null
+++ b/uvcgan_s/cgan/funcs.py
@@ -0,0 +1,55 @@
+import torch
+
+def set_two_domain_input(images, inputs, domain, device):
+ if (domain is None) or (domain == 'both'):
+ images.real_a = inputs[0].to(device, non_blocking = True)
+ images.real_b = inputs[1].to(device, non_blocking = True)
+
+ elif domain in [ 'a', 0 ]:
+ images.real_a = inputs.to(device, non_blocking = True)
+
+ elif domain in [ 'b', 1 ]:
+ images.real_b = inputs.to(device, non_blocking = True)
+
+ else:
+ raise ValueError(
+ f"Unknown domain: '{domain}'."
+ " Supported domains: 'a' (alias 0), 'b' (alias 1), or 'both'"
+ )
+
+def set_asym_two_domain_input(images, inputs, domain, device):
+ if (domain is None) or (domain == 'all'):
+ images.real_a0 = inputs[0].to(device, non_blocking = True)
+ images.real_a1 = inputs[1].to(device, non_blocking = True)
+ images.real_b = inputs[2].to(device, non_blocking = True)
+
+ elif domain in [ 'a0', 0 ]:
+ images.real_a0 = inputs.to(device, non_blocking = True)
+
+ elif domain in [ 'a1', 1 ]:
+ images.real_a1 = inputs.to(device, non_blocking = True)
+
+ elif domain in [ 'b', 2 ]:
+ images.real_b = inputs.to(device, non_blocking = True)
+
+ else:
+ raise ValueError(
+ f"Unknown domain: '{domain}'."
+ " Supported domains: 'a0' (alias 0), 'a1' (alias 1), "
+ "'b' (alias 2), or 'all'"
+ )
+
+def trace_models(models, input_shapes, device):
+ result = {}
+
+ for (name, model) in models.items():
+ if name not in input_shapes:
+ continue
+
+ shape = input_shapes[name]
+ data = torch.randn((1, *shape)).to(device)
+
+ result[name] = torch.jit.trace(model, data)
+
+ return result
+
diff --git a/uvcgan_s/cgan/model_base.py b/uvcgan_s/cgan/model_base.py
new file mode 100644
index 0000000..bbae864
--- /dev/null
+++ b/uvcgan_s/cgan/model_base.py
@@ -0,0 +1,176 @@
+import logging
+import torch
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+
+from uvcgan_s.base.schedulers import get_scheduler
+from .named_dict import NamedDict
+from .checkpoint import find_last_checkpoint_epoch, save, load
+
+PREFIX_MODEL = 'net'
+PREFIX_OPT = 'opt'
+PREFIX_SCHED = 'sched'
+
+LOGGER = logging.getLogger('uvcgan_s.cgan')
+
+class ModelBase:
+ # pylint: disable=too-many-instance-attributes
+
+ def __init__(self, savedir, config, is_train, device):
+ self.is_train = is_train
+ self.device = device
+ self.savedir = savedir
+ self._config = config
+
+ self.models = self._setup_models(config)
+ self.images = self._setup_images(config)
+ self.losses = self._setup_losses(config)
+ self.metric = 0
+ self.epoch = 0
+
+ self.optimizers = NamedDict()
+ self.schedulers = NamedDict()
+
+ if is_train:
+ self.optimizers = self._setup_optimizers(config)
+ self.schedulers = self._setup_schedulers(config)
+
+ def set_input(self, inputs, domain = None):
+ for key in self.images:
+ self.images[key] = None
+
+ self._set_input(inputs, domain)
+
+ def forward(self):
+ raise NotImplementedError
+
+ def optimization_step(self):
+ raise NotImplementedError
+
+ def _set_input(self, inputs, domain):
+ raise NotImplementedError
+
+ def _setup_images(self, config):
+ raise NotImplementedError
+
+ def _setup_models(self, config):
+ raise NotImplementedError
+
+ def _setup_losses(self, config):
+ raise NotImplementedError
+
+ def _setup_optimizers(self, config):
+ raise NotImplementedError
+
+ def _setup_schedulers(self, config):
+ schedulers = { }
+
+ for (name, opt) in self.optimizers.items():
+ schedulers[name] = get_scheduler(opt, config.scheduler)
+
+ return NamedDict(**schedulers)
+
+ def _save_model_state(self, epoch):
+ pass
+
+ def _load_model_state(self, epoch):
+ pass
+
+ def _handle_epoch_end(self):
+ pass
+
+ def eval(self):
+ self.is_train = False
+
+ for model in self.models.values():
+ model.eval()
+
+ def train(self):
+ self.is_train = True
+
+ for model in self.models.values():
+ model.train()
+
+ def forward_nograd(self):
+ with torch.no_grad():
+ self.forward()
+
+ def find_last_checkpoint_epoch(self):
+ return find_last_checkpoint_epoch(self.savedir, PREFIX_MODEL)
+
+ def load(self, epoch):
+ if (epoch is not None) and (epoch <= 0):
+ return
+
+ LOGGER.debug('Loading model from epoch %s', epoch)
+
+ load(self.models, self.savedir, PREFIX_MODEL, epoch, self.device)
+ load(self.optimizers, self.savedir, PREFIX_OPT, epoch, self.device)
+ load(self.schedulers, self.savedir, PREFIX_SCHED, epoch, self.device)
+
+ self.epoch = epoch
+ self._load_model_state(epoch)
+ self._handle_epoch_end()
+
+ def save(self, epoch = None):
+ LOGGER.debug('Saving model at epoch %s', epoch)
+
+ save(self.models, self.savedir, PREFIX_MODEL, epoch)
+ save(self.optimizers, self.savedir, PREFIX_OPT, epoch)
+ save(self.schedulers, self.savedir, PREFIX_SCHED, epoch)
+
+ self._save_model_state(epoch)
+
+ def end_epoch(self, epoch = None):
+ for scheduler in self.schedulers.values():
+ if scheduler is None:
+ continue
+
+ if isinstance(scheduler, ReduceLROnPlateau):
+ scheduler.step(self.metric)
+ else:
+ scheduler.step()
+
+ self._handle_epoch_end()
+
+ if epoch is None:
+ self.epoch = self.epoch + 1
+ else:
+ self.epoch = epoch
+
+ def pprint(self, verbose):
+ for name,model in self.models.items():
+ num_params = 0
+
+ for param in model.parameters():
+ num_params += param.numel()
+
+ if verbose:
+ print(model)
+
+ print(
+ '[Network %s] Total number of parameters : %.3f M' % (
+ name, num_params / 1e6
+ )
+ )
+
+ def set_requires_grad(self, models, requires_grad = False):
+ # pylint: disable=no-self-use
+ if not isinstance(models, list):
+ models = [models, ]
+
+ for model in models:
+ for param in model.parameters():
+ param.requires_grad = requires_grad
+
+ def get_current_losses(self):
+ result = {}
+
+ for (k,v) in self.losses.items():
+ result[k] = float(v)
+
+ return result
+
+ def get_images(self):
+ return self.images
+
+
diff --git a/uvcgan_s/cgan/named_dict.py b/uvcgan_s/cgan/named_dict.py
new file mode 100644
index 0000000..02c0180
--- /dev/null
+++ b/uvcgan_s/cgan/named_dict.py
@@ -0,0 +1,48 @@
+from collections.abc import Mapping
+
+class NamedDict(Mapping):
+ # pylint: disable=too-many-instance-attributes
+
+ _fields = None
+
+ def __init__(self, *args, **kwargs):
+ self._fields = {}
+
+ for arg in args:
+ self._fields[arg] = None
+
+ self._fields.update(**kwargs)
+
+ def __contains__(self, key):
+ return (key in self._fields)
+
+ def __getitem__(self, key):
+ return self._fields[key]
+
+ def __setitem__(self, key, value):
+ self._fields[key] = value
+
+ def __getattr__(self, key):
+ return self._fields[key]
+
+ def __setattr__(self, key, value):
+ if (self._fields is not None) and (key in self._fields):
+ self._fields[key] = value
+ else:
+ super().__setattr__(key, value)
+
+ def __iter__(self):
+ return iter(self._fields)
+
+ def __len__(self):
+ return len(self._fields)
+
+ def items(self):
+ return self._fields.items()
+
+ def keys(self):
+ return self._fields.keys()
+
+ def values(self):
+ return self._fields.values()
+
diff --git a/uvcgan_s/cgan/pix2pix.py b/uvcgan_s/cgan/pix2pix.py
new file mode 100644
index 0000000..ac97b41
--- /dev/null
+++ b/uvcgan_s/cgan/pix2pix.py
@@ -0,0 +1,166 @@
+# pylint: disable=not-callable
+# NOTE: Mistaken lint:
+# E1102: self.criterion_gan is not callable (not-callable)
+
+import torch
+
+from uvcgan_s.torch.select import select_optimizer
+from uvcgan_s.base.losses import GANLoss, cal_gradient_penalty
+from uvcgan_s.models.discriminator import construct_discriminator
+from uvcgan_s.models.generator import construct_generator
+
+from .model_base import ModelBase
+from .named_dict import NamedDict
+from .funcs import set_two_domain_input
+
+class Pix2PixModel(ModelBase):
+
+ def _setup_images(self, _config):
+ return NamedDict('real_a', 'fake_b', 'real_b', 'fake_a')
+
+ def _setup_models(self, config):
+ models = { }
+
+ image_shape_a = config.data.datasets[0].shape
+ image_shape_b = config.data.datasets[1].shape
+
+ assert image_shape_a[1:] == image_shape_b[1:], \
+ "Pix2Pix needs images in both domains to have the same size"
+
+ models['gen_ab'] = construct_generator(
+ config.generator, image_shape_a, image_shape_b, self.device
+ )
+ models['gen_ba'] = construct_generator(
+ config.generator, image_shape_b, image_shape_a, self.device
+ )
+
+ if self.is_train:
+ extended_image_shape = (
+ image_shape_a[0] + image_shape_b[0], *image_shape_a[1:]
+ )
+
+ for name in [ 'disc_a', 'disc_b' ]:
+ models[name] = construct_discriminator(
+ config.discriminator, extended_image_shape, self.device
+ )
+
+ return NamedDict(**models)
+
+ def _setup_losses(self, config):
+ return NamedDict(
+ 'gen_ab', 'gen_ba', 'l1_ab', 'l1_ba', 'disc_a', 'disc_b'
+ )
+
+ def _setup_optimizers(self, config):
+ optimizers = NamedDict('gen_ab', 'gen_ba', 'disc_a', 'disc_b')
+
+ optimizers.gen_ab = select_optimizer(
+ self.models.gen_ab.parameters(), config.generator.optimizer
+ )
+ optimizers.gen_ba = select_optimizer(
+ self.models.gen_ba.parameters(), config.generator.optimizer
+ )
+
+ optimizers.disc_a = select_optimizer(
+ self.models.disc_a.parameters(), config.discriminator.optimizer
+ )
+ optimizers.disc_b = select_optimizer(
+ self.models.disc_b.parameters(), config.discriminator.optimizer
+ )
+
+ return optimizers
+
+ def __init__(self, savedir, config, is_train, device):
+ super().__init__(savedir, config, is_train, device)
+
+ assert len(config.data.datasets) == 2, \
+ "Pix2Pix expects a pair of datasets"
+
+ self.criterion_gan = GANLoss(config.loss).to(self.device)
+ self.criterion_l1 = torch.nn.L1Loss()
+ self.gradient_penalty = config.gradient_penalty
+
+ def _set_input(self, inputs, domain):
+ set_two_domain_input(self.images, inputs, domain, self.device)
+
+ def forward(self):
+ if self.images.real_a is not None:
+ self.images.fake_b = self.models.gen_ab(self.images.real_a)
+
+ if self.images.real_b is not None:
+ self.images.fake_a = self.models.gen_ba(self.images.real_b)
+
+ def backward_discriminator_base(self, model, real, fake, preimage):
+ cond_real = torch.cat([real, preimage], dim = 1)
+ cond_fake = torch.cat([fake, preimage], dim = 1).detach()
+
+ pred_real = model(cond_real)
+ loss_real = self.criterion_gan(pred_real, True)
+
+ pred_fake = model(cond_fake)
+ loss_fake = self.criterion_gan(pred_fake, False)
+
+ loss = (loss_real + loss_fake) * 0.5
+
+ if self.gradient_penalty is not None:
+ loss += cal_gradient_penalty(
+ model, cond_real, cond_fake, real.device,
+ **self.gradient_penalty
+ )[0]
+
+ loss.backward()
+ return loss
+
+ def backward_discriminators(self):
+ self.losses.disc_b = self.backward_discriminator_base(
+ self.models.disc_b,
+ self.images.real_b, self.images.fake_b, self.images.real_a
+ )
+
+ self.losses.disc_a = self.backward_discriminator_base(
+ self.models.disc_a,
+ self.images.real_a, self.images.fake_a, self.images.real_b
+ )
+
+ def backward_generator_base(self, disc, real, fake, preimage):
+ loss_gen = self.criterion_gan(
+ disc(torch.cat([fake, preimage], dim = 1)), True
+ )
+
+ loss_l1 = self.criterion_l1(fake, real)
+
+ loss = loss_gen + loss_l1
+ loss.backward()
+
+ return (loss_gen, loss_l1)
+
+ def backward_generators(self):
+ self.losses.gen_ab, self.losses.l1_ab = self.backward_generator_base(
+ self.models.disc_b,
+ self.images.real_b, self.images.fake_b, self.images.real_a
+ )
+
+ self.losses.gen_ba, self.losses.l1_ba = self.backward_generator_base(
+ self.models.disc_a,
+ self.images.real_a, self.images.fake_a, self.images.real_b
+ )
+
+ def optimization_step(self):
+ self.forward()
+
+ # Generators
+ self.set_requires_grad([self.models.disc_a, self.models.disc_b], False)
+ self.optimizers.gen_ab.zero_grad()
+ self.optimizers.gen_ba.zero_grad()
+ self.backward_generators()
+ self.optimizers.gen_ab.step()
+ self.optimizers.gen_ba.step()
+
+ # Discriminators
+ self.set_requires_grad([self.models.disc_a, self.models.disc_b], True)
+ self.optimizers.disc_a.zero_grad()
+ self.optimizers.disc_b.zero_grad()
+ self.backward_discriminators()
+ self.optimizers.disc_a.step()
+ self.optimizers.disc_b.step()
+
diff --git a/uvcgan_s/cgan/simple_autoencoder.py b/uvcgan_s/cgan/simple_autoencoder.py
new file mode 100644
index 0000000..bf252a3
--- /dev/null
+++ b/uvcgan_s/cgan/simple_autoencoder.py
@@ -0,0 +1,111 @@
+from uvcgan_s.torch.select import select_optimizer, select_loss
+from uvcgan_s.torch.background_penalty import BackgroundPenaltyReduction
+from uvcgan_s.torch.image_masking import select_masking
+from uvcgan_s.models.generator import construct_generator
+
+from .model_base import ModelBase
+from .named_dict import NamedDict
+
+class SimpleAutoencoder(ModelBase):
+ """Model that tries to train an autoencoder (i.e. target == input).
+
+ This autoencoder expects inputs to be either tuples of the form
+ `(features, target)` or the `features` itself.
+ """
+
+ def _setup_images(self, _config):
+ images = [ 'real', 'reco' ]
+
+ if self.masking is not None:
+ images.append('masked')
+
+ return NamedDict(*images)
+
+ def _setup_models(self, config):
+ return NamedDict(
+ encoder = construct_generator(
+ config.generator,
+ config.data.datasets[0].shape,
+ config.data.datasets[0].shape,
+ self.device
+ )
+ )
+
+ def _setup_losses(self, config):
+ self.loss_fn = select_loss(config.loss)
+
+ assert config.gradient_penalty is None, \
+ "Autoencoder model does not support gradient penalty"
+
+ return NamedDict('loss')
+
+ def _setup_optimizers(self, config):
+ return NamedDict(
+ encoder = select_optimizer(
+ self.models.encoder.parameters(), config.generator.optimizer
+ )
+ )
+
+ def __init__(
+ self, savedir, config, is_train, device,
+ background_penalty = None, masking = None
+ ):
+ # pylint: disable=too-many-arguments
+ self.masking = select_masking(masking)
+ assert len(config.data.datasets) == 1, \
+ "Simple Autoencoder can work only with a single dataset"
+
+ super().__init__(savedir, config, is_train, device)
+
+ if background_penalty is None:
+ self.background_penalty = None
+ else:
+ self.background_penalty = BackgroundPenaltyReduction(
+ **background_penalty
+ )
+
+ assert config.discriminator is None, \
+ "Autoencoder model does not use discriminator"
+
+ def _handle_epoch_end(self):
+ if self.background_penalty is not None:
+ self.background_penalty.end_epoch(self.epoch)
+
+ def _set_input(self, inputs, _domain):
+ # inputs : image or (image, label)
+ if isinstance(inputs, (list, tuple)):
+ self.images.real = inputs[0].to(self.device)
+ else:
+ self.images.real = inputs.to(self.device)
+
+ def forward(self):
+ if self.masking is None:
+ input_img = self.images.real
+ else:
+ self.images.masked = self.masking(self.images.real)
+ input_img = self.images.masked
+
+ self.images.reco = self.models.encoder(input_img)
+
+ def backward(self):
+ if self.background_penalty is not None:
+ reco = self.background_penalty(self.images.reco, self.images.real)
+ else:
+ reco = self.images.reco
+
+ loss = self.loss_fn(reco, self.images.real)
+ loss.backward()
+
+ self.losses.loss = loss
+
+ def optimization_step(self):
+ self.forward()
+
+ for optimizer in self.optimizers.values():
+ optimizer.zero_grad()
+
+ self.backward()
+
+ for optimizer in self.optimizers.values():
+ optimizer.step()
+
diff --git a/uvcgan_s/cgan/uvcgan2.py b/uvcgan_s/cgan/uvcgan2.py
new file mode 100644
index 0000000..ebaf40a
--- /dev/null
+++ b/uvcgan_s/cgan/uvcgan2.py
@@ -0,0 +1,537 @@
+# pylint: disable=not-callable
+# NOTE: Mistaken lint:
+# E1102: self.criterion_gan is not callable (not-callable)
+
+import itertools
+import torch
+
+from torchvision.transforms import GaussianBlur, Resize
+
+from uvcgan_s.torch.select import select_optimizer, extract_name_kwargs
+from uvcgan_s.torch.gan_losses import select_gan_loss
+from uvcgan_s.torch.queue import FastQueue
+from uvcgan_s.torch.funcs import prepare_model, update_average_model
+
+from uvcgan_s.torch.layers.batch_head import BatchHeadWrapper, get_batch_head
+from uvcgan_s.torch.gradient_penalty import GradientPenalty
+from uvcgan_s.torch.gradient_cacher import GradientCacher
+from uvcgan_s.torch.image_masking import select_masking
+from uvcgan_s.models.discriminator import construct_discriminator
+from uvcgan_s.models.generator import construct_generator
+
+from .model_base import ModelBase
+from .named_dict import NamedDict
+from .funcs import set_two_domain_input
+
+def construct_consistency_model(consist, device):
+ name, kwargs = extract_name_kwargs(consist)
+
+ if name == 'blur':
+ return GaussianBlur(**kwargs).to(device)
+
+ if name == 'resize':
+ return Resize(**kwargs).to(device)
+
+ raise ValueError(f'Unknown consistency type: {name}')
+
+
+def queued_forward(batch_head_model, input_image, queue, update_queue = True):
+ output, pred_body = batch_head_model.forward(
+ input_image, extra_bodies = queue.query(), return_body = True
+ )
+
+ if update_queue:
+ queue.push(pred_body)
+
+ return output
+
+def init_hc(image, n_hidden):
+ return torch.zeros(
+ (image.shape[0], n_hidden, *image.shape[2:]),
+ dtype = image.dtype, device = image.device
+ )
+
+class UVCGAN2(ModelBase):
+ # pylint: disable=too-many-instance-attributes
+
+ def _setup_images(self, _config):
+ images = [
+ 'real_a', 'real_b',
+ 'fake_a', 'fake_b',
+ 'reco_a', 'reco_b',
+ 'real_hc_a', 'real_hc_b',
+ 'fake_hc_a', 'fake_hc_b',
+ 'reco_hc_a', 'reco_hc_b',
+ 'consist_real_a', 'consist_real_b',
+ 'consist_fake_a', 'consist_fake_b',
+ ]
+
+ if self.is_train and self.lambda_idt > 0:
+ images += [
+ 'idt_input_a', 'idt_input_b',
+ 'idt_a', 'idt_b', 'idt_hc_a', 'idt_hc_b'
+ ]
+
+ return NamedDict(*images)
+
+ def _construct_batch_head_disc(self, model_config, input_shape):
+ disc_body = construct_discriminator(
+ model_config, input_shape, self.device
+ )
+
+ disc_head = get_batch_head(self.head_config)
+ disc_head = prepare_model(disc_head, self.device)
+
+ return BatchHeadWrapper(disc_body, disc_head)
+
+ def _add_hc_to_shape(self, shape):
+ return (shape[0] + self.n_hidden, *shape[1:])
+
+ def _setup_models(self, config):
+ models = {}
+
+ shape_a = self._add_hc_to_shape(config.data.datasets[0].shape)
+ shape_b = self._add_hc_to_shape(config.data.datasets[1].shape)
+
+ models['gen_ab'] = construct_generator(
+ config.generator, shape_a, shape_b, self.device
+ )
+ models['gen_ba'] = construct_generator(
+ config.generator, shape_b, shape_a, self.device
+ )
+
+ if self.avg_momentum is not None:
+ models['avg_gen_ab'] = construct_generator(
+ config.generator, shape_a, shape_b, self.device
+ )
+ models['avg_gen_ba'] = construct_generator(
+ config.generator, shape_b, shape_a, self.device
+ )
+
+ models['avg_gen_ab'].load_state_dict(models['gen_ab'].state_dict())
+ models['avg_gen_ba'].load_state_dict(models['gen_ba'].state_dict())
+
+ if self.is_train:
+ models['disc_a'] = self._construct_batch_head_disc(
+ config.discriminator, config.data.datasets[0].shape
+ )
+ models['disc_b'] = self._construct_batch_head_disc(
+ config.discriminator, config.data.datasets[1].shape
+ )
+
+ return NamedDict(**models)
+
+ def _setup_losses(self, config):
+ losses = [
+ 'gen_ab', 'gen_ba', 'cycle_a', 'cycle_b', 'disc_a', 'disc_b'
+ ]
+
+ if self.is_train and self.lambda_idt > 0:
+ losses += [ 'idt_a', 'idt_b' ]
+
+ if self.is_train and config.gradient_penalty is not None:
+ losses += [ 'gp_a', 'gp_b' ]
+
+ if self.consist_model is not None:
+ losses += [ 'consist_a', 'consist_b' ]
+
+ return NamedDict(*losses)
+
+ def _setup_optimizers(self, config):
+ optimizers = NamedDict('gen', 'disc')
+
+ optimizers.gen = select_optimizer(
+ itertools.chain(
+ self.models.gen_ab.parameters(),
+ self.models.gen_ba.parameters()
+ ),
+ config.generator.optimizer
+ )
+
+ optimizers.disc = select_optimizer(
+ itertools.chain(
+ self.models.disc_a.parameters(),
+ self.models.disc_b.parameters()
+ ),
+ config.discriminator.optimizer
+ )
+
+ return optimizers
+
+ def __init__(
+ self, savedir, config, is_train, device, head_config = None,
+ lambda_a = 10.0,
+ lambda_b = 10.0,
+ lambda_idt = 0.5,
+ lambda_consist = 0,
+ n_hidden = 0,
+ head_queue_size = 3,
+ avg_momentum = None,
+ gp_cache_period = 1,
+ consistency = None,
+ masking = None,
+ ):
+ # pylint: disable=too-many-arguments
+ # pylint: disable=too-many-locals
+ self.lambda_a = lambda_a
+ self.lambda_b = lambda_b
+ self.lambda_idt = lambda_idt
+ self.lambda_consist = lambda_consist
+ self.avg_momentum = avg_momentum
+ self.n_hidden = n_hidden
+ self.head_config = head_config or {}
+ self.consist_model = None
+ self.masking = select_masking(masking)
+
+ if (lambda_consist > 0) and (consistency is not None):
+ self.consist_model \
+ = construct_consistency_model(consistency, device)
+
+ assert len(config.data.datasets) == 2, \
+ "CycleGAN expects a pair of datasets"
+
+ super().__init__(savedir, config, is_train, device)
+
+ self.criterion_gan = select_gan_loss(config.loss).to(self.device)
+ self.criterion_cycle = torch.nn.L1Loss()
+ self.criterion_idt = torch.nn.L1Loss()
+ self.criterion_consist = torch.nn.L1Loss()
+ self.gradient_penalty = config.gradient_penalty
+
+ if self.is_train:
+ self.queues = NamedDict(**{
+ name : FastQueue(head_queue_size, device = device)
+ for name in [ 'real_a', 'real_b', 'fake_a', 'fake_b' ]
+ })
+
+ self.gp = None
+ self.gp_cacher_a = None
+ self.gp_cacher_b = None
+
+ if config.gradient_penalty is not None:
+ self.gp = GradientPenalty(**config.gradient_penalty)
+
+ self.gp_cacher_a = GradientCacher(
+ self.models.disc_a, self.gp, gp_cache_period
+ )
+ self.gp_cacher_b = GradientCacher(
+ self.models.disc_b, self.gp, gp_cache_period
+ )
+
+ def _set_input(self, inputs, domain):
+ set_two_domain_input(self.images, inputs, domain, self.device)
+
+ if self.images.real_a is not None:
+ self.images.real_hc_a = init_hc(self.images.real_a, self.n_hidden)
+
+ if self.masking is None:
+ self.images.idt_input_a = self.images.real_a
+ else:
+ self.images.idt_input_a = self.masking(self.images.real_a)
+
+ if self.consist_model is not None:
+ self.images.consist_real_a \
+ = self.consist_model(self.images.real_a)
+
+ if self.images.real_b is not None:
+ self.images.real_hc_b = init_hc(self.images.real_b, self.n_hidden)
+
+ if self.masking is None:
+ self.images.idt_input_b = self.images.real_b
+ else:
+ self.images.idt_input_b = self.masking(self.images.real_b)
+
+ if self.consist_model is not None:
+ self.images.consist_real_b \
+ = self.consist_model(self.images.real_b)
+
+ def cycle_forward_image(self, real, real_hc, gen_fwd, gen_bkw):
+ # pylint: disable=no-self-use
+
+ # (N, C+n_hidden, H, W)
+ real_with_hc = torch.cat((real, real_hc), dim = 1)
+
+ fake_with_hc = gen_fwd(real_with_hc)
+ reco_with_hc = gen_bkw(fake_with_hc)
+
+ if self.n_hidden > 0:
+ fake = fake_with_hc[:, :-self.n_hidden, ...]
+ fake_hc = fake_with_hc[:, -self.n_hidden:, ...]
+
+ reco = reco_with_hc[:, :-self.n_hidden, ...]
+ reco_hc = reco_with_hc[:, -self.n_hidden:, ...]
+ else:
+ fake = fake_with_hc
+ fake_hc = None
+
+ reco = reco_with_hc
+ reco_hc = None
+
+ consist_fake = None
+
+ if self.consist_model is not None:
+ consist_fake = self.consist_model(fake)
+
+ return (fake, fake_hc, reco, reco_hc, consist_fake)
+
+ def idt_forward_image(self, idt_input, real_hc, gen):
+ # pylint: disable=no-self-use
+
+ # (N, C+n_hidden, H, W)
+ real_with_hc = torch.cat((idt_input, real_hc), dim = 1)
+ idt_with_hc = gen(real_with_hc)
+
+ if self.n_hidden > 0:
+ idt = idt_with_hc[:, :-self.n_hidden, ...]
+ idt_hc = idt_with_hc[:, -self.n_hidden:, ...]
+ else:
+ idt = idt_with_hc
+ idt_hc = None
+
+ return (idt, idt_hc)
+
+ def forward_dispatch(self, direction):
+ if direction == 'ab':
+ (
+ self.images.fake_b, self.images.fake_hc_b,
+ self.images.reco_a, self.images.reco_hc_a,
+ self.images.consist_fake_b
+ ) = self.cycle_forward_image(
+ self.images.real_a, self.images.real_hc_a,
+ self.models.gen_ab, self.models.gen_ba
+ )
+
+ elif direction == 'ba':
+ (
+ self.images.fake_a, self.images.fake_hc_a,
+ self.images.reco_b, self.images.reco_hc_b,
+ self.images.consist_fake_a
+ ) = self.cycle_forward_image(
+ self.images.real_b, self.images.real_hc_b,
+ self.models.gen_ba, self.models.gen_ab
+ )
+
+ elif direction == 'aa':
+ self.images.idt_a, self.images.idt_hc_a = \
+ self.idt_forward_image(
+ self.images.idt_input_a, self.images.real_hc_a,
+ self.models.gen_ba
+ )
+
+ elif direction == 'bb':
+ self.images.idt_b, self.images.idt_hc_b = \
+ self.idt_forward_image(
+ self.images.idt_input_b, self.images.real_hc_b,
+ self.models.gen_ab
+ )
+
+ elif direction == 'avg-ab':
+ (
+ self.images.fake_b, self.images.fake_hc_b,
+ self.images.reco_a, self.images.reco_hc_a,
+ self.images.consist_fake_b
+ ) = self.cycle_forward_image(
+ self.images.real_a, self.images.real_hc_a,
+ self.models.avg_gen_ab, self.models.avg_gen_ba
+ )
+
+ elif direction == 'avg-ba':
+ (
+ self.images.fake_a, self.images.fake_hc_a,
+ self.images.reco_b, self.images.reco_hc_b,
+ self.images.consist_fake_a
+ ) = self.cycle_forward_image(
+ self.images.real_b, self.images.real_hc_b,
+ self.models.avg_gen_ba, self.models.avg_gen_ab
+ )
+
+ else:
+ raise ValueError(f"Unknown forward direction: '{direction}'")
+
+ def forward(self):
+ if self.images.real_a is not None:
+ if self.avg_momentum is not None:
+ self.forward_dispatch(direction = 'avg-ab')
+ else:
+ self.forward_dispatch(direction = 'ab')
+
+ if self.images.real_b is not None:
+ if self.avg_momentum is not None:
+ self.forward_dispatch(direction = 'avg-ba')
+ else:
+ self.forward_dispatch(direction = 'ba')
+
+ def eval_consist_loss(
+ self, consist_real_0, consist_fake_1, lambda_cycle_0
+ ):
+ return lambda_cycle_0 * self.lambda_consist * self.criterion_consist(
+ consist_fake_1, consist_real_0
+ )
+
+ def eval_loss_of_cycle_forward(
+ self, disc_1, real_0, fake_1, reco_0, fake_queue_1, lambda_cycle_0
+ ):
+ # pylint: disable=too-many-arguments
+ # NOTE: Queue is updated in discriminator backprop
+ disc_pred_fake_1 = queued_forward(
+ disc_1, fake_1, fake_queue_1, update_queue = False
+ )
+
+ loss_gen = self.criterion_gan(
+ disc_pred_fake_1, is_real = True, is_generator = True
+ )
+ loss_cycle = lambda_cycle_0 * self.criterion_cycle(reco_0, real_0)
+
+ loss = loss_gen + loss_cycle
+
+ return (loss_gen, loss_cycle, loss)
+
+ def eval_loss_of_idt_forward(self, real_0, idt_0, lambda_cycle_0):
+ loss_idt = (
+ lambda_cycle_0
+ * self.lambda_idt
+ * self.criterion_idt(idt_0, real_0)
+ )
+
+ loss = loss_idt
+
+ return (loss_idt, loss)
+
+ def backward_gen(self, direction):
+ if direction == 'ab':
+ (self.losses.gen_ab, self.losses.cycle_a, loss) \
+ = self.eval_loss_of_cycle_forward(
+ self.models.disc_b,
+ self.images.real_a, self.images.fake_b, self.images.reco_a,
+ self.queues.fake_b, self.lambda_a
+ )
+
+ if self.consist_model is not None:
+ self.losses.consist_a = self.eval_consist_loss(
+ self.images.consist_real_a, self.images.consist_fake_b,
+ self.lambda_a
+ )
+
+ loss += self.losses.consist_a
+
+ elif direction == 'ba':
+ (self.losses.gen_ba, self.losses.cycle_b, loss) \
+ = self.eval_loss_of_cycle_forward(
+ self.models.disc_a,
+ self.images.real_b, self.images.fake_a, self.images.reco_b,
+ self.queues.fake_a, self.lambda_b
+ )
+
+ if self.consist_model is not None:
+ self.losses.consist_b = self.eval_consist_loss(
+ self.images.consist_real_b, self.images.consist_fake_a,
+ self.lambda_b
+ )
+
+ loss += self.losses.consist_b
+
+ elif direction == 'aa':
+ (self.losses.idt_a, loss) \
+ = self.eval_loss_of_idt_forward(
+ self.images.real_a, self.images.idt_a, self.lambda_a
+ )
+
+ elif direction == 'bb':
+ (self.losses.idt_b, loss) \
+ = self.eval_loss_of_idt_forward(
+ self.images.real_b, self.images.idt_b, self.lambda_b
+ )
+ else:
+ raise ValueError(f"Unknown forward direction: '{direction}'")
+
+
+ loss.backward()
+
+ def backward_discriminator_base(
+ self, model, real, fake, queue_real, queue_fake, gp_cacher
+ ):
+ # pylint: disable=too-many-arguments
+ loss_gp = None
+
+ if self.gp is not None:
+ loss_gp = gp_cacher(
+ model, fake, real,
+ model_kwargs_fake = { 'extra_bodies' : queue_fake.query() },
+ model_kwargs_real = { 'extra_bodies' : queue_real.query() },
+ )
+
+ pred_real = queued_forward(
+ model, real, queue_real, update_queue = True
+ )
+ loss_real = self.criterion_gan(pred_real, is_real = True)
+
+ pred_fake = queued_forward(
+ model, fake, queue_fake, update_queue = True
+ )
+ loss_fake = self.criterion_gan(pred_fake, is_real = False)
+
+ loss = (loss_real + loss_fake) * 0.5
+ loss.backward()
+
+ return (loss_gp, loss)
+
+ def backward_discriminators(self):
+ fake_a = self.images.fake_a.detach()
+ fake_b = self.images.fake_b.detach()
+
+ loss_gp_b, self.losses.disc_b \
+ = self.backward_discriminator_base(
+ self.models.disc_b, self.images.real_b, fake_b,
+ self.queues.real_b, self.queues.fake_b, self.gp_cacher_b
+ )
+
+ if loss_gp_b is not None:
+ self.losses.gp_b = loss_gp_b
+
+ loss_gp_a, self.losses.disc_a = \
+ self.backward_discriminator_base(
+ self.models.disc_a, self.images.real_a, fake_a,
+ self.queues.real_a, self.queues.fake_a, self.gp_cacher_a
+ )
+
+ if loss_gp_a is not None:
+ self.losses.gp_a = loss_gp_a
+
+ def optimization_step_gen(self):
+ self.set_requires_grad([self.models.disc_a, self.models.disc_b], False)
+ self.optimizers.gen.zero_grad(set_to_none = True)
+
+ dir_list = [ 'ab', 'ba' ]
+ if self.lambda_idt > 0:
+ dir_list += [ 'aa', 'bb' ]
+
+ for direction in dir_list:
+ self.forward_dispatch(direction)
+ self.backward_gen(direction)
+
+ self.optimizers.gen.step()
+
+ def optimization_step_disc(self):
+ self.set_requires_grad([self.models.disc_a, self.models.disc_b], True)
+ self.optimizers.disc.zero_grad(set_to_none = True)
+
+ self.backward_discriminators()
+
+ self.optimizers.disc.step()
+
+ def _accumulate_averages(self):
+ update_average_model(
+ self.models.avg_gen_ab, self.models.gen_ab, self.avg_momentum
+ )
+ update_average_model(
+ self.models.avg_gen_ba, self.models.gen_ba, self.avg_momentum
+ )
+
+ def optimization_step(self):
+ self.optimization_step_gen()
+ self.optimization_step_disc()
+
+ if self.avg_momentum is not None:
+ self._accumulate_averages()
+
+
diff --git a/uvcgan_s/cgan/uvcgan_s.py b/uvcgan_s/cgan/uvcgan_s.py
new file mode 100644
index 0000000..c0d9148
--- /dev/null
+++ b/uvcgan_s/cgan/uvcgan_s.py
@@ -0,0 +1,601 @@
+# pylint: disable=not-callable
+# NOTE: Mistaken lint:
+# E1102: self.criterion_gan is not callable (not-callable)
+
+import itertools
+import torch
+
+from uvcgan_s.torch.data_norm import select_data_normalization
+from uvcgan_s.torch.gan_losses import select_gan_loss
+from uvcgan_s.torch.select import select_optimizer
+from uvcgan_s.torch.queue import FastQueue
+from uvcgan_s.torch.funcs import (
+ prepare_model, update_average_model, clip_gradients
+)
+from uvcgan_s.torch.layers.batch_head import BatchHeadWrapper, get_batch_head
+from uvcgan_s.torch.gradient_penalty import GradientPenalty
+from uvcgan_s.torch.gradient_cacher import GradientCacher
+from uvcgan_s.models.discriminator import construct_discriminator
+from uvcgan_s.models.generator import construct_generator
+
+from .model_base import ModelBase
+from .named_dict import NamedDict
+from .funcs import set_asym_two_domain_input
+
+def queued_forward(
+ batch_head_model, input_image, queue, data_norm, normalize,
+ update_queue = True
+):
+ # pylint: disable=too-many-arguments
+ if normalize:
+ input_image = data_norm.normalize(input_image)
+
+ output, pred_body = batch_head_model.forward(
+ input_image, extra_bodies = queue.query(), return_body = True
+ )
+
+ if update_queue:
+ queue.push(pred_body)
+
+ return output
+
+def eval_loss_with_norm(loss_fn, a, b, data_norm, normalize):
+ if normalize:
+ a = data_norm.normalize(a)
+ b = data_norm.normalize(b)
+
+ return loss_fn(a, b)
+
+class UVCGAN_S(ModelBase):
+ # pylint: disable=too-many-instance-attributes
+
+ def _setup_images(self, _config):
+ images = [
+ 'real_a0', 'real_a1', 'real_a', 'real_b',
+ 'fake_a0', 'fake_a1', 'fake_a', 'fake_b',
+ 'reco_a0', 'reco_a1', 'reco_a', 'reco_b',
+ ]
+
+ if self.lambda_idt_bb:
+ images += [ 'idt_bb_input', 'idt_bb', ]
+
+ if self.lambda_idt_aa:
+ images += [
+ 'idt_aa_input', 'idt_aa', 'idt_aa_a0', 'idt_aa_a1',
+ ]
+
+ return NamedDict(*images)
+
+ def _construct_batch_head_disc(self, model_config, input_shape):
+ disc_body = construct_discriminator(
+ model_config, input_shape, self.device
+ )
+
+ disc_head = get_batch_head(self.head_config)
+ disc_head = prepare_model(disc_head, self.device)
+
+ return BatchHeadWrapper(disc_body, disc_head)
+
+ def _setup_models(self, config):
+ models = {}
+
+ shape_a0 = tuple(config.data.datasets[0].shape)
+ shape_a1 = tuple(config.data.datasets[1].shape)
+
+ shape_a = (shape_a0[0] + shape_a1[0], *shape_a1[1:])
+ shape_b = tuple(config.data.datasets[2].shape)
+
+ models['gen_ab'] = construct_generator(
+ config.generator, shape_a, shape_b, self.device
+ )
+ models['gen_ba'] = construct_generator(
+ config.generator, shape_b, shape_a, self.device
+ )
+
+ if self.ema_momentum is not None:
+ models['ema_gen_ab'] = construct_generator(
+ config.generator, shape_a, shape_b, self.device
+ )
+ models['ema_gen_ba'] = construct_generator(
+ config.generator, shape_b, shape_a, self.device
+ )
+
+ models['ema_gen_ab'].load_state_dict(models['gen_ab'].state_dict())
+ models['ema_gen_ba'].load_state_dict(models['gen_ba'].state_dict())
+
+ if self.is_train:
+ models['disc_a0'] = self._construct_batch_head_disc(
+ config.discriminator, shape_a0
+ )
+ models['disc_a1'] = self._construct_batch_head_disc(
+ config.discriminator, shape_a1
+ )
+ models['disc_b'] = self._construct_batch_head_disc(
+ config.discriminator, shape_b
+ )
+
+ return NamedDict(**models)
+
+ def _setup_losses(self, config):
+ losses = [
+ 'gen_ab', 'gen_ba0', 'gen_ba1',
+ 'cycle_a0', 'cycle_a1', 'cycle_b',
+ 'disc_a0', 'disc_a1', 'disc_b'
+ ]
+
+ if self.lambda_idt_aa:
+ losses += [
+ 'idt_aa_a0', 'idt_aa_a1',
+ ]
+
+ if self.lambda_idt_bb:
+ losses += [ 'idt_bb' ]
+
+ if config.gradient_penalty is not None:
+ losses += [ 'gp_a0', 'gp_a1', 'gp_b' ]
+
+ return NamedDict(*losses)
+
+ def _setup_optimizers(self, config):
+ optimizers = NamedDict('gen', 'disc')
+
+ optimizers.gen = select_optimizer(
+ itertools.chain(
+ self.models.gen_ab.parameters(),
+ self.models.gen_ba.parameters()
+ ),
+ config.generator.optimizer
+ )
+
+ optimizers.disc = select_optimizer(
+ itertools.chain(
+ self.models.disc_a0.parameters(),
+ self.models.disc_a1.parameters(),
+ self.models.disc_b.parameters()
+ ),
+ config.discriminator.optimizer
+ )
+
+ return optimizers
+
+ def __init__(
+ self, savedir, config, is_train, device, head_config = None,
+ lambda_adv_a0 = 1.0,
+ lambda_adv_a1 = 1.0,
+ lambda_adv_b = 1.0,
+ lambda_cyc_a0 = 10.0,
+ lambda_cyc_a1 = 10.0,
+ lambda_cyc_b = 10.0,
+ lambda_idt_aa = 0.5,
+ lambda_idt_bb = 0.5,
+ head_queue_size = 3,
+ ema_momentum = None,
+ data_norm = None,
+ gp_cache_period = 1,
+ grad_clip = None,
+ norm_loss_a0 = False,
+ norm_loss_a1 = False,
+ norm_loss_b = False,
+ norm_disc_a0 = False,
+ norm_disc_a1 = False,
+ norm_disc_b = False,
+ ):
+ # pylint: disable=too-many-arguments
+ # pylint: disable=too-many-locals
+ self.lambda_adv_a0 = lambda_adv_a0
+ self.lambda_adv_a1 = lambda_adv_a1
+ self.lambda_adv_b = lambda_adv_b
+
+ self.lambda_cyc_a0 = lambda_cyc_a0
+ self.lambda_cyc_a1 = lambda_cyc_a1
+ self.lambda_cyc_b = lambda_cyc_b
+
+ self.lambda_idt_aa = lambda_idt_aa
+ self.lambda_idt_bb = lambda_idt_bb
+
+ self.ema_momentum = ema_momentum
+ self.data_norm = select_data_normalization(data_norm)
+ self.head_config = head_config or {}
+
+ assert len(config.data.datasets) == 3, \
+ "Asymmetric CycleGAN expects a triplet of datasets"
+
+ self._c_a0 = config.data.datasets[0].shape[0]
+ self._c_a1 = config.data.datasets[1].shape[0]
+
+ self._grad_clip = grad_clip or {}
+
+ self._norm_loss_a0 = norm_loss_a0
+ self._norm_loss_a1 = norm_loss_a1
+ self._norm_loss_b = norm_loss_b
+
+ self._norm_disc_a0 = norm_disc_a0
+ self._norm_disc_a1 = norm_disc_a1
+ self._norm_disc_b = norm_disc_b
+
+ super().__init__(savedir, config, is_train, device)
+
+ self.criterion_gan = select_gan_loss(config.loss).to(self.device)
+ self.criterion_idt = torch.nn.L1Loss()
+ self.criterion_cycle = torch.nn.L1Loss()
+ self.gradient_penalty = config.gradient_penalty
+
+ if self.is_train:
+ self.queues = NamedDict(**{
+ name : FastQueue(head_queue_size, device = device)
+ for name in [
+ 'real_a0', 'real_a1', 'real_b',
+ 'fake_a0', 'fake_a1', 'fake_b'
+ ]
+ })
+
+ self.gp = None
+ self.gp_cacher_a0 = None
+ self.gp_cacher_a1 = None
+ self.gp_cacher_b = None
+
+ if config.gradient_penalty is not None:
+ self.gp = GradientPenalty(**config.gradient_penalty)
+
+ self.gp_cacher_a0 = GradientCacher(
+ self.models.disc_a0, self.gp, gp_cache_period
+ )
+ self.gp_cacher_a1 = GradientCacher(
+ self.models.disc_a1, self.gp, gp_cache_period
+ )
+ self.gp_cacher_b = GradientCacher(
+ self.models.disc_b, self.gp, gp_cache_period
+ )
+
+ def split_domain_a_image(self, image_a):
+ image_a0 = image_a[:, :self._c_a0, ...]
+ image_a1 = image_a[:, self._c_a0:, ...]
+
+ return (image_a0, image_a1)
+
+ def merge_domain_a_images(self, image_a0, image_a1):
+ # pylint: disable=no-self-use
+ return torch.cat((image_a0, image_a1), dim = 1)
+
+ def _set_input(self, inputs, domain):
+ set_asym_two_domain_input(self.images, inputs, domain, self.device)
+
+ if (
+ (self.images.real_a0 is not None)
+ and (self.images.real_a1 is not None)
+ ):
+ self.images.real_a = self.merge_domain_a_images(
+ self.images.real_a0, self.images.real_a1
+ )
+
+ if not self.is_train:
+ return
+
+ if self.lambda_idt_bb and (self.images.real_b is not None):
+ self.images.idt_bb_input = self.merge_domain_a_images(
+ self.images.real_b, torch.zeros_like(self.images.real_a1)
+ )
+
+ if self.lambda_idt_aa and (
+ (self.images.real_a0 is not None)
+ and (self.images.real_a1 is not None)
+ ):
+ self.images.idt_aa_input = (
+ self.images.real_a0 + self.images.real_a1
+ )
+
+ def cycle_forward_image(self, real, gen_fwd, gen_bkw):
+ # pylint: disable=no-self-use
+
+ # (N, C, H, W)
+ if self.data_norm is not None:
+ real = self.data_norm.normalize(real)
+
+ fake = gen_fwd(real)
+ reco = gen_bkw(fake)
+
+ if self.data_norm is not None:
+ fake = self.data_norm.denormalize(fake)
+ reco = self.data_norm.denormalize(reco)
+
+ return (fake, reco)
+
+ def idt_forward_image(self, idt_input, gen):
+ # pylint: disable=no-self-use
+ if self.data_norm is not None:
+ idt_input = self.data_norm.normalize(idt_input)
+
+ idt = gen(idt_input)
+
+ if self.data_norm is not None:
+ idt = self.data_norm.denormalize(idt)
+
+ return idt
+
+ def forward_dispatch(self, direction):
+ if direction == 'cyc-aba':
+ (self.images.fake_b, self.images.reco_a) \
+ = self.cycle_forward_image(
+ self.images.real_a, self.models.gen_ab, self.models.gen_ba
+ )
+
+ (self.images.reco_a0, self.images.reco_a1) \
+ = self.split_domain_a_image(self.images.reco_a)
+
+ elif direction == 'cyc-bab':
+ (self.images.fake_a, self.images.reco_b) \
+ = self.cycle_forward_image(
+ self.images.real_b, self.models.gen_ba, self.models.gen_ab
+ )
+
+ (self.images.fake_a0, self.images.fake_a1) \
+ = self.split_domain_a_image(self.images.fake_a)
+
+ elif direction == 'idt-bb':
+ self.images.idt_bb = self.idt_forward_image(
+ self.images.idt_bb_input, self.models.gen_ab
+ )
+
+ elif direction == 'idt-aa':
+ self.images.idt_aa = self.idt_forward_image(
+ self.images.idt_aa_input, self.models.gen_ba
+ )
+
+ (self.images.idt_aa_a0, self.images.idt_aa_a1) \
+ = self.split_domain_a_image(self.images.idt_aa)
+
+ elif direction == 'ema-cyc-aba':
+ (self.images.fake_b, self.images.reco_a) \
+ = self.cycle_forward_image(
+ self.images.real_a,
+ self.models.ema_gen_ab, self.models.ema_gen_ba
+ )
+
+ (self.images.reco_a0, self.images.reco_a1) \
+ = self.split_domain_a_image(self.images.reco_a)
+
+ elif direction == 'ema-cyc-bab':
+ (self.images.fake_a, self.images.reco_b) \
+ = self.cycle_forward_image(
+ self.images.real_b,
+ self.models.ema_gen_ba, self.models.ema_gen_ab
+ )
+
+ (self.images.fake_a0, self.images.fake_a1) \
+ = self.split_domain_a_image(self.images.fake_a)
+
+ else:
+ raise ValueError(f"Unknown forward direction: '{direction}'")
+
+ def forward(self):
+ if self.images.real_a is not None:
+ if self.ema_momentum is not None:
+ self.forward_dispatch(direction = 'ema-cyc-aba')
+ else:
+ self.forward_dispatch(direction = 'cyc-aba')
+
+ if self.images.real_b is not None:
+ if self.ema_momentum is not None:
+ self.forward_dispatch(direction = 'ema-cyc-bab')
+ else:
+ self.forward_dispatch(direction = 'cyc-bab')
+
+ def eval_loss_of_cycle_forward_aba(self):
+ # NOTE: Queue is updated in discriminator backprop
+ disc_pred_fake_b = queued_forward(
+ self.models.disc_b, self.images.fake_b, self.queues.fake_b,
+ self.data_norm, self._norm_disc_b, update_queue = False
+ )
+
+ self.losses.gen_ab = self.criterion_gan(
+ disc_pred_fake_b, is_real = True, is_generator = True
+ )
+
+ self.losses.cycle_a0 = eval_loss_with_norm(
+ self.criterion_cycle, self.images.reco_a0, self.images.real_a0,
+ self.data_norm, self._norm_loss_a0
+ )
+ self.losses.cycle_a1 = eval_loss_with_norm(
+ self.criterion_cycle, self.images.reco_a1, self.images.real_a1,
+ self.data_norm, self._norm_loss_a1
+ )
+
+ return (
+ self.lambda_adv_b * self.losses.gen_ab
+ + 0.5 * self.lambda_cyc_a0 * self.losses.cycle_a0
+ + 0.5 * self.lambda_cyc_a1 * self.losses.cycle_a1
+ )
+
+ def eval_loss_of_cycle_forward_bab(self):
+ # NOTE: Queue is updated in discriminator backprop
+ disc_pred_fake_a0 = queued_forward(
+ self.models.disc_a0, self.images.fake_a0, self.queues.fake_a0,
+ self.data_norm, self._norm_disc_a0, update_queue = False
+ )
+ disc_pred_fake_a1 = queued_forward(
+ self.models.disc_a1, self.images.fake_a1, self.queues.fake_a1,
+ self.data_norm, self._norm_disc_a1, update_queue = False
+ )
+
+ self.losses.gen_ba0 = self.criterion_gan(
+ disc_pred_fake_a0, is_real = True, is_generator = True
+ )
+ self.losses.gen_ba1 = self.criterion_gan(
+ disc_pred_fake_a1, is_real = True, is_generator = True
+ )
+
+ self.losses.cycle_b = eval_loss_with_norm(
+ self.criterion_cycle, self.images.reco_b, self.images.real_b,
+ self.data_norm, self._norm_loss_b
+ )
+
+ return (
+ 0.5 * self.lambda_adv_a0 * self.losses.gen_ba0
+ + 0.5 * self.lambda_adv_a1 * self.losses.gen_ba1
+ + self.lambda_cyc_b * self.losses.cycle_b
+ )
+
+ def eval_loss_of_idt_forward_bb(self):
+ self.losses.idt_bb = eval_loss_with_norm(
+ self.criterion_idt, self.images.idt_bb, self.images.real_b,
+ self.data_norm, self._norm_loss_b
+ )
+
+ return self.lambda_idt_bb * self.lambda_cyc_b * self.losses.idt_bb
+
+ def eval_loss_of_idt_forward_aa(self):
+ target_a0 = self.images.real_a0
+ target_a1 = self.images.real_a1
+
+ self.losses.idt_aa_a0 = eval_loss_with_norm(
+ self.criterion_idt, self.images.idt_aa_a0, target_a0,
+ self.data_norm, self._norm_loss_a0
+ )
+ self.losses.idt_aa_a1 = eval_loss_with_norm(
+ self.criterion_idt, self.images.idt_aa_a1, target_a1,
+ self.data_norm, self._norm_loss_a1
+ )
+
+ # * 0.5 to account for 2 losses ist_aa_a{0,1} relative to bb
+ return 0.5 * self.lambda_idt_aa * (
+ self.lambda_cyc_a0 * self.losses.idt_aa_a0
+ + self.lambda_cyc_a1 * self.losses.idt_aa_a1
+ )
+
+ def backward_gen(self, direction):
+ if direction == 'cyc-aba':
+ loss = self.eval_loss_of_cycle_forward_aba()
+
+ elif direction == 'cyc-bab':
+ loss = self.eval_loss_of_cycle_forward_bab()
+
+ elif direction == 'idt-bb':
+ loss = self.eval_loss_of_idt_forward_bb()
+
+ elif direction == 'idt-aa':
+ loss = self.eval_loss_of_idt_forward_aa()
+
+ else:
+ raise ValueError(f"Unknown forward direction: '{direction}'")
+
+ loss.backward()
+
+ def backward_discriminator_base(
+ self, model, real, fake, queue_real, queue_fake, gp_cacher, scale,
+ normalize
+ ):
+ # pylint: disable=too-many-arguments
+ loss_gp = None
+
+ if self.gp is not None:
+ loss_gp = gp_cacher(
+ model, fake, real,
+ model_kwargs_fake = { 'extra_bodies' : queue_fake.query() },
+ model_kwargs_real = { 'extra_bodies' : queue_real.query() },
+ )
+
+ pred_real = queued_forward(
+ model, real, queue_real, self.data_norm, normalize,
+ update_queue = True
+ )
+ loss_real = self.criterion_gan(
+ pred_real, is_real = True, is_generator = False
+ )
+
+ pred_fake = queued_forward(
+ model, fake, queue_fake, self.data_norm, normalize,
+ update_queue = True
+ )
+ loss_fake = self.criterion_gan(
+ pred_fake, is_real = False, is_generator = False
+ )
+
+ loss = (loss_real + loss_fake) * 0.5 * scale
+ loss.backward()
+
+ return (loss_gp, loss)
+
+ def backward_discriminators(self):
+ fake_a0 = self.images.fake_a0.detach()
+ fake_a1 = self.images.fake_a1.detach()
+ fake_b = self.images.fake_b .detach()
+
+ loss_gp_b, self.losses.disc_b \
+ = self.backward_discriminator_base(
+ self.models.disc_b, self.images.real_b, fake_b,
+ self.queues.real_b, self.queues.fake_b, self.gp_cacher_b,
+ self.lambda_adv_b, self._norm_disc_b
+ )
+
+ if loss_gp_b is not None:
+ self.losses.gp_b = loss_gp_b
+
+ loss_gp_a0, self.losses.disc_a0 = \
+ self.backward_discriminator_base(
+ self.models.disc_a0, self.images.real_a0, fake_a0,
+ self.queues.real_a0, self.queues.fake_a0, self.gp_cacher_a0,
+ 0.5 * self.lambda_adv_a0, self._norm_disc_a0
+ )
+
+ if loss_gp_a0 is not None:
+ self.losses.gp_a0 = loss_gp_a0
+
+ loss_gp_a1, self.losses.disc_a1 = \
+ self.backward_discriminator_base(
+ self.models.disc_a1, self.images.real_a1, fake_a1,
+ self.queues.real_a1, self.queues.fake_a1, self.gp_cacher_a1,
+ 0.5 * self.lambda_adv_a1, self._norm_disc_a1
+ )
+
+ if loss_gp_a1 is not None:
+ self.losses.gp_a1 = loss_gp_a1
+
+ def optimization_step_gen(self):
+ self.set_requires_grad(
+ [self.models.disc_a0, self.models.disc_a1, self.models.disc_b],
+ False
+ )
+ self.optimizers.gen.zero_grad(set_to_none = True)
+
+ dir_list = [ 'cyc-aba', 'cyc-bab' ]
+
+ if self.lambda_idt_bb:
+ dir_list += [ 'idt-bb' ]
+
+ if self.lambda_idt_aa:
+ dir_list += [ 'idt-aa' ]
+
+ for direction in dir_list:
+ self.forward_dispatch(direction)
+ self.backward_gen(direction)
+
+ clip_gradients(self.optimizers.gen, **self._grad_clip)
+ self.optimizers.gen.step()
+
+ def optimization_step_disc(self):
+ self.set_requires_grad(
+ [self.models.disc_a0, self.models.disc_a1, self.models.disc_b],
+ True
+ )
+ self.optimizers.disc.zero_grad(set_to_none = True)
+
+ self.backward_discriminators()
+
+ clip_gradients(self.optimizers.disc, **self._grad_clip)
+ self.optimizers.disc.step()
+
+ def _accumulate_averages(self):
+ update_average_model(
+ self.models.ema_gen_ab, self.models.gen_ab, self.ema_momentum
+ )
+ update_average_model(
+ self.models.ema_gen_ba, self.models.gen_ba, self.ema_momentum
+ )
+
+ def optimization_step(self):
+ self.optimization_step_gen()
+ self.optimization_step_disc()
+
+ if self.ema_momentum is not None:
+ self._accumulate_averages()
+
diff --git a/uvcgan_s/config/__init__.py b/uvcgan_s/config/__init__.py
new file mode 100644
index 0000000..6c20b3e
--- /dev/null
+++ b/uvcgan_s/config/__init__.py
@@ -0,0 +1,4 @@
+from .args import Args
+from .config import Config
+from .model_config import ModelConfig
+
diff --git a/uvcgan_s/config/args.py b/uvcgan_s/config/args.py
new file mode 100644
index 0000000..306ed00
--- /dev/null
+++ b/uvcgan_s/config/args.py
@@ -0,0 +1,101 @@
+import difflib
+import os
+from .config import Config
+
+LABEL_FNAME = 'label'
+
+def get_config_difference(config_old, config_new):
+ diff_gen = difflib.unified_diff(
+ config_old.to_json(sort_keys = True, indent = 4).split('\n'),
+ config_new.to_json(sort_keys = True, indent = 4).split('\n'),
+ fromfile = 'Old Config',
+ tofile = 'New Config',
+ )
+
+ return "\n".join(diff_gen)
+
+class Args:
+ __slots__ = [
+ 'config',
+ 'label',
+ 'savedir',
+ 'checkpoint',
+ 'log_level',
+ ]
+
+ def __init__(
+ self, config, savedir, label,
+ log_level = 'INFO',
+ checkpoint = 100,
+ ):
+ # pylint: disable=too-many-arguments
+ self.config = config
+ self.label = label
+ self.savedir = savedir
+ self.checkpoint = checkpoint
+ self.log_level = log_level
+
+ def __getattr__(self, attr):
+ return getattr(self.config, attr)
+
+ def save(self):
+ self.config.save(self.savedir)
+
+ if self.label is not None:
+ # pylint: disable=unspecified-encoding
+ with open(os.path.join(self.savedir, LABEL_FNAME), 'wt') as f:
+ f.write(self.label)
+
+ def check_no_collision(self):
+ try:
+ old_config = Config.load(self.savedir)
+ except IOError:
+ return
+
+ old = old_config.to_json(sort_keys = True)
+ new = self.config.to_json(sort_keys = True)
+
+ if old != new:
+ diff = get_config_difference(old_config, self.config)
+
+ raise RuntimeError(
+ (
+ f"Config collision detected in '{self.savedir}'"
+ f" . Difference:\n{diff}\n"
+ "If you would like to overwrite the config then delete the"
+ f" old one in '{self.savedir}' ."
+ )
+ )
+
+ @staticmethod
+ def from_args_dict(
+ outdir,
+ label = None,
+ log_level = 'INFO',
+ checkpoint = 100,
+ **args_dict
+ ):
+ config = Config(**args_dict)
+ savedir = config.get_savedir(outdir, label)
+
+ result = Args(config, savedir, label, log_level, checkpoint)
+ result.check_no_collision()
+
+ result.save()
+
+ return result
+
+ @staticmethod
+ def load(savedir):
+ config = Config.load(savedir)
+ label = None
+
+ label_path = os.path.join(savedir, LABEL_FNAME)
+
+ if os.path.exists(label_path):
+ # pylint: disable=unspecified-encoding
+ with open(label_path, 'rt') as f:
+ label = f.read()
+
+ return Args(config, savedir, label)
+
diff --git a/uvcgan_s/config/config.py b/uvcgan_s/config/config.py
new file mode 100644
index 0000000..77bdf7a
--- /dev/null
+++ b/uvcgan_s/config/config.py
@@ -0,0 +1,142 @@
+import json
+import logging
+import os
+
+from uvcgan_s.consts import CONFIG_NAME
+
+from .config_base import ConfigBase
+from .data_config import parse_data_config
+from .model_config import ModelConfig
+from .transfer_config import TransferConfig
+
+LOGGER = logging.getLogger('uvcgan_s.config')
+
+class Config(ConfigBase):
+ # pylint: disable=too-many-instance-attributes
+
+ __slots__ = [
+ 'batch_size',
+ 'data',
+ 'epochs',
+ 'discriminator',
+ 'generator',
+ 'model',
+ 'model_args',
+ 'loss',
+ 'gradient_penalty',
+ 'seed',
+ 'scheduler',
+ 'steps_per_epoch',
+ 'transfer',
+ ]
+
+ def __init__(
+ self,
+ batch_size = 32,
+ data = None,
+ data_args = None,
+ epochs = 100,
+ image_shape = None,
+ discriminator = None,
+ generator = None,
+ model = 'cyclegan',
+ model_args = None,
+ loss = 'lsgan',
+ gradient_penalty = None,
+ seed = 0,
+ scheduler = None,
+ steps_per_epoch = 250,
+ transfer = None,
+ workers = None,
+ ):
+ # pylint: disable=too-many-arguments
+ # pylint: disable=too-many-locals
+ self.data = parse_data_config(data, data_args, image_shape, workers)
+
+ self.batch_size = batch_size
+ self.model = model
+ self.model_args = model_args or {}
+ self.seed = seed
+ self.loss = loss
+ self.epochs = epochs
+ self.scheduler = scheduler
+ self.steps_per_epoch = steps_per_epoch
+
+ if discriminator is not None:
+ discriminator = ModelConfig(**discriminator)
+
+ if generator is not None:
+ generator = ModelConfig(**generator)
+
+ if gradient_penalty is True:
+ gradient_penalty = {}
+
+ if transfer is not None:
+ if isinstance(transfer, list):
+ transfer = [ TransferConfig(**conf) for conf in transfer ]
+ else:
+ transfer = TransferConfig(**transfer)
+
+ self.discriminator = discriminator
+ self.generator = generator
+ self.gradient_penalty = gradient_penalty
+ self.transfer = transfer
+
+ Config._check_deprecated_args(image_shape, workers)
+
+ if image_shape is not None:
+ self._validate_image_shape(image_shape)
+
+ @staticmethod
+ def _check_deprecated_args(image_shape, workers):
+ if image_shape is not None:
+ LOGGER.warning(
+ "Deprecation Warning: Deprecated `image_shape` configuration "
+ "parameter detected."
+ )
+
+ if workers is not None:
+ LOGGER.warning(
+ "Deprecation Warning: Deprecated `workers` configuration "
+ "parameter detected."
+ )
+
+ def _validate_image_shape(self, image_shape):
+ assert all(d.shape == image_shape for d in self.data.datasets), (
+ f"Value of the deprecated `image_shape` parameter {image_shape}"
+ f"does not match shapes of the datasets."
+ )
+
+ def get_savedir(self, outdir, label = None):
+ if label is None:
+ label = self.get_hash()
+
+ discriminator = None
+ if self.discriminator is not None:
+ discriminator = self.discriminator.model
+
+ generator = None
+ if self.generator is not None:
+ generator = self.generator.model
+
+ savedir = 'model_m(%s)_d(%s)_g(%s)_%s' % (
+ self.model, discriminator, generator, label
+ )
+
+ savedir = savedir.replace('/', ':')
+ path = os.path.join(outdir, savedir)
+
+ os.makedirs(path, exist_ok = True)
+ return path
+
+ def save(self, path):
+ # pylint: disable=unspecified-encoding
+ with open(os.path.join(path, CONFIG_NAME), 'wt') as f:
+ f.write(self.to_json(sort_keys = True, indent = ' '))
+
+ @staticmethod
+ def load(path):
+ # pylint: disable=unspecified-encoding
+ with open(os.path.join(path, CONFIG_NAME), 'rt') as f:
+ return Config(**json.load(f))
+
diff --git a/uvcgan_s/config/config_base.py b/uvcgan_s/config/config_base.py
new file mode 100644
index 0000000..813445a
--- /dev/null
+++ b/uvcgan_s/config/config_base.py
@@ -0,0 +1,27 @@
+import json
+import hashlib
+
+class ConfigBase:
+
+ __slots__ = []
+
+ def to_dict(self):
+ return { x : getattr(self, x) for x in self.__slots__ }
+
+ def to_json(self, **kwargs):
+ return json.dumps(self, default = lambda x : x.to_dict(), **kwargs)
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def __setitem__(self, key, value):
+ return setattr(self, key, value)
+
+ def get_hash(self):
+ s = self.to_json(sort_keys = True)
+
+ md5 = hashlib.md5()
+ md5.update(s.encode())
+
+ return md5.hexdigest()
+
diff --git a/uvcgan_s/config/data_config.py b/uvcgan_s/config/data_config.py
new file mode 100644
index 0000000..ba808ca
--- /dev/null
+++ b/uvcgan_s/config/data_config.py
@@ -0,0 +1,281 @@
+import logging
+
+from uvcgan_s.consts import MERGE_PAIRED, MERGE_UNPAIRED, MERGE_NONE
+from uvcgan_s.utils.funcs import check_value_in_range
+
+from .config_base import ConfigBase
+
+LOGGER = logging.getLogger('uvcgan_s.config')
+MERGE_TYPES = [ MERGE_PAIRED, MERGE_UNPAIRED, MERGE_NONE ]
+
+class DatasetConfig(ConfigBase):
+ """Dataset configuration.
+
+ Parameters
+ ----------
+ dataset : str or dict
+ Dataset specification.
+ shape : tuple of int
+ Shape of inputs.
+ transform_train : None or str or dict or list of those
+ Transformations to be applied to the training dataset.
+ If `transform_train` is None, then no transformations will be applied
+ to the training dataset.
+ If `transform_train` is str, then its value is interpreted as a name
+ of the transformation.
+ If `transform_train` is dict, then it is expected to be of the form
+ `{ 'name' : TRANFORM_NAME, **kwargs }`, where 'name' is the name of
+ the transformation, and `kwargs` dict will be passed to the
+ transformation constructor.
+ Otherwise, `transform_train` is expected to be a list of values above.
+ The corresponding transformations will be chained together in the
+ order that they are specified.
+ Default: None.
+ transform_val : None or str or dict or list of those
+ Transformations to be applied to the validation dataset.
+ C.f. `transform_train`.
+ Default: None.
+ """
+
+ __slots__ = [
+ 'dataset',
+ 'shape',
+ 'transform_train',
+ 'transform_test',
+ ]
+
+ def __init__(
+ self, dataset, shape,
+ transform_train = None,
+ transform_test = None,
+ ):
+ super().__init__()
+
+ self.dataset = dataset
+ self.shape = shape
+ self.transform_train = transform_train
+ self.transform_test = transform_test
+
+class DataConfig(ConfigBase):
+ """Data configuration.
+
+ Parameters
+ ----------
+ datasets : list of dict
+ List of dataset specifications.
+ merge_type : str, optional
+ How to merge samples from datasets.
+ Choices: 'paired', 'unpaired', 'none'.
+ Default: 'unpaired'
+ workers : int, optional
+ Number of data workers.
+ Default: None
+ """
+
+ __slots__ = [
+ 'datasets',
+ 'merge_type',
+ 'workers',
+ ]
+
+ def __init__(self, datasets, merge_type = MERGE_UNPAIRED, workers = None):
+ super().__init__()
+
+ check_value_in_range(merge_type, MERGE_TYPES, 'merge_type')
+ assert isinstance(datasets, list)
+
+ self.datasets = [ DatasetConfig(**x) for x in datasets ]
+ self.merge_type = merge_type
+ self.workers = workers
+
+def parse_deprecated_data_config_v1_celeba(
+ dataset_args, image_shape, workers, transform_train, transform_val
+):
+ attr = dataset_args.get('attr', None)
+
+ if attr is None:
+ domains = [ None, ]
+ else:
+ domains = [ 'a', 'b' ]
+
+ return DataConfig(
+ datasets = [
+ {
+ 'dataset' : {
+ 'name' : 'celeba',
+ 'attr' : attr,
+ 'domain' : domain,
+ 'path' : dataset_args.get('path', None),
+ },
+ 'shape' : image_shape,
+ 'transform_train' : transform_train,
+ 'transform_test' : transform_val,
+ } for domain in domains
+ ],
+ merge_type = 'unpaired',
+ workers = workers,
+ )
+
+def parse_deprecated_data_config_v1_cyclegan(
+ dataset_args, image_shape, workers, transform_train, transform_val
+):
+ return DataConfig(
+ datasets = [
+ {
+ 'dataset' : {
+ 'name' : 'cyclegan',
+ 'domain' : domain,
+ 'path' : dataset_args.get('path', None),
+ },
+ 'shape' : image_shape,
+ 'transform_train' : transform_train,
+ 'transform_test' : transform_val,
+ } for domain in ['a', 'b']
+ ],
+ merge_type = 'unpaired',
+ workers = workers,
+ )
+
+def parse_deprecated_data_config_v1_imagedir(
+ dataset_args, image_shape, workers, transform_train, transform_val
+):
+ return DataConfig(
+ datasets = [
+ {
+ 'dataset' : {
+ 'name' : 'imagedir',
+ 'path' : dataset_args.get('path', None),
+ },
+ 'shape' : image_shape,
+ 'transform_train' : transform_train,
+ 'transform_test' : transform_val,
+ },
+ ],
+ merge_type = 'none',
+ workers = workers,
+ )
+
+def parse_deprecated_data_config_v1_toyzero_precropped(
+ dataset_args, image_shape, workers, transform_train, transform_val
+):
+ align_train = dataset_args.get('align_train', False)
+
+ return DataConfig(
+ datasets = [
+ {
+ 'dataset' : {
+ 'name' : 'toyzero-precropped-v1',
+ 'domain' : domain,
+ 'path' : dataset_args.get('path', None),
+ },
+ 'shape' : image_shape,
+ 'transform_train' : transform_train,
+ 'transform_test' : transform_val,
+ } for domain in [ 'a', 'b' ]
+ ],
+ merge_type = 'paired' if align_train else 'unpaired',
+ workers = workers,
+ )
+
+def parse_deprecated_data_config_v1_toyzero_presimple(
+ dataset, dataset_args, image_shape, workers, transform_train, transform_val
+):
+ # pylint: disable=too-many-arguments
+ if dataset == 'toyzero-presimple':
+ merge_type = 'paired'
+
+ elif dataset == 'toyzero-preunaligned':
+ if dataset_args.get('align_train', False):
+ merge_type = 'paired'
+ else:
+ merge_type = 'unpaired'
+
+ else:
+ raise ValueError(f"Unsupported dataset '{dataset}'")
+
+ return DataConfig(
+ datasets = [
+ {
+ 'dataset' : {
+ 'name' : 'toyzero-presimple-v1',
+ 'domain' : domain,
+ 'fname' : dataset_args.get('fname', None),
+ 'path' : dataset_args.get('path', None),
+ 'seed' : dataset_args.get('seed', 0),
+ 'shuffle' : dataset_args.get('shuffle', True),
+ },
+ 'shape' : image_shape,
+ 'transform_train' : transform_train,
+ 'transform_test' : transform_val,
+ } for domain in [ 'a', 'b' ]
+ ],
+ merge_type = merge_type,
+ workers = workers,
+ )
+
+def parse_deprecated_data_config_v1(
+ dataset, dataset_args, image_shape, workers,
+ transform_train = None, transform_val = None
+):
+ # pylint: disable=too-many-arguments
+ if dataset == 'celeba':
+ return parse_deprecated_data_config_v1_celeba(
+ dataset_args, image_shape, workers, transform_train, transform_val
+ )
+
+ if dataset == 'cyclegan':
+ return parse_deprecated_data_config_v1_cyclegan(
+ dataset_args, image_shape, workers, transform_train, transform_val
+ )
+
+ if dataset == 'imagedir':
+ return parse_deprecated_data_config_v1_imagedir(
+ dataset_args, image_shape, workers, transform_train, transform_val
+ )
+
+ if dataset == 'toyzero-precropped':
+ return parse_deprecated_data_config_v1_toyzero_precropped(
+ dataset_args, image_shape, workers, transform_train, transform_val
+ )
+
+ if dataset in [ 'toyzero-presimple', 'toyzero-preunaligned' ]:
+ return parse_deprecated_data_config_v1_toyzero_presimple(
+ dataset, dataset_args, image_shape, workers,
+ transform_train, transform_val
+ )
+
+ raise NotImplementedError(
+ f"Do not know how to parse deprecated '{dataset}'"
+ )
+
+def parse_data_config(data, data_args, image_shape, workers):
+ if isinstance(data, str):
+ LOGGER.warning(
+ "Deprecation Warning: Old (v0) dataset configuration detected."
+ " Please modify your configuration and change `data` parameter"
+ " into a dictionary describing `DataConfig` structure."
+ )
+ return parse_deprecated_data_config_v1(
+ data, data_args, image_shape, workers
+ )
+
+ assert data_args is None, \
+ "Deprecated `data_args` argument detected with new data configuration"
+
+ if (
+ ('dataset' in data)
+ or ('dataset_args' in data)
+ or ('transform_train' in data)
+ or ('transform_val' in data)
+ ):
+ LOGGER.warning(
+ "Deprecation Warning: Old (v1) dataset configuration detected."
+ " Please modify your configuration and change `data` parameter"
+ " into a dictionary describing `DataConfig` structure."
+ )
+ return parse_deprecated_data_config_v1(
+ **data, image_shape = image_shape, workers = workers
+ )
+
+ return DataConfig(**data)
+
diff --git a/uvcgan_s/config/funcs.py b/uvcgan_s/config/funcs.py
new file mode 100644
index 0000000..ec599f4
--- /dev/null
+++ b/uvcgan_s/config/funcs.py
@@ -0,0 +1,7 @@
+import os
+
+def create_evaldir(path, eval_name):
+ result = os.path.join(path, eval_name)
+ os.makedirs(result, exist_ok = True)
+
+ return result
diff --git a/uvcgan_s/config/model_config.py b/uvcgan_s/config/model_config.py
new file mode 100644
index 0000000..2b287c5
--- /dev/null
+++ b/uvcgan_s/config/model_config.py
@@ -0,0 +1,34 @@
+
+class ModelConfig:
+
+ __slots__ = [
+ 'model',
+ 'model_args',
+ 'optimizer',
+ 'weight_init',
+ 'lr_equal',
+ 'spectr_norm',
+ ]
+
+ def __init__(
+ self,
+ model,
+ optimizer = None,
+ model_args = None,
+ weight_init = None,
+ lr_equal = False,
+ spectr_norm = False,
+ ):
+ # pylint: disable=too-many-arguments
+ self.model = model
+ self.model_args = model_args or {}
+ self.optimizer = optimizer or {
+ 'name' : 'AdamW', 'betas' : (0.5, 0.999), 'weight_decay' : 1e-5,
+ }
+ self.weight_init = weight_init
+ self.lr_equal = lr_equal
+ self.spectr_norm = spectr_norm
+
+ def to_dict(self):
+ return { x : getattr(self, x) for x in self.__slots__ }
+
diff --git a/uvcgan_s/config/transfer_config.py b/uvcgan_s/config/transfer_config.py
new file mode 100644
index 0000000..cbb1a5a
--- /dev/null
+++ b/uvcgan_s/config/transfer_config.py
@@ -0,0 +1,53 @@
+from .config_base import ConfigBase
+
+class TransferConfig(ConfigBase):
+ """Model transfer configuration.
+
+ Parameters
+ ----------
+ base_model : str
+ Path to the model to transfer parameters from. If path is relative
+ then this path is interpreted relative to `ROOT_OUTDIR`.
+ transfer_map : dict or None
+ Mapping between networks names of the current model and the model to
+ transfer parameters from. For example, mapping of the form
+ `{ 'gen_ab' : 'gen' }` will initialize generator `gen_ab` of the
+ current model from the `gen` generator of the base model.
+ Default: None.
+ strict : bool
+ Value of the pytorch's strict parameter when loading parameters.
+ Default: True.
+ allow_partial : bool
+ Whether to allow transfer from the last checkpoint of a partially
+ trained base model.
+ Default: False.
+ fuzzy : str, optional
+ Allow fuzzy transfer. E.g. transfer parameters from a larger model to
+ a smaller one.
+ Choices: [ 'none', 'from-larger-model' ]
+ Default: None.
+ """
+
+ __slots__ = [
+ 'base_model',
+ 'transfer_map',
+ 'strict',
+ 'allow_partial',
+ 'fuzzy',
+ ]
+
+ def __init__(
+ self,
+ base_model,
+ transfer_map = None,
+ strict = True,
+ allow_partial = False,
+ fuzzy = None,
+ ):
+ # pylint: disable=too-many-arguments
+ self.base_model = base_model
+ self.transfer_map = transfer_map or {}
+ self.strict = strict
+ self.allow_partial = allow_partial
+ self.fuzzy = fuzzy
+
diff --git a/uvcgan_s/consts.py b/uvcgan_s/consts.py
new file mode 100644
index 0000000..8e79b95
--- /dev/null
+++ b/uvcgan_s/consts.py
@@ -0,0 +1,16 @@
+import os
+
+CONFIG_NAME = 'config.json'
+ROOT_DATA = os.path.join(os.environ.get('UVCGAN_S_DATA', 'data'))
+ROOT_OUTDIR = os.path.join(os.environ.get('UVCGAN_S_OUTDIR', 'outdir'))
+
+SPLIT_TRAIN = 'train'
+SPLIT_VAL = 'val'
+SPLIT_TEST = 'test'
+
+MERGE_PAIRED = 'paired'
+MERGE_UNPAIRED = 'unpaired'
+MERGE_NONE = 'none'
+
+MODEL_STATE_TRAIN = 'train'
+MODEL_STATE_EVAL = 'eval'
diff --git a/uvcgan_s/data/__init__.py b/uvcgan_s/data/__init__.py
new file mode 100644
index 0000000..9b02bd0
--- /dev/null
+++ b/uvcgan_s/data/__init__.py
@@ -0,0 +1,2 @@
+from .data import construct_data_loaders, construct_datasets
+
diff --git a/uvcgan_s/data/data.py b/uvcgan_s/data/data.py
new file mode 100644
index 0000000..0ddfe66
--- /dev/null
+++ b/uvcgan_s/data/data.py
@@ -0,0 +1,129 @@
+import os
+import torch
+
+import torchvision
+
+from uvcgan_s.consts import (
+ ROOT_DATA, SPLIT_TRAIN, MERGE_PAIRED, MERGE_UNPAIRED
+)
+from uvcgan_s.torch.select import extract_name_kwargs
+
+from .datasets.celeba import CelebaDataset
+from .datasets.image_domain_folder import ImageDomainFolder
+from .datasets.image_domain_hierarchy import ImageDomainHierarchy
+from .datasets.zipper import DatasetZipper
+from .datasets.ndarray_domain_hierarchy import NDArrayDomainHierarchy
+from .datasets.h5array_domain_hierarchy import H5ArrayDomainHierarchy
+from .datasets.toy_mix_blur_dataset import ToyMixBlurDataset
+
+from .loader_zipper import DataLoaderZipper
+from .transforms import select_transform
+
+def select_dataset(name, path, split, transform, **kwargs):
+ # pylint: disable=too-many-return-statements
+ # pylint: disable=too-many-branches
+
+ if name == 'celeba':
+ return CelebaDataset(
+ path, transform = transform, split = split, **kwargs
+ )
+
+ if name in [ 'cyclegan', 'image-domain-folder' ]:
+ return ImageDomainFolder(
+ path, transform = transform, split = split, **kwargs
+ )
+
+ if name in [ 'image-domain-hierarchy' ]:
+ return ImageDomainHierarchy(
+ path, transform = transform, split = split, **kwargs
+ )
+
+ if name == 'imagenet':
+ return torchvision.datasets.ImageNet(
+ path, transform = transform, split = split, **kwargs
+ )
+
+ if name in [ 'imagedir', 'image-folder' ]:
+ return torchvision.datasets.ImageFolder(
+ os.path.join(path, split), transform = transform, **kwargs
+ )
+
+ if name == 'ndarray-domain-hierarchy':
+ return NDArrayDomainHierarchy(
+ path, transform = transform, split = split, **kwargs
+ )
+
+ if name == 'h5array-domain-hierarchy':
+ return H5ArrayDomainHierarchy(
+ path, transform = transform, split = split, **kwargs
+ )
+
+ if name == 'toy-mix-blur':
+ return ToyMixBlurDataset(
+ path, transform = transform, split = split, **kwargs
+ )
+
+ raise ValueError(f"Unknown dataset: '{name}'")
+
+def construct_single_dataset(dataset_config, split):
+ name, kwargs = extract_name_kwargs(dataset_config.dataset)
+ path = os.path.join(ROOT_DATA, kwargs.pop('path', name))
+
+ if split == SPLIT_TRAIN:
+ transform = select_transform(dataset_config.transform_train)
+ else:
+ transform = select_transform(dataset_config.transform_test)
+
+ return select_dataset(name, path, split, transform, **kwargs)
+
+def construct_datasets(data_config, split):
+ return [
+ construct_single_dataset(config, split)
+ for config in data_config.datasets
+ ]
+
+def construct_single_loader(
+ dataset, batch_size, shuffle,
+ workers = None,
+ prefetch_factor = 2,
+ **kwargs
+):
+ if workers is None:
+ workers = min(torch.get_num_threads(), 20)
+
+ return torch.utils.data.DataLoader(
+ dataset, batch_size,
+ shuffle = shuffle,
+ num_workers = workers,
+ prefetch_factor = prefetch_factor,
+ pin_memory = True,
+ **kwargs
+ )
+
+def construct_data_loaders(data_config, batch_size, split):
+ datasets = construct_datasets(data_config, split)
+ shuffle = (split == SPLIT_TRAIN)
+
+ if data_config.merge_type == MERGE_PAIRED:
+ dataset = DatasetZipper(datasets)
+
+ return construct_single_loader(
+ dataset, batch_size, shuffle, data_config.workers,
+ drop_last = False
+ )
+
+ loaders = [
+ construct_single_loader(
+ dataset, batch_size, shuffle, data_config.workers,
+ drop_last = (data_config.merge_type == MERGE_UNPAIRED)
+ ) for dataset in datasets
+ ]
+
+ if data_config.merge_type == MERGE_UNPAIRED:
+ return DataLoaderZipper(loaders)
+
+ if len(loaders) == 1:
+ return loaders[0]
+
+ return loaders
+
diff --git a/uvcgan_s/data/datasets/__init__.py b/uvcgan_s/data/datasets/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/uvcgan_s/data/datasets/celeba.py b/uvcgan_s/data/datasets/celeba.py
new file mode 100644
index 0000000..d8f1765
--- /dev/null
+++ b/uvcgan_s/data/datasets/celeba.py
@@ -0,0 +1,112 @@
+import os
+import pandas as pd
+
+from torch.utils.data import Dataset
+from torchvision.datasets.folder import default_loader
+
+from uvcgan_s.consts import SPLIT_TRAIN, SPLIT_VAL, SPLIT_TEST
+from uvcgan_s.utils.funcs import check_value_in_range
+
+FNAME_ATTRS = 'list_attr_celeba.txt'
+FNAME_SPLIT = 'list_eval_partition.txt'
+SUBDIR_IMG = 'img_align_celeba'
+
+SPLITS = {
+ SPLIT_TRAIN : 0,
+ SPLIT_VAL : 1,
+ SPLIT_TEST : 2,
+}
+
+DOMAINS = [ 'a', 'b' ]
+
+class CelebaDataset(Dataset):
+
+ def __init__(
+ self, path,
+ attr = 'Young',
+ domain = 'a',
+ split = SPLIT_TRAIN,
+ transform = None,
+ **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ check_value_in_range(split, SPLITS, 'CelebaDataset: split')
+
+ if attr is None:
+ assert domain is None
+ else:
+ check_value_in_range(domain, DOMAINS, 'CelebaDataset: domain')
+
+ super().__init__(**kwargs)
+
+ self._path = path
+ self._root_imgs = os.path.join(path, SUBDIR_IMG)
+ self._split = split
+ self._attr = attr
+ self._domain = domain
+ self._imgs = []
+ self._transform = transform
+
+ self._collect_files()
+
+ def _collect_files(self):
+ imgs_specs = CelebaDataset.load_image_specs(self._path)
+
+ imgs = CelebaDataset.partition_images(
+ imgs_specs, self._split, self._attr, self._domain
+ )
+
+ self._imgs = [ os.path.join(self._root_imgs, x) for x in imgs ]
+
+ @staticmethod
+ def load_image_partition(root):
+ path = os.path.join(root, FNAME_SPLIT)
+
+ return pd.read_csv(
+ path, sep = r'\s+', header = None, names = [ 'partition', ],
+ index_col = 0
+ )
+
+ @staticmethod
+ def load_image_attrs(root):
+ path = os.path.join(root, FNAME_ATTRS)
+
+ return pd.read_csv(
+ path, sep = r'\s+', skiprows = 1, header = 0, index_col = 0
+ )
+
+ @staticmethod
+ def load_image_specs(root):
+ df_partition = CelebaDataset.load_image_partition(root)
+ df_attrs = CelebaDataset.load_image_attrs(root)
+
+ return df_partition.join(df_attrs)
+
+ @staticmethod
+ def partition_images(image_specs, split, attr, domain):
+ part_mask = (image_specs.partition == SPLITS[split])
+
+ if attr is None:
+ imgs = image_specs[part_mask].index.to_list()
+ else:
+ if domain == 'a':
+ domain_mask = (image_specs[attr] > 0)
+ else:
+ domain_mask = (image_specs[attr] < 0)
+
+ imgs = image_specs[part_mask & domain_mask].index.to_list()
+
+ return imgs
+
+ def __len__(self):
+ return len(self._imgs)
+
+ def __getitem__(self, index):
+ path = self._imgs[index]
+ result = default_loader(path)
+
+ if self._transform is not None:
+ result = self._transform(result)
+
+ return result
+
diff --git a/uvcgan_s/data/datasets/funcs.py b/uvcgan_s/data/datasets/funcs.py
new file mode 100644
index 0000000..d955f5a
--- /dev/null
+++ b/uvcgan_s/data/datasets/funcs.py
@@ -0,0 +1,11 @@
+
+def cantor_pairing(x, y, mod = (1 << 32)):
+ # https://stackoverflow.com/questions/919612/mapping-two-integers-to-one-in-a-unique-and-deterministic-way
+ # https://en.wikipedia.org/wiki/Pairing_function#Cantor_pairing_function
+ result = (x + y) * (x + y + 1) // 2 + y
+
+ if mod is not None:
+ result = result % mod
+
+ return result
+
diff --git a/uvcgan_s/data/datasets/h5array_domain_hierarchy.py b/uvcgan_s/data/datasets/h5array_domain_hierarchy.py
new file mode 100644
index 0000000..910d84c
--- /dev/null
+++ b/uvcgan_s/data/datasets/h5array_domain_hierarchy.py
@@ -0,0 +1,53 @@
+import os
+
+import numpy as np
+from torch.utils.data import Dataset
+
+from uvcgan_s.consts import SPLIT_TRAIN
+
+DSET_INDEX = 'index'
+DSET_DATA = 'data'
+
+H5_EXT = [ '', '.h5', '.hdf5' ]
+
+
+class H5ArrayDomainHierarchy(Dataset):
+
+ def __init__(
+ self, path, domain,
+ split = SPLIT_TRAIN,
+ transform = None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self._path = None
+ path_base = os.path.join(path, split, domain)
+
+ for ext in H5_EXT:
+ path = path_base + ext
+ if os.path.exists(path):
+ self._path = path
+ break
+ else:
+ raise RuntimeError(f"Failed to find h5 dataset '{path_base}'")
+
+ # pylint: disable=import-outside-toplevel
+ import h5py
+
+ self._f = h5py.File(self._path, 'r')
+ self._dset = self._f.get(DSET_DATA)
+
+ self._transform = transform
+
+ def __len__(self):
+ return len(self._dset)
+
+ def __getitem__(self, index):
+ result = np.float32(self._dset[index])
+
+ if self._transform is not None:
+ result = self._transform(result)
+
+ return result
+
diff --git a/uvcgan_s/data/datasets/image_domain_folder.py b/uvcgan_s/data/datasets/image_domain_folder.py
new file mode 100644
index 0000000..d7fee75
--- /dev/null
+++ b/uvcgan_s/data/datasets/image_domain_folder.py
@@ -0,0 +1,76 @@
+import os
+
+from torch.utils.data import Dataset
+from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
+
+from uvcgan_s.consts import SPLIT_TRAIN
+
+class ImageDomainFolder(Dataset):
+ """Dataset structure introduced in a CycleGAN paper.
+
+ This dataset expects images to be arranged into subdirectories
+ under `path`: `trainA`, `trainB`, `testA`, `testB`. Here, `trainA`
+ subdirectory contains training images from domain "a", `trainB`
+ subdirectory contains training images from domain "b", and so on.
+
+ Parameters
+ ----------
+ path : str
+ Path where the dataset is located.
+ domain : str
+ Choices: 'a', 'b'.
+ split : str
+ Choices: 'train', 'test', 'val'
+ transform : Callable or None,
+ Optional transformation to apply to images.
+ E.g. torchvision.transforms.RandomCrop.
+ Default: None
+ """
+
+ def __init__(
+ self, path,
+ domain = 'a',
+ split = SPLIT_TRAIN,
+ transform = None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ subdir = split + domain.upper()
+
+ self._path = os.path.join(path, subdir)
+ self._imgs = ImageDomainFolder.find_images_in_dir(self._path)
+ self._transform = transform
+
+ @staticmethod
+ def find_images_in_dir(path):
+ extensions = set(IMG_EXTENSIONS)
+
+ result = []
+ for fname in os.listdir(path):
+ fullpath = os.path.join(path, fname)
+
+ if not os.path.isfile(fullpath):
+ continue
+
+ ext = os.path.splitext(fname)[1]
+ if ext not in extensions:
+ continue
+
+ result.append(fullpath)
+
+ result.sort()
+ return result
+
+ def __len__(self):
+ return len(self._imgs)
+
+ def __getitem__(self, index):
+ path = self._imgs[index]
+ result = default_loader(path)
+
+ if self._transform is not None:
+ result = self._transform(result)
+
+ return result
+
diff --git a/uvcgan_s/data/datasets/image_domain_hierarchy.py b/uvcgan_s/data/datasets/image_domain_hierarchy.py
new file mode 100644
index 0000000..bdbeba5
--- /dev/null
+++ b/uvcgan_s/data/datasets/image_domain_hierarchy.py
@@ -0,0 +1,34 @@
+import os
+
+from torch.utils.data import Dataset
+from torchvision.datasets.folder import default_loader
+
+from uvcgan_s.consts import SPLIT_TRAIN
+from .image_domain_folder import ImageDomainFolder
+
+class ImageDomainHierarchy(Dataset):
+
+ def __init__(
+ self, path, domain,
+ split = SPLIT_TRAIN,
+ transform = None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self._path = os.path.join(path, split, domain)
+ self._imgs = ImageDomainFolder.find_images_in_dir(self._path)
+ self._transform = transform
+
+ def __len__(self):
+ return len(self._imgs)
+
+ def __getitem__(self, index):
+ path = self._imgs[index]
+ result = default_loader(path)
+
+ if self._transform is not None:
+ result = self._transform(result)
+
+ return result
+
diff --git a/uvcgan_s/data/datasets/ndarray_domain_hierarchy.py b/uvcgan_s/data/datasets/ndarray_domain_hierarchy.py
new file mode 100644
index 0000000..5a05341
--- /dev/null
+++ b/uvcgan_s/data/datasets/ndarray_domain_hierarchy.py
@@ -0,0 +1,55 @@
+import os
+
+import numpy as np
+from torch.utils.data import Dataset
+
+from uvcgan_s.consts import SPLIT_TRAIN
+
+def find_ndarrays_in_dir(path):
+ result = []
+
+ for fname in os.listdir(path):
+ fullpath = os.path.join(path, fname)
+
+ if not os.path.isfile(fullpath):
+ continue
+
+ ext = os.path.splitext(fname)[1]
+ if ext != '.npz':
+ continue
+
+ result.append(fullpath)
+
+ result.sort()
+ return result
+
+def load_ndarray(path):
+ with np.load(path) as f:
+ return f[f.files[0]]
+
+class NDArrayDomainHierarchy(Dataset):
+
+ def __init__(
+ self, path, domain,
+ split = SPLIT_TRAIN,
+ transform = None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self._path = os.path.join(path, split, domain)
+ self._arrays = find_ndarrays_in_dir(self._path)
+ self._transform = transform
+
+ def __len__(self):
+ return len(self._arrays)
+
+ def __getitem__(self, index):
+ path = self._arrays[index]
+ result = np.float32(load_ndarray(path))
+
+ if self._transform is not None:
+ result = self._transform(result)
+
+ return result
+
diff --git a/uvcgan_s/data/datasets/svhn.py b/uvcgan_s/data/datasets/svhn.py
new file mode 100644
index 0000000..ecd38a2
--- /dev/null
+++ b/uvcgan_s/data/datasets/svhn.py
@@ -0,0 +1,26 @@
+from torchvision.datasets import SVHN
+
+class SVHNDataset(SVHN):
+
+ def __init__(
+ self, path, split, transform, return_target = False, **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__(
+ path,
+ split = split,
+ transform = transform,
+ download = True,
+ **kwargs
+ )
+
+ self._return_target = return_target
+
+ def __getitem__(self, index):
+ item = super().__getitem__(index)
+
+ if self._return_target:
+ return item
+
+ return item[0]
+
diff --git a/uvcgan_s/data/datasets/toy_mix_blur_dataset.py b/uvcgan_s/data/datasets/toy_mix_blur_dataset.py
new file mode 100644
index 0000000..b80e788
--- /dev/null
+++ b/uvcgan_s/data/datasets/toy_mix_blur_dataset.py
@@ -0,0 +1,84 @@
+import os
+import random
+
+import torchvision.transforms.functional as TF
+
+from torch.utils.data import Dataset
+from torchvision.datasets.folder import default_loader
+
+from uvcgan_s.consts import SPLIT_TRAIN
+from .image_domain_folder import ImageDomainFolder
+
+
+class ToyMixBlurDataset(Dataset):
+ # pylint: disable=too-many-instance-attributes
+
+ def __init__(
+ self, path,
+ domain_a0 = 'cat',
+ domain_a1 = 'dog',
+ split = SPLIT_TRAIN,
+ alpha = 0.5,
+ alpha_range = 0.0,
+ blur_kernel_size = 5,
+ seed = 42,
+ transform = None,
+ **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__(**kwargs)
+
+ self._split = split
+ self._alpha = alpha
+ self._alpha_range = alpha_range
+ self._blur_ks = blur_kernel_size
+ self._transform = transform
+
+ path_a0 = os.path.join(path, split, domain_a0)
+ path_a1 = os.path.join(path, split, domain_a1)
+
+ self._imgs_a0 = ImageDomainFolder.find_images_in_dir(path_a0)
+ self._imgs_a1 = ImageDomainFolder.find_images_in_dir(path_a1)
+
+ self._rng = random.Random(seed)
+
+ self._pairing_indices = list(range(len(self._imgs_a1)))
+ self._rng.shuffle(self._pairing_indices)
+
+ def __len__(self):
+ return len(self._imgs_a0)
+
+ def _get_alpha(self):
+ if self._split == SPLIT_TRAIN:
+ offset = random.uniform(
+ -self._alpha_range / 2, self._alpha_range / 2
+ )
+ return self._alpha + offset
+ else:
+ return self._alpha
+
+ def _apply_blur(self, img):
+ return TF.gaussian_blur(img, kernel_size=self._blur_ks)
+
+ def __getitem__(self, index):
+
+ if self._split == SPLIT_TRAIN:
+ index_a1 = random.randint(0, len(self._imgs_a1) - 1)
+ else:
+ index_a1 = self._pairing_indices[
+ index % len(self._pairing_indices)
+ ]
+
+ img_a0 = default_loader(self._imgs_a0[index])
+ img_a1 = default_loader(self._imgs_a1[index_a1])
+
+ if self._transform is not None:
+ img_a0 = self._transform(img_a0)
+ img_a1 = self._transform(img_a1)
+
+ alpha = self._get_alpha()
+ mixed = alpha * img_a0 + (1.0 - alpha) * img_a1
+ mixed = self._apply_blur(mixed)
+
+ return mixed
+
diff --git a/uvcgan_s/data/datasets/zipper.py b/uvcgan_s/data/datasets/zipper.py
new file mode 100644
index 0000000..bb89129
--- /dev/null
+++ b/uvcgan_s/data/datasets/zipper.py
@@ -0,0 +1,24 @@
+from torch.utils.data import Dataset
+
+class DatasetZipper(Dataset):
+
+ def __init__(self, datasets, **kwargs):
+ super().__init__(**kwargs)
+
+ assert len(datasets) > 0, \
+ "DatasetZipper does not know how to zip empty list of datasets"
+
+ self._datasets = datasets
+ self._len = len(datasets[0])
+
+ lengths = [ len(dset) for dset in datasets ]
+
+ assert all(x == self._len for x in lengths), \
+ f"DatasetZipper cannot zip datasets of unequal lengths: {lengths}"
+
+ def __len__(self):
+ return self._len
+
+ def __getitem__(self, index):
+ return tuple(d[index] for d in self._datasets)
+
diff --git a/uvcgan_s/data/loader_zipper.py b/uvcgan_s/data/loader_zipper.py
new file mode 100644
index 0000000..db3d3a0
--- /dev/null
+++ b/uvcgan_s/data/loader_zipper.py
@@ -0,0 +1,12 @@
+
+class DataLoaderZipper:
+
+ def __init__(self, loaders):
+ self._loaders = loaders
+
+ def __len__(self):
+ return min(len(d) for d in self._loaders)
+
+ def __iter__(self):
+ return zip(*self._loaders)
+
diff --git a/uvcgan_s/data/transforms.py b/uvcgan_s/data/transforms.py
new file mode 100644
index 0000000..67b7281
--- /dev/null
+++ b/uvcgan_s/data/transforms.py
@@ -0,0 +1,105 @@
+import torch
+import torchvision
+from torchvision import transforms
+
+from uvcgan_s.torch.select import extract_name_kwargs
+
+FromNumpy = lambda : torch.from_numpy
+
+TRANSFORM_DICT = {
+ 'center-crop' : transforms.CenterCrop,
+ 'color-jitter' : transforms.ColorJitter,
+ 'random-crop' : transforms.RandomCrop,
+ 'random-flip-vertical' : transforms.RandomVerticalFlip,
+ 'random-flip-horizontal' : transforms.RandomHorizontalFlip,
+ 'random-rotation' : transforms.RandomRotation,
+ 'random-resize-crop' : transforms.RandomResizedCrop,
+ 'random-solarize' : transforms.RandomSolarize,
+ 'random-invert' : transforms.RandomInvert,
+ 'gaussian-blur' : transforms.GaussianBlur,
+ 'resize' : transforms.Resize,
+ 'normalize' : transforms.Normalize,
+ 'pad' : transforms.Pad,
+ 'grayscale' : transforms.Grayscale,
+ 'to-tensor' : transforms.ToTensor,
+ 'from-numpy' : FromNumpy,
+ 'CenterCrop' : transforms.CenterCrop,
+ 'ColorJitter' : transforms.ColorJitter,
+ 'RandomCrop' : transforms.RandomCrop,
+ 'RandomVerticalFlip' : transforms.RandomVerticalFlip,
+ 'RandomHorizontalFlip' : transforms.RandomHorizontalFlip,
+ 'RandomRotation' : transforms.RandomRotation,
+ 'Resize' : transforms.Resize,
+}
+
+INTERPOLATION_DICT = {
+ 'nearest' : transforms.InterpolationMode.NEAREST,
+ 'bilinear' : transforms.InterpolationMode.BILINEAR,
+ 'bicubic' : transforms.InterpolationMode.BICUBIC,
+ 'lanczos' : transforms.InterpolationMode.LANCZOS,
+}
+
+def parse_interpolation(kwargs):
+ if 'interpolation' in kwargs:
+ kwargs['interpolation'] = INTERPOLATION_DICT[kwargs['interpolation']]
+
+def select_single_transform(transform):
+ name, kwargs = extract_name_kwargs(transform)
+
+ if name == 'random-apply':
+ transform = select_transform_basic(kwargs.pop('transforms'))
+ return transforms.RandomApply(transform, **kwargs)
+
+ if name not in TRANSFORM_DICT:
+ raise ValueError(f"Unknown transform: '{name}'")
+
+ parse_interpolation(kwargs)
+
+ return TRANSFORM_DICT[name](**kwargs)
+
+def select_transform_basic(transform, compose = False):
+ result = []
+
+ if transform is not None:
+ if not isinstance(transform, (list, tuple)):
+ transform = [ transform, ]
+
+ result = [
+ select_single_transform(x) for x in transform if x != 'none'
+ ]
+
+ if compose:
+ if len(result) == 1:
+ return result[0]
+ else:
+ return torchvision.transforms.Compose(result)
+ else:
+ return result
+
+def select_transform(transform, add_to_tensor = True):
+ if transform == 'none':
+ return None
+
+ result = select_transform_basic(transform)
+
+ # NOTE: this uglinness is for backward compat
+ if add_to_tensor:
+ if not isinstance(transform, (list, tuple)):
+ transform = [ transform, ]
+
+ need_transform = True
+
+ if any(t == 'to-tensor' for t in transform):
+ need_transform = False
+
+ if any(t == 'from-numpy' for t in transform):
+ need_transform = False
+
+ if any(t == 'none' for t in transform):
+ need_transform = False
+
+ if need_transform:
+ result.append(torchvision.transforms.ToTensor())
+
+ return torchvision.transforms.Compose(result)
+
diff --git a/uvcgan_s/eval/__init__.py b/uvcgan_s/eval/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/uvcgan_s/eval/funcs.py b/uvcgan_s/eval/funcs.py
new file mode 100644
index 0000000..999506c
--- /dev/null
+++ b/uvcgan_s/eval/funcs.py
@@ -0,0 +1,112 @@
+import os
+import math
+from itertools import islice
+
+from uvcgan_s.config import Args
+from uvcgan_s.consts import (
+ MODEL_STATE_TRAIN, MODEL_STATE_EVAL, MERGE_NONE
+)
+from uvcgan_s.data import construct_data_loaders
+from uvcgan_s.torch.funcs import get_torch_device_smart, seed_everything
+from uvcgan_s.cgan import construct_model
+
+def slice_data_loader(loader, batch_size, n_samples = None):
+ if n_samples is None:
+ return (loader, len(loader))
+
+ steps = min(math.ceil(n_samples / batch_size), len(loader))
+ sliced_loader = islice(loader, steps)
+
+ return (sliced_loader, steps)
+
+def tensor_to_image(tensor):
+ result = tensor.cpu().detach().numpy()
+
+ if tensor.ndim == 4:
+ result = result.squeeze(0)
+
+ result = result.transpose((1, 2, 0))
+ return result
+
+def override_config(config, config_overrides):
+ if config_overrides is None:
+ return
+
+ for (k,v) in config_overrides.items():
+ config[k] = v
+
+def get_evaldir(root, epoch, mkdir = False):
+ if epoch is None:
+ result = os.path.join(root, 'evals', 'final')
+ else:
+ result = os.path.join(root, 'evals', 'epoch_%04d' % epoch)
+
+ if mkdir:
+ os.makedirs(result, exist_ok = True)
+
+ return result
+
+def set_model_state(model, state):
+ if state == MODEL_STATE_TRAIN:
+ model.train()
+ elif state == MODEL_STATE_EVAL:
+ model.eval()
+ else:
+ raise ValueError(f"Unknown model state '{state}'")
+
+def start_model_eval(path, epoch, model_state, merge_type, **config_overrides):
+ args = Args.load(path)
+ device = get_torch_device_smart()
+
+ override_config(args.config, config_overrides)
+
+ if merge_type is not None:
+ args.config.data.merge_type = merge_type
+
+ model = construct_model(
+ args.savedir, args.config,
+ is_train = (model_state == MODEL_STATE_TRAIN),
+ device = device
+ )
+
+ if epoch == -1:
+ epoch = max(model.find_last_checkpoint_epoch(), 0)
+
+ print("Load checkpoint at epoch %s" % epoch)
+
+ seed_everything(args.config.seed)
+ model.load(epoch)
+
+ set_model_state(model, model_state)
+ evaldir = get_evaldir(path, epoch, mkdir = True)
+
+ return (args, model, evaldir)
+
+def load_eval_model_dset_from_cmdargs(
+ cmdargs, merge_type = MERGE_NONE, **config_overrides
+):
+ args, model, evaldir = start_model_eval(
+ cmdargs.model, cmdargs.epoch, cmdargs.model_state,
+ merge_type = merge_type,
+ batch_size = cmdargs.batch_size, **config_overrides
+ )
+
+ data_it = construct_data_loaders(
+ args.config.data, args.config.batch_size, split = cmdargs.split
+ )
+
+ return (args, model, data_it, evaldir)
+
+def get_eval_savedir(evaldir, prefix, model_state, split, mkdir = False):
+ result = os.path.join(evaldir, f'{prefix}_{model_state}-{split}')
+
+ if mkdir:
+ os.makedirs(result, exist_ok = True)
+
+ return result
+
+def make_image_subdirs(model, savedir):
+ for name in model.images:
+ path = os.path.join(savedir, name)
+ os.makedirs(path, exist_ok = True)
+
diff --git a/uvcgan_s/models/__init__.py b/uvcgan_s/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/uvcgan_s/models/discriminator/__init__.py b/uvcgan_s/models/discriminator/__init__.py
new file mode 100644
index 0000000..e701fb9
--- /dev/null
+++ b/uvcgan_s/models/discriminator/__init__.py
@@ -0,0 +1,26 @@
+from uvcgan_s.base.networks import select_base_discriminator
+from uvcgan_s.models.funcs import default_model_init
+
+from .resnet import ResNetDisc, ResNetFFTDisc
+from .dcgan import DCGANDiscriminator
+
+DISC_DICT = {
+ 'resnet' : ResNetDisc,
+ 'resnet-fft' : ResNetFFTDisc,
+ 'dcgan' : DCGANDiscriminator,
+}
+
+def select_discriminator(name, **kwargs):
+ if name in DISC_DICT:
+ return DISC_DICT[name](**kwargs)
+
+ return select_base_discriminator(name, **kwargs)
+
+def construct_discriminator(model_config, image_shape, device):
+ model = select_discriminator(
+ model_config.model, image_shape = image_shape,
+ **model_config.model_args
+ )
+
+ return default_model_init(model, model_config, device)
+
diff --git a/uvcgan_s/models/discriminator/dcgan.py b/uvcgan_s/models/discriminator/dcgan.py
new file mode 100644
index 0000000..0ae08e4
--- /dev/null
+++ b/uvcgan_s/models/discriminator/dcgan.py
@@ -0,0 +1,121 @@
+import math
+import logging
+
+from torch import nn
+from torchvision.transforms import CenterCrop
+
+from uvcgan_s.torch.select import get_activ_layer, get_norm_layer
+
+LOGGER = logging.getLogger('models.discriminator.dcgan')
+
+DEF_NORM = 'batch'
+DEF_ACTIV = {
+ 'name' : 'leakyrelu',
+ 'negative_slope' : 0.2,
+}
+
+def math_prod(shape):
+ result = 1
+
+ for x in shape:
+ result *= x
+
+ return result
+
+def get_padding_layer(input_shape, downscale_factor):
+ need_pad = False
+
+ if (
+ (input_shape[1] < downscale_factor)
+ or (input_shape[2] < downscale_factor)
+ ):
+ need_pad = True
+
+ LOGGER.warning(
+ "DCGAN input shape '%s' is smaller than the downscale factor '%d'."
+ " Adding padding.", tuple(input_shape), downscale_factor
+ )
+
+ if (
+ (input_shape[1] % downscale_factor != 0)
+ or (input_shape[2] % downscale_factor != 0)
+ ):
+ need_pad = True
+
+ LOGGER.warning(
+ "DCGAN input shape '%s' is not divisible by the downscale "
+ " factor '%d'. Adding padding.",
+ tuple(input_shape), downscale_factor
+ )
+
+ h = math.ceil(input_shape[1] / downscale_factor) * downscale_factor
+ w = math.ceil(input_shape[2] / downscale_factor) * downscale_factor
+
+ if not need_pad:
+ return None, h, w
+
+ return CenterCrop((h, w)), h ,w
+
+class DCGANDiscriminator(nn.Module):
+
+ def __init__(
+ self, image_shape, features_list, activ = DEF_ACTIV, norm = DEF_NORM,
+ ):
+ # pylint: disable=dangerous-default-value
+ # pylint: disable=too-many-arguments
+ super().__init__()
+
+ self._input_shape = image_shape
+
+ curr_features = image_shape[0]
+ downscale_factor = 1
+ layers = []
+
+ layers.append(nn.Sequential(
+ nn.Conv2d(
+ curr_features, features_list[0], kernel_size = 5,
+ stride = 2, padding = 2
+ ),
+ get_activ_layer(activ)
+ ))
+
+ downscale_factor *= 2
+ curr_features = features_list[0]
+
+ for features in features_list[1:]:
+ layers.append(nn.Sequential(
+ nn.Conv2d(
+ curr_features, features, kernel_size = 5,
+ stride = 2, padding = 2
+ ),
+ get_norm_layer(norm, features),
+ get_activ_layer(activ)
+ ))
+
+ downscale_factor *= 2
+ curr_features = features
+
+ padding, h, w = get_padding_layer(image_shape, downscale_factor)
+ curr_shape = (
+ curr_features, h // downscale_factor, w // downscale_factor
+ )
+
+ if padding is not None:
+ layers = [ padding, ] + layers
+
+ self._net_main = nn.Sequential(*layers)
+ dense_features = math_prod(curr_shape)
+
+ self._net_output = nn.Sequential(
+ nn.Flatten(),
+ nn.Linear(dense_features, 1),
+ )
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+
+ # y : (N, Ci, Hi, Wi)
+ y = self._net_main(x)
+
+ return self._net_output(y)
+
diff --git a/uvcgan_s/models/discriminator/resnet.py b/uvcgan_s/models/discriminator/resnet.py
new file mode 100644
index 0000000..ded2642
--- /dev/null
+++ b/uvcgan_s/models/discriminator/resnet.py
@@ -0,0 +1,99 @@
+import torch
+from torch import nn
+
+from uvcgan_s.torch.select import get_activ_layer
+from uvcgan_s.torch.layers.resnet import ResNetEncoder
+
+class ResNetDisc(nn.Module):
+
+ def __init__(
+ self, image_shape, block_specs, activ, norm,
+ rezero = True, activ_output = None, reduce_output_channels = True
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__()
+
+ self._net = ResNetEncoder(
+ image_shape, block_specs, activ, norm, rezero
+ )
+
+ self._output_shape = self._net.output_shape
+ self._input_shape = image_shape
+
+ output_layers = []
+
+ if reduce_output_channels:
+ output_layers.append(
+ nn.Conv2d(self._output_shape[0], 1, kernel_size = 1)
+ )
+ self._output_shape = (1, *self._output_shape[1:])
+
+ if activ_output is not None:
+ output_layers.append(get_activ_layer(activ_output))
+
+ self._out = nn.Sequential(*output_layers)
+
+ @property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def forward(self, x):
+ y = self._net(x)
+ return self._out(y)
+
+class ResNetFFTDisc(nn.Module):
+
+ def __init__(
+ self, image_shape, block_specs, activ, norm,
+ rezero = True, activ_output = None, reduce_output_channels = True
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__()
+
+ # Doubling channels for FFT amplitudes
+ input_shape = (2 * image_shape[0], *image_shape[1:])
+
+ self._net = ResNetEncoder(
+ input_shape, block_specs, activ, norm, rezero
+ )
+
+ self._output_shape = self._net.output_shape
+ self._input_shape = image_shape
+
+ output_layers = []
+
+ if reduce_output_channels:
+ output_layers.append(
+ nn.Conv2d(self._output_shape[0], 1, kernel_size = 1)
+ )
+ self._output_shape = (1, *self._output_shape[1:])
+
+ if activ_output is not None:
+ output_layers.append(get_activ_layer(activ_output))
+
+ self._out = nn.Sequential(*output_layers)
+
+ @property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+ # x_fft : (N, C, H, W)
+ x_fft = torch.fft.fft2(x).abs()
+
+ # z : (N, 2C, H, W)
+ z = torch.cat([ x, x_fft ], dim = 1)
+
+ y = self._net(z)
+
+ return self._out(y)
+
diff --git a/uvcgan_s/models/funcs.py b/uvcgan_s/models/funcs.py
new file mode 100644
index 0000000..86e1efb
--- /dev/null
+++ b/uvcgan_s/models/funcs.py
@@ -0,0 +1,17 @@
+from uvcgan_s.base.weight_init import init_weights
+from uvcgan_s.torch.funcs import prepare_model
+from uvcgan_s.torch.lr_equal import apply_lr_equal
+from uvcgan_s.torch.spectr_norm import apply_sn
+
+def default_model_init(model, model_config, device):
+ model = prepare_model(model, device)
+ init_weights(model, model_config.weight_init)
+
+ if model_config.lr_equal:
+ apply_lr_equal(model)
+
+ if model_config.spectr_norm:
+ apply_sn(model)
+
+ return model
+
diff --git a/uvcgan_s/models/generator/__init__.py b/uvcgan_s/models/generator/__init__.py
new file mode 100644
index 0000000..e5afc02
--- /dev/null
+++ b/uvcgan_s/models/generator/__init__.py
@@ -0,0 +1,42 @@
+from uvcgan_s.base.networks import select_base_generator
+from uvcgan_s.models.funcs import default_model_init
+
+from .vitunet import ViTUNetGenerator
+from .vitmodnet import ViTModNetGenerator, CViTModNetGenerator
+from .resnet import ResNetGen
+from .dcgan import DCGANGenerator
+
+def select_generator(name, **kwargs):
+ # pylint: disable=too-many-return-statements
+
+ if name == 'vit-unet':
+ return ViTUNetGenerator(**kwargs)
+
+ if name == 'vit-modnet':
+ return ViTModNetGenerator(**kwargs)
+
+ if name == 'cvit-modnet':
+ return CViTModNetGenerator(**kwargs)
+
+ if name == 'resnet':
+ return ResNetGen(**kwargs)
+
+ if name == 'dcgan':
+ return DCGANGenerator(**kwargs)
+
+ input_shape = kwargs.pop('input_shape')
+ output_shape = kwargs.pop('output_shape')
+
+ assert input_shape == output_shape
+ return select_base_generator(name, image_shape = input_shape, **kwargs)
+
+def construct_generator(model_config, input_shape, output_shape, device):
+ model = select_generator(
+ model_config.model,
+ input_shape = input_shape,
+ output_shape = output_shape,
+ **model_config.model_args
+ )
+
+ return default_model_init(model, model_config, device)
+
diff --git a/uvcgan_s/models/generator/dcgan.py b/uvcgan_s/models/generator/dcgan.py
new file mode 100644
index 0000000..ef88c84
--- /dev/null
+++ b/uvcgan_s/models/generator/dcgan.py
@@ -0,0 +1,98 @@
+# import math
+import logging
+
+from torch import nn
+from torchvision.transforms import CenterCrop
+
+from uvcgan_s.torch.select import get_activ_layer, get_norm_layer
+
+LOGGER = logging.getLogger('models.generator.dcgan')
+
+DEF_FEATURES = 1024
+DEF_NORM = 'batch'
+DEF_ACTIV = 'relu'
+
+def math_prod(shape):
+ result = 1
+
+ for x in shape:
+ result *= x
+
+ return result
+
+class DCGANGenerator(nn.Module):
+
+ def __init__(
+ self, input_shape, output_shape, features_list,
+ activ = DEF_ACTIV,
+ norm = DEF_NORM,
+ activ_output = None
+ ):
+ # pylint: disable=dangerous-default-value
+ # pylint: disable=too-many-arguments
+ super().__init__()
+
+ self._input_shae = input_shape
+ self._output_shape = output_shape
+
+ # to reshape into (2, 2)
+ dense_features = 4 * features_list[0]
+
+ self._net_in = nn.Sequential(
+ nn.Flatten(),
+ nn.Linear(math_prod(input_shape), dense_features),
+ get_activ_layer(activ),
+ )
+
+ layers = []
+
+ curr_shape = (features_list[0], 2, 2)
+
+ for features in features_list[1:]:
+ layers.append(nn.Sequential(
+ nn.ConvTranspose2d(
+ curr_shape[0], features,
+ kernel_size = 5, stride = 2, padding = 2,
+ output_padding = 1,
+ ),
+ get_norm_layer(norm, features),
+ get_activ_layer(activ),
+ ))
+
+ curr_shape = (features, 2 * curr_shape[1], 2 * curr_shape[2])
+
+ layers.append(nn.Sequential(
+ nn.ConvTranspose2d(
+ curr_shape[0], output_shape[0],
+ kernel_size = 5, stride = 2, padding = 2,
+ output_padding = 1,
+ ),
+ get_activ_layer(activ_output),
+ ))
+
+ curr_shape = (output_shape[0], 2 * curr_shape[1], 2 * curr_shape[2])
+
+ if curr_shape != tuple(output_shape):
+ LOGGER.warning(
+ "DCGAN output shape '%s' is not equal to the expected output"
+ " shape '%s'. Adding center cropping.",
+ curr_shape, tuple(output_shape)
+ )
+
+ layers.append(CenterCrop(tuple(output_shape[1:])))
+
+ self._net_main = nn.Sequential(*layers)
+
+ def forward(self, z):
+ # z : (N, F)
+
+ # x : (N, 4 * C)
+ x = self._net_in(z)
+
+ # x : (N, C, 2, 2)
+ x = x.reshape((x.shape[0], -1, 2, 2))
+
+ result = self._net_main(x)
+
+ return result
+
diff --git a/uvcgan_s/models/generator/resnet.py b/uvcgan_s/models/generator/resnet.py
new file mode 100644
index 0000000..466058b
--- /dev/null
+++ b/uvcgan_s/models/generator/resnet.py
@@ -0,0 +1,40 @@
+from torch import nn
+
+from uvcgan_s.torch.select import get_activ_layer
+from uvcgan_s.torch.layers.resnet import ResNetEncoder
+
+class ResNetGen(nn.Module):
+
+ def __init__(
+ self, input_shape, output_shape, block_specs, activ, norm,
+ rezero = True, activ_output = None
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__()
+
+ self._net = ResNetEncoder(
+ input_shape, block_specs, activ, norm, rezero
+ )
+
+ self._output_shape = self._net.output_shape
+ self._input_shape = input_shape
+
+ assert tuple(output_shape) == self._output_shape, (
+ f"Output shape {self._output_shape}"
+ f" != desired shape {output_shape}"
+ )
+
+ self._out = get_activ_layer(activ_output)
+
+ @property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def forward(self, x):
+ y = self._net(x)
+ return self._out(y)
+
diff --git a/uvcgan_s/models/generator/vit.py b/uvcgan_s/models/generator/vit.py
new file mode 100644
index 0000000..53ad83e
--- /dev/null
+++ b/uvcgan_s/models/generator/vit.py
@@ -0,0 +1,78 @@
+# pylint: disable=too-many-arguments
+# pylint: disable=too-many-instance-attributes
+
+import numpy as np
+from torch import nn
+
+from uvcgan_s.torch.layers.transformer import (
+ calc_tokenized_size, ViTInput, TransformerEncoder, img_to_tokens,
+ img_from_tokens
+)
+
+class ViTGenerator(nn.Module):
+
+ def __init__(
+ self, features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm, input_shape, output_shape, token_size,
+ rescale = False, rezero = True, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ assert input_shape == output_shape
+ image_shape = input_shape
+
+ self.image_shape = image_shape
+ self.token_size = token_size
+ self.token_shape = (image_shape[0], *token_size)
+ self.token_features = np.prod([image_shape[0], *token_size])
+ self.N_h, self.N_w = calc_tokenized_size(image_shape, token_size)
+ self.rescale = rescale
+
+ self.gan_input = ViTInput(
+ self.token_features, embed_features, features, self.N_h, self.N_w
+ )
+
+ self.trans = TransformerEncoder(
+ features, ffn_features, n_heads, n_blocks, activ, norm, rezero
+ )
+
+ self.gan_output = nn.Linear(features, self.token_features)
+
+ # pylint: disable=no-self-use
+ def calc_scale(self, x):
+ # x : (N, C, H, W)
+ return x.abs().mean(dim = (1, 2, 3), keepdim = True) + 1e-8
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+ if self.rescale:
+ scale = self.calc_scale(x)
+ x = x / scale
+
+ # itokens : (N, N_h, N_w, C, H_c, W_c)
+ itokens = img_to_tokens(x, self.token_shape[1:])
+
+ # itokens : (N, N_h, N_w, C, H_c, W_c)
+ # -> (N, N_h * N_w, C * H_c * W_c)
+ # = (N, L, in_features)
+ itokens = itokens.reshape((itokens.shape[0], self.N_h * self.N_w, -1))
+
+ # y : (N, L, features)
+ y = self.gan_input(itokens)
+ y = self.trans(y)
+
+ # otokens : (N, L, in_features)
+ otokens = self.gan_output(y)
+
+ # otokens : (N, L, in_features)
+ # -> (N, N_h, N_w, C, H_c, W_c)
+ otokens = otokens.reshape((
+ otokens.shape[0], self.N_h, self.N_w, *self.token_shape
+ ))
+
+ result = img_from_tokens(otokens)
+ if self.rescale:
+ result = result * scale
+
+ return result
+
diff --git a/uvcgan_s/models/generator/vitgan.py b/uvcgan_s/models/generator/vitgan.py
new file mode 100644
index 0000000..0796870
--- /dev/null
+++ b/uvcgan_s/models/generator/vitgan.py
@@ -0,0 +1,216 @@
+# pylint: disable=too-many-arguments
+# pylint: disable=too-many-instance-attributes
+
+import torch
+
+import numpy as np
+from torch import nn
+
+from uvcgan_s.torch.select import select_activation
+from uvcgan_s.torch.layers.transformer import (
+ calc_tokenized_size, img_to_tokens, img_from_tokens,
+ ViTInput, TransformerEncoder, FourierEmbedding
+)
+
+class ModulatedLinear(nn.Module):
+ # arXiv: 2011.13775
+
+ def __init__(
+ self, in_features, out_features, w_features, activ, eps = 1e-8,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.weight = nn.Parameter(torch.empty((out_features, in_features)))
+ self.bias = nn.Parameter(torch.empty((out_features,)))
+ self._eps = eps
+
+ self.A = nn.Linear(w_features, in_features)
+ self.activ = select_activation(activ)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.kaiming_normal_(self.weight)
+ nn.init.zeros_(self.bias)
+
+ def get_modulated_weight(self, s):
+ # s : (N, in_features)
+
+ # w : (1, out_features, in_features)
+ w = self.weight.unsqueeze(0)
+
+ # s : (N, in_features)
+ # -> (N, 1, in_features)
+ s = s.unsqueeze(1)
+
+ # weight : (N, out_features, in_features)
+ weight = w * s
+
+ # norm : (N, out_features, 1)
+ norm = torch.rsqrt(
+ self._eps + torch.sum(weight**2, dim = 2, keepdim = True)
+ )
+
+ return weight * norm
+
+ def forward(self, x, w):
+ # x : (N, L, in_features)
+ # w : (N, w_features)
+
+ # s : (N, in_features)
+ s = self.A(w)
+
+ # weight : (N, out_features, in_features)
+ weight = self.get_modulated_weight(s)
+
+ # weight : (N, out_features, in_features)
+ # x : (N, L, in_features)
+ # result : (N, L, out_features)
+ result = torch.matmul(weight.unsqueeze(1), x.unsqueeze(-1)).squeeze(-1)
+
+ result = result + self.bias
+
+ return self.activ(result)
+
+class ViTGANOutput(nn.Module):
+
+ def __init__(self, features, token_shape, activ, **kwargs):
+ super().__init__(**kwargs)
+
+ self.embed = ViTGANInput(features, token_shape[1], token_shape[2])
+
+ self.fc1 = ModulatedLinear(features, features, features, activ)
+ self.fc2 = ModulatedLinear(features, token_shape[0], features, None)
+
+ self._token_shape = token_shape
+
+ def _map_fn(self, x, w):
+ # x : (N, H_t * W_t, features)
+ # w : (N, features)
+
+ # result : (N, H_t * W_t, features)
+ result = self.fc1(x, w)
+
+ # result : (N, H_t * W_t, C)
+ result = self.fc2(result, w)
+
+ # result : (N, H_t * W_t, C)
+ # -> (N, C, H_t * W_t)
+ result = result.permute(0, 2, 1)
+
+ # (N, 1, C, H_t, W_t)
+ return result.reshape(
+ (result.shape[0], 1, result.shape[1], *self._token_shape[1:])
+ )
+
+ def forward(self, y):
+ # y : (N, L, features)
+ # e : (N, H_t * W_t, features)
+ e = self.embed(len(y))
+
+ # result : (N, L, C, H_t, W_t)
+ result = torch.stack(
+ [ self._map_fn(e, w) for w in torch.unbind(y, dim = 1) ],
+ dim = 1
+ )
+
+ return result
+
+class ViTGANInput(nn.Module):
+
+ def __init__(self, features, height, width, **kwargs):
+ super().__init__(**kwargs)
+ self._height = height
+ self._width = width
+
+ x = torch.arange(width).to(torch.float32)
+ y = torch.arange(height).to(torch.float32)
+
+ x, y = torch.meshgrid(x, y)
+ self.x = x.reshape((1, -1))
+ self.y = y.reshape((1, -1))
+
+ self.register_buffer('x_const', self.x)
+ self.register_buffer('y_const', self.y)
+
+ self.embed = FourierEmbedding(features, height, width)
+
+ def forward(self, batch_size = None):
+ # result : (1, height * width, features)
+ result = self.embed(self.y_const, self.x_const)
+
+ if batch_size is not None:
+ # result : (1, height * width, features)
+ # -> (batch_size, height * width, features)
+ result = result.expand((batch_size, *result.shape[1:]))
+
+ return result
+
+class ViTGANGenerator(nn.Module):
+
+ def __init__(
+ self, features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm, input_shape, output_shape, token_size, rescale = False,
+ rezero = True, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ assert input_shape == output_shape
+ image_shape = input_shape
+
+ self.image_shape = image_shape
+ self.token_size = token_size
+ self.token_shape = (image_shape[0], *token_size)
+ self.token_features = np.prod([image_shape[0], *token_size])
+ self.N_h, self.N_w = calc_tokenized_size(image_shape, token_size)
+ self.rescale = rescale
+
+ self.gan_input = ViTInput(
+ self.token_features, embed_features, features, self.N_h, self.N_w
+ )
+
+ self.encoder = TransformerEncoder(
+ features, ffn_features, n_heads, n_blocks, activ, norm, rezero
+ )
+
+ self.gan_output = ViTGANOutput(features, self.token_shape, 'relu')
+
+ # pylint: disable=no-self-use
+ def calc_scale(self, x):
+ # x : (N, C, H, W)
+ return x.abs().mean(dim = (1, 2, 3), keepdim = True) + 1e-8
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+ if self.rescale:
+ scale = self.calc_scale(x)
+ x = x / scale
+
+ # itokens : (N, N_h, N_w, C, H_c, W_c)
+ itokens = img_to_tokens(x, self.token_shape[1:])
+
+ # itokens : (N, N_h, N_w, C, H_c, W_c)
+ # -> (N, N_h * N_w, C * H_c * W_c)
+ # = (N, L, in_features)
+ itokens = itokens.reshape((itokens.shape[0], self.N_h * self.N_w, -1))
+
+ # y : (N, L, features)
+ y = self.gan_input(itokens)
+ y = self.trans(y)
+
+ # otokens : (N, L, C, H_t, W_t)
+ otokens = self.gan_output(y)
+
+ # otokens : (N, L, C, H_t, W_t)
+ # -> (N, N_h, N_w, C, H_c, W_c)
+ otokens = otokens.reshape(
+ (otokens.shape[0], self.N_h, self.N_w, *otokens.shape[3:])
+ )
+
+ result = img_from_tokens(otokens)
+ if self.rescale:
+ result = result * scale
+
+ return result
+
diff --git a/uvcgan_s/models/generator/vithybrid.py b/uvcgan_s/models/generator/vithybrid.py
new file mode 100644
index 0000000..5c33a56
--- /dev/null
+++ b/uvcgan_s/models/generator/vithybrid.py
@@ -0,0 +1,137 @@
+# pylint: disable=too-many-arguments
+# pylint: disable=too-many-instance-attributes
+
+from torch import nn
+
+from uvcgan_s.torch.select import get_norm_layer, get_activ_layer
+from uvcgan_s.torch.layers.transformer import PixelwiseViT
+from uvcgan_s.torch.layers.cnn import (
+ get_downsample_x2_layer, get_upsample_x2_layer
+)
+
+def construct_downsample_stem(
+ features_list, activ, norm, downsample, image_shape
+):
+ result = nn.Sequential()
+
+ result.add_module(
+ "downsample_base",
+ nn.Sequential(
+ nn.Conv2d(
+ image_shape[0], features_list[0], kernel_size = 3, padding = 1
+ ),
+ get_activ_layer(activ),
+ )
+ )
+
+ prev_features = features_list[0]
+ curr_size = image_shape[1:]
+
+ for idx,features in enumerate(features_list):
+ layer_down, next_features = \
+ get_downsample_x2_layer(downsample, features)
+
+ result.add_module(
+ "downsample_block_%d" % idx,
+ nn.Sequential(
+ get_norm_layer(norm, prev_features),
+ nn.Conv2d(
+ prev_features, features, kernel_size = 3, padding = 1
+ ),
+ get_activ_layer(activ),
+
+ layer_down,
+ )
+ )
+
+ prev_features = next_features
+ curr_size = (curr_size[0] // 2, curr_size[1] // 2)
+
+ return result, prev_features, curr_size
+
+def construct_upsample_stem(
+ features_list, input_features, activ, norm, upsample, image_shape
+):
+ result = nn.Sequential()
+ prev_features = input_features
+
+ for idx,features in reversed(list(enumerate(features_list))):
+ layer_up, next_features \
+ = get_upsample_x2_layer(upsample, prev_features)
+
+ result.add_module(
+ "upsample_block_%d" % idx,
+ nn.Sequential(
+ layer_up,
+
+ get_norm_layer(norm, next_features),
+ nn.Conv2d(
+ next_features, features, kernel_size = 3, padding = 1
+ ),
+ get_activ_layer(activ)
+ )
+ )
+
+ prev_features = features
+
+
+ result.add_module(
+ "upsample_base",
+ nn.Sequential(
+ nn.Conv2d(prev_features, image_shape[0], kernel_size = 1),
+ )
+ )
+
+ return result
+
+class ViTHybridGenerator(nn.Module):
+
+ def __init__(
+ self, features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm, input_shape, output_shape,
+ stem_features_list, stem_activ, stem_norm,
+ stem_downsample = 'conv',
+ stem_upsample = 'upsample-conv',
+ rezero = True,
+ **kwargs
+ ):
+ # pylint: disable=too-many-locals
+ super().__init__(**kwargs)
+
+ assert input_shape == output_shape
+ image_shape = input_shape
+
+ self.image_shape = image_shape
+
+ self.stem_down, self.token_features, output_size = \
+ construct_downsample_stem(
+ stem_features_list, stem_activ, stem_norm, stem_downsample,
+ image_shape
+ )
+
+ self.N_h, self.N_w = output_size
+
+ self.bottleneck = PixelwiseViT(
+ features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm,
+ image_shape = (self.token_features, self.N_h, self.N_w),
+ rezero = rezero
+ )
+
+ self.stem_up = construct_upsample_stem(
+ stem_features_list, self.token_features, stem_activ, stem_norm,
+ stem_upsample, image_shape
+ )
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+
+ # z : (N, token_features, N_h, N_w)
+ z = self.stem_down(x)
+ z = self.bottleneck(z)
+
+ # result : (N, C, H, W)
+ result = self.stem_up(z)
+
+ return result
+
diff --git a/uvcgan_s/models/generator/vitmodnet.py b/uvcgan_s/models/generator/vitmodnet.py
new file mode 100644
index 0000000..54fbf71
--- /dev/null
+++ b/uvcgan_s/models/generator/vitmodnet.py
@@ -0,0 +1,119 @@
+# pylint: disable=too-many-arguments
+# pylint: disable=too-many-instance-attributes
+
+from torch import nn
+
+from uvcgan_s.torch.layers.transformer import (
+ ExtendedPixelwiseViT, CExtPixelwiseViT
+)
+from uvcgan_s.torch.layers.modnet import ModNet
+from uvcgan_s.torch.select import get_activ_layer
+
+class ViTModNetGenerator(nn.Module):
+
+ def __init__(
+ self, features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm, input_shape, output_shape, modnet_features_list,
+ modnet_activ,
+ modnet_norm = None,
+ modnet_downsample = 'conv',
+ modnet_upsample = 'upsample-conv',
+ modnet_rezero = False,
+ modnet_demod = True,
+ rezero = True,
+ activ_output = None,
+ style_rezero = True,
+ style_bias = True,
+ n_ext = 1,
+ **kwargs
+ ):
+ # pylint: disable = too-many-locals
+ super().__init__(**kwargs)
+
+ mod_features = features * n_ext
+
+ self.net = ModNet(
+ modnet_features_list, modnet_activ, modnet_norm,
+ input_shape, output_shape,
+ modnet_downsample, modnet_upsample, mod_features, modnet_rezero,
+ modnet_demod, style_rezero, style_bias, return_mod = False
+ )
+
+ bottleneck = ExtendedPixelwiseViT(
+ features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm,
+ image_shape = self.net.get_inner_shape(),
+ rezero = rezero,
+ n_ext = n_ext,
+ )
+
+ self.net.set_bottleneck(bottleneck)
+
+ self.output = get_activ_layer(activ_output)
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+ result = self.net(x)
+ return self.output(result)
+
+class CViTModNetGenerator(nn.Module):
+
+ def __init__(
+ self, features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm, input_shape, output_shape, modnet_features_list,
+ modnet_activ,
+ modnet_norm = None,
+ modnet_downsample = 'conv',
+ modnet_upsample = 'upsample-conv',
+ modnet_rezero = False,
+ modnet_demod = True,
+ rezero = True,
+ activ_output = None,
+ style_rezero = True,
+ style_bias = True,
+ n_control_in = 100,
+ n_control_out = 100,
+ return_feedback = True,
+ **kwargs
+ ):
+ # pylint: disable = too-many-locals
+ super().__init__(**kwargs)
+
+ self.control_in = nn.Linear(n_control_in, features)
+ self.control_out = nn.Linear(features, n_control_out)
+ self.return_feedback = return_feedback
+ mod_features = features
+
+ self.net = ModNet(
+ modnet_features_list, modnet_activ, modnet_norm,
+ input_shape, output_shape,
+ modnet_downsample, modnet_upsample, mod_features, modnet_rezero,
+ modnet_demod, style_rezero, style_bias, return_mod = True
+ )
+
+ bottleneck = CExtPixelwiseViT(
+ features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm,
+ image_shape = self.net.get_inner_shape(),
+ rezero = rezero,
+ )
+
+ self.net.set_bottleneck(bottleneck)
+
+ self.output = get_activ_layer(activ_output)
+
+ def forward(self, x, control):
+ # x : (N, C, H, W)
+
+ mod = self.control_in(control)
+
+ result, mod = self.net(x, mod)
+
+ result = self.output(result)
+
+ if not self.return_feedback:
+ return result
+
+ feedback = self.control_out(mod)
+ return result, feedback
+
diff --git a/uvcgan_s/models/generator/vitunet.py b/uvcgan_s/models/generator/vitunet.py
new file mode 100644
index 0000000..442beac
--- /dev/null
+++ b/uvcgan_s/models/generator/vitunet.py
@@ -0,0 +1,49 @@
+# pylint: disable=too-many-arguments
+# pylint: disable=too-many-instance-attributes
+
+from torch import nn
+
+from uvcgan_s.torch.layers.transformer import PixelwiseViT
+from uvcgan_s.torch.layers.unet import UNet
+from uvcgan_s.torch.select import get_activ_layer
+
+class ViTUNetGenerator(nn.Module):
+
+ def __init__(
+ self, features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm, input_shape, output_shape,
+ unet_features_list, unet_activ, unet_norm,
+ unet_downsample = 'conv',
+ unet_upsample = 'upsample-conv',
+ unet_rezero = False,
+ rezero = True,
+ activ_output = None,
+ **kwargs
+ ):
+ # pylint: disable = too-many-locals
+ super().__init__(**kwargs)
+
+ self.image_shape = input_shape
+
+ self.net = UNet(
+ unet_features_list, unet_activ, unet_norm, input_shape,
+ unet_downsample, unet_upsample, unet_rezero,
+ output_shape = output_shape
+ )
+
+ bottleneck = PixelwiseViT(
+ features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm,
+ image_shape = self.net.get_inner_shape(),
+ rezero = rezero
+ )
+
+ self.net.set_bottleneck(bottleneck)
+
+ self.output = get_activ_layer(activ_output)
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+ result = self.net(x)
+ return self.output(result)
+
diff --git a/uvcgan_s/torch/__init__.py b/uvcgan_s/torch/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/uvcgan_s/torch/background_penalty.py b/uvcgan_s/torch/background_penalty.py
new file mode 100644
index 0000000..6dd319f
--- /dev/null
+++ b/uvcgan_s/torch/background_penalty.py
@@ -0,0 +1,35 @@
+from torch import nn
+
+class BackgroundPenaltyReduction(nn.Module):
+
+ def __init__(self, epochs_warmup, epochs_anneal, **kwargs):
+ super().__init__(**kwargs)
+
+ self._epochs_warmup = epochs_warmup
+ self._epochs_anneal = epochs_anneal
+
+ self._alpha = 0
+
+ def end_epoch(self, epoch):
+ if epoch is None:
+ self._alpha = 1
+ return
+
+ if epoch < self._epochs_warmup:
+ self._alpha = 0
+ return
+
+ progression = (epoch - self._epochs_warmup)
+
+ if progression < self._epochs_anneal:
+ self._alpha = progression / self._epochs_anneal
+ else:
+ self._alpha = 1
+
+ def forward(self, fake, real):
+ if self._alpha == 1:
+ return fake
+
+ result = fake - (1 - self._alpha) * fake * (real == 0)
+ return result
+
diff --git a/uvcgan_s/torch/data_norm.py b/uvcgan_s/torch/data_norm.py
new file mode 100644
index 0000000..4409a41
--- /dev/null
+++ b/uvcgan_s/torch/data_norm.py
@@ -0,0 +1,220 @@
+import torch
+from torchvision.transforms import Normalize
+from .select import extract_name_kwargs
+
+class DataNorm:
+
+ def normalize(self, x):
+ raise NotImplementedError
+
+ def denormalize(self, y):
+ raise NotImplementedError
+
+ def normalize_nograd(self, x):
+ with torch.no_grad():
+ return self.normalize(x)
+
+ def denormalize_nograd(self, y):
+ with torch.no_grad():
+ return self.denormalize(y)
+
+class Standardizer(DataNorm):
+
+ def __init__(self, mean, stdev):
+ super().__init__()
+ # fwd: (x - m) / s
+ # bkw: s * y + m
+
+ bkw_mean = [ -m/s for (m, s) in zip(mean, stdev) ]
+ bkw_stdev = [ 1/s for s in stdev ]
+
+ self._norm_fwd = Normalize(mean, stdev)
+ self._norm_bkw = Normalize(bkw_mean, bkw_stdev)
+
+ def normalize(self, x):
+ return self._norm_fwd(x)
+
+ def denormalize(self, y):
+ return self._norm_bkw(y)
+
+class ScaleNorm(DataNorm):
+
+ def __init__(self, scale = 1.0):
+ super().__init__()
+ self._scale = scale
+
+ def normalize(self, x):
+ return self._scale * x
+
+ def denormalize(self, y):
+ return y / self._scale
+
+class LogNorm(DataNorm):
+
+ def __init__(self, clip_min = None, bias = None):
+ super().__init__()
+
+ self._clip_min = clip_min
+ self._bias = bias
+
+ def normalize(self, x):
+ if self._clip_min is not None:
+ x = x.clip(min = self._clip_min, max = None)
+
+ if self._bias is not None:
+ x = x + self._bias
+
+ return torch.log(x)
+
+ def denormalize(self, y):
+ x = torch.exp(y)
+
+ if self._bias is not None:
+ x = x - self._bias
+
+ return x
+
+class SymLogNorm(DataNorm):
+ """
+ To ensure continuity of the function and its derivative:
+ y = scale * x if (x <= T)
+ y = scale * T * (log(x/T) + 1) if (x > T)
+
+ Inverse:
+ x = y / scale if (y <= scale * T)
+ x = T * exp(y / (scale * T) - 1) otherwise
+ """
+
+ def __init__(self, threshold = 1.0, scale = 1.0):
+ super().__init__()
+
+ self._scale = scale
+ self._T = threshold
+ self._inv_T = scale * threshold
+
+ def normalize(self, x):
+ x_abs = x.abs()
+
+ y_lin = self._scale * x
+ y_log = torch.sign(x) * self._inv_T * (torch.log(x_abs / self._T) + 1)
+
+ return torch.where(x_abs > self._T, y_log, y_lin)
+
+ def denormalize(self, y):
+ y_abs = y.abs()
+
+ x_lin = y / self._scale
+ x_log = torch.sign(y) * self._T * torch.exp(y_abs / self._inv_T - 1)
+
+ return torch.where(y_abs > self._inv_T, x_log, x_lin)
+
+class MinMaxScaler(DataNorm):
+
+ def __init__(self, feature_min, feature_max, dim):
+ super().__init__()
+
+ self._min = feature_min
+ self._max = feature_max
+ self._dim = dim
+
+ @staticmethod
+ def align_shapes(source, target, dim):
+ if source.ndim == 0:
+ source = source.unsqueeze(0)
+
+ n_before = dim
+ n_after = target.ndim - (dim + source.ndim)
+
+ return source.reshape(
+ ( 1, ) * n_before + source.shape + ( 1, ) * n_after
+ )
+
+ def normalize(self, x):
+ fmin = torch.tensor(self._min, dtype = x.dtype, device = x.device)
+ fmax = torch.tensor(self._max, dtype = x.dtype, device = x.device)
+
+ fmin = MinMaxScaler.align_shapes(fmin, x, self._dim)
+ fmax = MinMaxScaler.align_shapes(fmax, x, self._dim)
+
+ return (x - fmin) / (fmax - fmin)
+
+ def denormalize(self, y):
+ fmin = torch.tensor(self._min, dtype = y.dtype, device = y.device)
+ fmax = torch.tensor(self._max, dtype = y.dtype, device = y.device)
+
+ fmin = MinMaxScaler.align_shapes(fmin, y, self._dim)
+ fmax = MinMaxScaler.align_shapes(fmax, y, self._dim)
+
+ return (y * (fmax - fmin) + fmin)
+
+class DoublePrecision(DataNorm):
+
+ def __init__(self, norm):
+ self._norm = norm
+
+ def normalize(self, x):
+ old_dtype = x.dtype
+
+ result = x.double()
+ result = self._norm.normalize(result)
+
+ return result.to(dtype = old_dtype)
+
+ def denormalize(self, y):
+ old_dtype = y.dtype
+
+ result = y.double()
+ result = self._norm.denormalize(result)
+
+ return result.to(dtype = old_dtype)
+
+class Compose(DataNorm):
+
+ def __init__(self, norms):
+ self._norms = norms
+
+ def normalize(self, x):
+ for norm in self._norms:
+ x = norm.normalize(x)
+
+ return x
+
+ def denormalize(self, y):
+ for norm in reversed(self._norms):
+ y = norm.denormalize(y)
+
+ return y
+
+def select_single_data_normalization(norm):
+ name, kwargs = extract_name_kwargs(norm)
+
+ if name == 'double-precision':
+ return DoublePrecision(select_data_normalization(**kwargs))
+
+ if name == 'scale':
+ return ScaleNorm(**kwargs)
+
+ if name == 'log':
+ return LogNorm(**kwargs)
+
+ if name == 'symlog':
+ return SymLogNorm(**kwargs)
+
+ if name == 'standardize':
+ return Standardizer(**kwargs)
+
+ if name == 'min-max-scaler':
+ return MinMaxScaler(**kwargs)
+
+ raise ValueError(f"Unknown data normalization '{name}'")
+
+def select_data_normalization(norm):
+ if norm is None:
+ return None
+
+ if isinstance(norm, (tuple, list)):
+ norm = [ select_single_data_normalization(n) for n in norm ]
+ return Compose(norm)
+
+ return select_single_data_normalization(norm)
+
diff --git a/uvcgan_s/torch/funcs.py b/uvcgan_s/torch/funcs.py
new file mode 100644
index 0000000..b7b2b94
--- /dev/null
+++ b/uvcgan_s/torch/funcs.py
@@ -0,0 +1,68 @@
+import logging
+import random
+import torch
+import numpy as np
+
+from torch import nn
+
+LOGGER = logging.getLogger('uvcgan_s.torch')
+
+def seed_everything(seed):
+ torch.manual_seed(seed)
+ random.seed(seed)
+ np.random.seed(seed)
+
+def get_torch_device_smart():
+ if torch.cuda.is_available():
+ return 'cuda'
+
+ return 'cpu'
+
+def prepare_model(model, device):
+ model = model.to(device)
+
+ if torch.cuda.device_count() > 1:
+ LOGGER.warning(
+ "Multiple (%d) GPUs found. Using Data Parallelism",
+ torch.cuda.device_count()
+ )
+ model = nn.DataParallel(model)
+
+ return model
+
+@torch.no_grad()
+def update_average_model(average_model, model, momentum):
+ # TODO: Maybe it is better to copy buffers, instead of
+ # averaging them.
+ # Think about this later.
+ online_params = dict(model.named_parameters())
+ online_bufs = dict(model.named_buffers())
+
+ for (k, v) in average_model.named_parameters():
+ if v.ndim == 0:
+ v.copy_(momentum * v + (1 - momentum) * online_params[k])
+ else:
+ v.lerp_(online_params[k], (1 - momentum))
+
+ for (k, v) in average_model.named_buffers():
+ if v.ndim == 0:
+ v.copy_(momentum * v + (1 - momentum) * online_bufs[k])
+ else:
+ v.lerp_(online_bufs[k], (1 - momentum))
+
+def clip_gradients(optimizer, norm = None, value = None):
+ if (norm is None) and (value is None):
+ return
+
+ params = [
+ param
+ for param_group in optimizer.param_groups
+ for param in param_group['params']
+ ]
+
+ if norm is not None:
+ torch.nn.utils.clip_grad_norm_(params, max_norm = norm)
+
+ if value is not None:
+ torch.nn.utils.clip_grad_value_(params, clip_value = value)
+
diff --git a/uvcgan_s/torch/gan_losses.py b/uvcgan_s/torch/gan_losses.py
new file mode 100644
index 0000000..059bfcb
--- /dev/null
+++ b/uvcgan_s/torch/gan_losses.py
@@ -0,0 +1,135 @@
+import torch
+from torch import nn
+
+from .select import extract_name_kwargs
+
+def reduce_loss(loss, reduction):
+ if (reduction is None) or (reduction == 'none'):
+ return loss
+
+ if reduction == 'mean':
+ return loss.mean()
+
+ if reduction == 'sum':
+ return loss.sum()
+
+ raise ValueError(f"Unknown reduction method: '{reduction}'")
+
+class GANLoss(nn.Module):
+
+ def __init__(
+ self, label_real = 1, label_fake = 0, reduction = 'mean',
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.reduction = reduction
+ self.register_buffer('label_real', torch.tensor(label_real))
+ self.register_buffer('label_fake', torch.tensor(label_fake))
+
+ def _expand_label_as(self, x, is_real):
+ result = self.label_real if is_real else self.label_fake
+ return result.to(dtype = x.dtype).expand_as(x)
+
+ def eval_loss(self, x, is_real, is_generator):
+ raise NotImplementedError
+
+ def forward(self, x, is_real, is_generator = False):
+ if isinstance(x, (list, tuple)):
+ result = sum(self.forward(y, is_real, is_generator) for y in x)
+ return result / len(x)
+
+ return self.eval_loss(x, is_real, is_generator)
+
+class LSGANLoss(GANLoss):
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.loss = nn.MSELoss(reduction = self.reduction)
+
+ def eval_loss(self, x, is_real, is_generator = False):
+ label = self._expand_label_as(x, is_real)
+ return self.loss(x, label)
+
+class BCEGANLoss(GANLoss):
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.loss = nn.BCEWithLogitsLoss(reduction = self.reduction)
+
+ def eval_loss(self, x, is_real, is_generator = False):
+ label = self._expand_label_as(x, is_real)
+ return self.loss(x, label)
+
+class SoftplusGANLoss(GANLoss):
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.loss = nn.Softplus()
+
+ def eval_loss(self, x, is_real, is_generator = False):
+ if is_real:
+ result = self.loss(x)
+ else:
+ result = self.loss(-x)
+
+ return reduce_loss(result, self.reduction)
+
+class WGANLoss(GANLoss):
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ if self.reduction != 'mean':
+ raise NotImplementedError
+
+ def eval_loss(self, x, is_real, is_generator = False):
+ if is_real:
+ result = -x.mean()
+ else:
+ result = x.mean()
+
+ return reduce_loss(result, self.reduction)
+
+class HingeGANLoss(GANLoss):
+
+ def __init__(self, margin = 1, **kwargs):
+ super().__init__(**kwargs)
+ self._margin = margin
+ self._relu = nn.ReLU()
+
+ if self.reduction != 'mean':
+ raise NotImplementedError
+
+ def eval_loss(self, x, is_real, is_generator = False):
+ if is_generator:
+ if is_real:
+ result = -x.mean()
+ else:
+ result = x.mean()
+ else:
+ if is_real:
+ result = self._relu(self._margin - x).mean()
+ else:
+ result = self._relu(self._margin + x).mean()
+
+ return reduce_loss(result, self.reduction)
+
+GAN_LOSSES = {
+ 'lsgan' : LSGANLoss,
+ 'wgan' : WGANLoss,
+ 'softplus' : SoftplusGANLoss,
+ 'hinge' : HingeGANLoss,
+ 'bce' : BCEGANLoss,
+ 'vanilla' : BCEGANLoss,
+}
+
+def select_gan_loss(gan_loss):
+ name, kwargs = extract_name_kwargs(gan_loss)
+
+ if name in GAN_LOSSES:
+ return GAN_LOSSES[name](**kwargs)
+
+ raise ValueError(f"Unknown gan loss: '{name}'")
+
+
diff --git a/uvcgan_s/torch/gradient_cacher.py b/uvcgan_s/torch/gradient_cacher.py
new file mode 100644
index 0000000..1c812f9
--- /dev/null
+++ b/uvcgan_s/torch/gradient_cacher.py
@@ -0,0 +1,41 @@
+
+class GradientCacher:
+
+ def __init__(self, model, loss, cache_period = 10):
+ self._cache_period = cache_period
+ self._cache_grad = {}
+ self._cache_loss = None
+ self._iter = 0
+ self._model = model
+ self._loss = loss
+
+ def __call__(self, *args, **kwargs):
+ if (
+ (self._iter < self._cache_period)
+ and (self._cache_loss is not None)
+ ):
+ for name, p in self._model.named_parameters():
+ cached_grad = self._cache_grad[name]
+
+ if cached_grad is None:
+ p.grad = None
+ else:
+ p.grad = cached_grad.clone()
+
+ result = self._cache_loss
+ self._iter += 1
+
+ else:
+ result = self._loss(*args, **kwargs)
+ result.backward()
+
+ self._cache_loss = result
+ self._cache_grad = {
+ name : p.grad.detach().clone() if p.grad is not None else None
+ for (name, p) in self._model.named_parameters()
+ }
+
+ self._iter = 0
+
+ return result
+
diff --git a/uvcgan_s/torch/gradient_penalty.py b/uvcgan_s/torch/gradient_penalty.py
new file mode 100644
index 0000000..2f82ff2
--- /dev/null
+++ b/uvcgan_s/torch/gradient_penalty.py
@@ -0,0 +1,134 @@
+import torch
+
+def mix_tensors(alpha, a, b):
+ return alpha * a + (1 - alpha) * b
+
+def recursively_mix_args(alpha, a, b):
+ if a is None:
+ assert b is None
+ return None
+
+ if isinstance(a, (tuple, list)):
+ return type(a)(
+ recursively_mix_args(alpha, x, y) for (x, y) in zip(a, b)
+ )
+
+ if isinstance(a, dict):
+ return {
+ k : recursively_mix_args(alpha, a[k], b[k]) for k in a
+ }
+
+ if isinstance(a, torch.Tensor):
+ return mix_tensors(alpha, a, b)
+
+ assert a == b
+ return a
+
+def reduce_tensor(y, reduction, reduce_batch = False):
+ if isinstance(y, (list, tuple)):
+ return [ reduce_tensor(x, reduction) for x in y ]
+
+ if (reduction is None) or (reduction == 'none'):
+ return y
+
+ if reduction == 'mean':
+ if reduce_batch:
+ return y.mean()
+ else:
+ return y.mean(dim = tuple(i for i in range(1, y.dim())))
+
+ if reduction == 'sum':
+ if reduce_batch:
+ return y.sum()
+ else:
+ return y.sum(dim = tuple(i for i in range(1, y.dim())))
+
+ raise ValueError(f"Unknown reduction: '{reduction}'")
+
+class GradientPenalty:
+
+ def __init__(
+ self, mix_type, center, lambda_gp, seed = 0,
+ reduction = 'mean', gp_reduction = 'mean'
+ ):
+ # pylint: disable=too-many-arguments
+ self._mix_type = mix_type
+ self._center = center
+ self._reduce = reduction
+ self._gp_reduce = gp_reduction
+ self._lambda = lambda_gp
+
+ self._rng = torch.Generator()
+ self._rng.manual_seed(seed)
+
+ def eval_at(self, model, x, **model_kwargs):
+ x.requires_grad_(True)
+
+ y = model(x, **model_kwargs)
+ y = reduce_tensor(y, self._reduce, reduce_batch = False)
+
+ if not isinstance(y, list):
+ y = [ y, ]
+
+ grad = torch.autograd.grad(
+ outputs = y,
+ inputs = x,
+ grad_outputs = [ torch.ones(z.size()).to(z.device) for z in y ],
+ create_graph = True,
+ retain_graph = True,
+ )
+
+ grad_x = grad[0].reshape((x.shape[0], -1))
+ result = torch.square(
+ torch.norm(grad_x, p = 2, dim = 1) - self._center
+ )
+
+ return self._lambda * result
+
+ def mix_gp_args(self, a, b, model_kwargs_a, model_kwargs_b):
+ alpha = torch.rand(1, generator = self._rng).to(a.device)
+ result = mix_tensors(alpha, a, b)
+
+ mixed_kwargs = recursively_mix_args(
+ alpha, model_kwargs_a, model_kwargs_b
+ )
+ return (result, mixed_kwargs)
+
+ def get_eval_point(
+ self, fake, real, model_kwargs_fake = None, model_kwargs_real = None
+ ):
+ if self._mix_type == 'real':
+ return (real.clone(), model_kwargs_real)
+
+ if self._mix_type == 'fake':
+ return (fake.clone(), model_kwargs_fake)
+
+ if self._mix_type == 'real-fake':
+ return self.mix_gp_args(
+ real, fake, model_kwargs_real, model_kwargs_fake
+ )
+
+ if self._mix_type == 'real-or-fake':
+ alpha = torch.rand(1, generator = self._rng).cpu().item()
+ if alpha > 0.5:
+ return (real.clone(), model_kwargs_real)
+ else:
+ return (fake.clone(), model_kwargs_fake)
+
+ raise ValueError(f"Unknown mix type: {self._mix_type}")
+
+ def __call__(
+ self, model, fake, real,
+ model_kwargs_fake = None,
+ model_kwargs_real = None,
+ ):
+ # pylint: disable=too-many-arguments
+ x, model_kwargs = self.get_eval_point(
+ fake, real, model_kwargs_fake, model_kwargs_real
+ )
+
+ model_kwargs = model_kwargs or {}
+ result = self.eval_at(model, x, **model_kwargs)
+
+ return reduce_tensor(result, self._gp_reduce, reduce_batch = True)
+
diff --git a/uvcgan_s/torch/image_masking.py b/uvcgan_s/torch/image_masking.py
new file mode 100644
index 0000000..cffc0c1
--- /dev/null
+++ b/uvcgan_s/torch/image_masking.py
@@ -0,0 +1,129 @@
+import torch
+from torch import nn
+
+from .select import extract_name_kwargs
+from .layers.transformer import calc_tokenized_size
+
+class SequenceRandomMasking(nn.Module):
+
+ def __init__(self, fraction = 0.4, seed = 0, **kwargs):
+ super().__init__(**kwargs)
+ self._fraction = fraction
+
+ self._rng = torch.Generator()
+ self._rng.manual_seed(seed)
+
+ def forward(self, sequence):
+ # sequence : (N, L, features)
+ mask = (
+ torch.rand((*sequence.shape[:2], 1), generator = self._rng)
+ > self._fraction
+ )
+ return mask.to(sequence.device) * sequence
+
+class ImagePatchRandomMasking(nn.Module):
+
+ def __init__(self, patch_size, fraction = 0.4, seed = 0, **kwargs):
+ super().__init__(**kwargs)
+
+ self._patch_size = patch_size
+ self._fraction = fraction
+
+ self._rng = torch.Generator()
+ self._rng.manual_seed(seed)
+
+ def forward(self, image):
+ # image : (N, C, H, W)
+ N_h, N_w = calc_tokenized_size(image.shape[1:], self._patch_size)
+
+ # mask : (N, 1, N_h, N_w)
+ mask = (
+ torch.rand((image.shape[0], 1, N_h, N_w), generator = self._rng)
+ > self._fraction
+ )
+
+ # mask : (N, 1, N_h, N_w)
+ # -> (N, 1, H, W)
+ mask = mask.repeat_interleave(self._patch_size[0], dim = 2)
+ mask = mask.repeat_interleave(self._patch_size[1], dim = 3)
+
+ return mask.to(image.device) * image
+
+# pylint: disable=trailing-whitespace
+# class BlockwiseMasking(ImageMaskingBase):
+# # Algorithm 1 of arXiv:2106.08254
+#
+# def __init__(
+# self,
+# mask_ratio = 0.4,
+# min_block_size = 16,
+# aspect_ratio = 0.3,
+# seed = 0,
+# ):
+# self._mask_ratio = mask_ratio
+# self._min_block_size = min_block_size
+# self._aspect_ratio = aspect_ratio
+#
+# self._prg = np.random.default_rng(seed)
+#
+# def get_mask_region(self, image, h, w, masking_threshold, masked_patches):
+# min_block_size = self._min_block_size
+# max_block_size = \
+# max(min_block_size, masking_threshold - len(masked_patches))
+#
+# block_size = self._prg.integers(min_block_size, max_block_size)
+# aspect_ratio = self._prg.uniform(
+# self._aspect_ratio, 1/self._aspect_ratio
+# )
+#
+# y_range = int(np.round(np.sqrt(block_size * aspect_ratio)))
+# x_range = int(np.round(np.sqrt(block_size / aspect_ratio)))
+#
+# y_range = min(y_range, h)
+# x_range = min(x_range, w)
+#
+# y0 = self._prg.integers(0, h - y_range)
+# x0 = self._prg.integers(0, w - x_range)
+#
+# return (y0, x0, y_range, x_range)
+#
+# def mask(self, image):
+# # image : (..., H, W)
+# h = image.shape[-2]
+# w = image.shape[-1]
+#
+# n_patches = h * w
+# masked_patches = set()
+#
+# masking_threshold = self._mask_ratio * n_patches
+#
+# while len(masked_patches) < masking_threshold:
+# (y0, x0, y_range, x_range) = self.get_mask_region(
+# image, h, w, masking_threshold, masked_patches
+# )
+#
+# for y in range(y0, y0 + y_range):
+# for x in range(x0, x0 + x_range):
+# coord = (x, y)
+# if coord in masked_patches:
+# continue
+#
+# image[..., y, x] = 0
+# masked_patches.add(coord)
+#
+# return image
+
+def select_masking(masking):
+ if masking is None:
+ return None
+
+ name, kwargs = extract_name_kwargs(masking)
+
+ if name in [ 'transformer-random', 'sequence-random' ]:
+ return SequenceRandomMasking(**kwargs)
+
+ if name == 'image-patch-random':
+ return ImagePatchRandomMasking(**kwargs)
+
+ raise ValueError("Unknown masking: '%s'" % name)
+
diff --git a/uvcgan_s/torch/layers/__init__.py b/uvcgan_s/torch/layers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/uvcgan_s/torch/layers/activation.py b/uvcgan_s/torch/layers/activation.py
new file mode 100644
index 0000000..bb14ab8
--- /dev/null
+++ b/uvcgan_s/torch/layers/activation.py
@@ -0,0 +1,12 @@
+import torch
+from torch import nn
+
+class Exponential(nn.Module):
+
+ def __init__(self, beta = 1):
+ super().__init__()
+ self._beta = beta
+
+ def forward(self, x):
+ return torch.exp(self._beta * x)
+
diff --git a/uvcgan_s/torch/layers/alt_trans.py b/uvcgan_s/torch/layers/alt_trans.py
new file mode 100644
index 0000000..9e9c315
--- /dev/null
+++ b/uvcgan_s/torch/layers/alt_trans.py
@@ -0,0 +1,125 @@
+# pylint: disable=too-many-arguments
+# pylint: disable=too-many-instance-attributes
+
+import torch
+from torch import nn
+
+from uvcgan_s.torch.select import get_norm_layer
+from .attention import select_attention
+from .transformer import (
+ img_to_pixelwise_tokens, img_from_pixelwise_tokens,
+ PositionWiseFFN, ViTInput
+)
+
+class AltTransformerBlock(nn.Module):
+
+ def __init__(
+ self, features, ffn_features, n_heads, activ = 'gelu', norm = None,
+ attention = 'dot', rezero = True, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.norm1 = get_norm_layer(norm, features)
+ self.atten = select_attention(
+ attention,
+ embed_dim = features,
+ num_heads = n_heads,
+ batch_first = False,
+ )
+
+ self.norm2 = get_norm_layer(norm, features)
+ self.ffn = PositionWiseFFN(features, ffn_features, activ)
+
+ self.rezero = rezero
+
+ if rezero:
+ self.re_alpha = nn.Parameter(torch.zeros((1, )))
+ else:
+ self.re_alpha = 1
+
+ def forward(self, x):
+ # x: (L, N, features)
+
+ # Step 1: Multi-Head Self Attention
+ y1 = self.norm1(x)
+ y1, _atten_weights = self.atten(y1, y1, y1)
+
+ y = x + self.re_alpha * y1
+
+ # Step 2: PositionWise Feed Forward Network
+ y2 = self.norm2(y)
+ y2 = self.ffn(y2)
+
+ y = y + self.re_alpha * y2
+
+ return y
+
+ def extra_repr(self):
+ return 're_alpha = %e' % (self.re_alpha, )
+
+class AltTransformerEncoder(nn.Module):
+
+ def __init__(
+ self, features, ffn_features, n_heads, n_blocks, activ, norm,
+ attention = 'dot', rezero = True, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.encoder = nn.Sequential(*[
+ AltTransformerBlock(
+ features, ffn_features, n_heads, activ, norm, attention, rezero
+ ) for _ in range(n_blocks)
+ ])
+
+ def forward(self, x):
+ # x : (N, L, features)
+
+ # y : (L, N, features)
+ y = x.permute((1, 0, 2))
+ y = self.encoder(y)
+
+ # result : (N, L, features)
+ result = y.permute((1, 0, 2))
+
+ return result
+
+class AltPixelwiseViT(nn.Module):
+
+ def __init__(
+ self, features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm, image_shape, attention = 'dot', rezero = True, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.image_shape = image_shape
+
+ self.trans_input = ViTInput(
+ image_shape[0], embed_features, features,
+ image_shape[1], image_shape[2],
+ )
+
+ self.encoder = AltTransformerEncoder(
+ features, ffn_features, n_heads, n_blocks, activ, norm,
+ attention, rezero
+ )
+
+ self.trans_output = nn.Linear(features, image_shape[0])
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+
+ # itokens : (N, L, C)
+ itokens = img_to_pixelwise_tokens(x)
+
+ # y : (N, L, features)
+ y = self.trans_input(itokens)
+ y = self.encoder(y)
+
+ # otokens : (N, L, C)
+ otokens = self.trans_output(y)
+
+ # result : (N, C, H, W)
+ result = img_from_pixelwise_tokens(otokens, self.image_shape)
+
+ return result
+
diff --git a/uvcgan_s/torch/layers/attention.py b/uvcgan_s/torch/layers/attention.py
new file mode 100644
index 0000000..c67adc5
--- /dev/null
+++ b/uvcgan_s/torch/layers/attention.py
@@ -0,0 +1,160 @@
+import math
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from einops import rearrange
+from uvcgan_s.torch.select import extract_name_kwargs
+
+def expand_heads(values, n_heads):
+ return rearrange(
+ values, 'N L (D_h n_heads) -> (N n_heads) L D_h', n_heads = n_heads
+ )
+
+def contract_heads(values, n_heads):
+ return rearrange(
+ values, '(N n_heads) L D_h -> N L (D_h n_heads)', n_heads = n_heads
+ )
+
+def return_result_and_atten_weights(
+ result, A, need_weights, average_attn_weights, batch_first, n_heads
+):
+ # pylint: disable=too-many-arguments
+
+ # result : (N, L, embed_dim)
+ # A : ((N n_heads), L, S)
+
+ if not batch_first:
+ # result : (N, L, embed_dim)
+ # -> (L, N, embed_dim)
+ result = result.swapaxes(0, 1)
+
+ if not need_weights:
+ return (result, None)
+
+ # A : (n_heads, N, L, S)
+ A = rearrange(
+ A, '(N n_heads) L S -> N n_heads L S', n_heads = n_heads
+ )
+
+ if average_attn_weights:
+ # A : (N, n_heads, L, S)
+ # -> (N, L, S)
+ A = A.mean(dim = 1)
+
+ return (result, A)
+
+class QuadraticAttention(nn.Module):
+
+ # This is Lipshits continuous when equal_kq
+
+ def __init__(
+ self, embed_dim, num_heads,
+ bias = False,
+ add_bias_kv = False,
+ kdim = None,
+ vdim = None,
+ batch_first = False,
+ equal_kq = False,
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__()
+
+ if kdim is None:
+ kdim = embed_dim
+ elif equal_kq:
+ assert kdim == embed_dim
+
+ if vdim is None:
+ vdim = embed_dim
+
+ self._batch_first = batch_first
+ self._n_heads = num_heads
+ self._dh = embed_dim // num_heads
+
+ self.w_q = nn.Linear(embed_dim, embed_dim, bias = bias)
+
+ if equal_kq:
+ self.w_k = self.w_q
+ else:
+ self.w_k = nn.Linear(kdim, embed_dim, bias = add_bias_kv)
+
+ self.w_v = nn.Linear(vdim, embed_dim, bias = add_bias_kv)
+ self.w_o = nn.Linear(embed_dim, embed_dim, bias = bias)
+
+ def compute_attention_matrix(self, w_query, w_key):
+ # w_query : (N, L, d)
+ # w_key : (N, S, d)
+
+ # w_query : (N, L, d)
+ # -> (N, L, 1, d)
+ w_query = w_query.unsqueeze(dim = 2)
+
+ # w_key : (N, S, d)
+ # -> (N, 1, S, d)
+ w_key = w_key.unsqueeze(dim = 1)
+
+ # L : (N, L, S)
+ L = -torch.norm(w_query - w_key, p = 2, dim = 3)
+ L = L / math.sqrt(self._dh)
+
+ # result : (N, L, S)
+ return torch.softmax(L, dim = 2)
+
+ def forward(
+ self, query: Tensor, key: Tensor, value: Tensor,
+ key_padding_mask : Optional[Tensor] = None,
+ need_weights : bool = True,
+ attn_mask : Optional[Tensor] = None,
+ average_attn_weights : bool = True
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ # pylint: disable=too-many-arguments
+
+ assert key_padding_mask is None, "key_padding_mask is not supported"
+ assert attn_mask is None, "attn_mask is not supported"
+
+ if not self._batch_first:
+ # (k, q, v) : (L, N, ...)
+ query = query.swapaxes(0, 1)
+ key = key.swapaxes(0, 1)
+ value = value.swapaxes(0, 1)
+
+ # (query, key, value) : (N, L, ...)
+
+ # w_query : ((N n_heads), L, D_h)
+ # w_key : ((N n_heads), S, D_h)
+ # w_value : ((N n_heads), S, D_h)
+ w_query = expand_heads(self.w_q(query), self._n_heads)
+ w_key = expand_heads(self.w_k(key), self._n_heads)
+ w_value = expand_heads(self.w_v(key), self._n_heads)
+
+ # A : ((N n_heads), L, S)
+ A = self.compute_attention_matrix(w_query, w_key)
+
+ # w_output : ((N n_heads), L, D_h)
+ w_output = torch.bmm(A, w_value)
+
+ # output : (N, L, embed_dim)
+ output = contract_heads(w_output, self._n_heads)
+
+ # result : (N, L, embed_dim)
+ result = self.w_o(output)
+
+ return return_result_and_atten_weights(
+ result, A, need_weights, average_attn_weights,
+ self._batch_first, self._n_heads
+ )
+
+def select_attention(attention, **extra_kwargs):
+ name, kwargs = extract_name_kwargs(attention)
+
+ if name in [ 'default', 'standard', 'scalar', 'dot' ]:
+ return nn.MultiheadAttention(**kwargs, **extra_kwargs)
+
+ if name in [ 'quadratic', 'l2' ]:
+ return QuadraticAttention(**kwargs, **extra_kwargs)
+
+ raise ValueError(f"Unknown attention {name}")
+
diff --git a/uvcgan_s/torch/layers/batch_head.py b/uvcgan_s/torch/layers/batch_head.py
new file mode 100644
index 0000000..6d5464c
--- /dev/null
+++ b/uvcgan_s/torch/layers/batch_head.py
@@ -0,0 +1,331 @@
+import torch
+from torch import nn
+
+from uvcgan_s.torch.select import get_activ_layer, extract_name_kwargs
+from .alt_trans import AltTransformerEncoder
+
+# References:
+# arXiv: 1912.0495
+# https://github.com/moono/stylegan2-tf-2.x/blob/master/stylegan2/discriminator.py
+# https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py
+
+class BatchStdev(nn.Module):
+
+ # pylint: disable=useless-super-delegation
+ def __init__(self, **kwargs):
+ """ arXiv: 1710.10196 """
+ super().__init__(**kwargs)
+
+ @staticmethod
+ def safe_stdev(x, dim = 0, eps = 1e-6):
+ var = torch.var(x, dim = dim, unbiased = False, keepdim = True)
+ stdev = torch.sqrt(var + eps)
+
+ return stdev
+
+ # pylint: disable=no-self-use
+ def forward(self, x):
+ """
+ NOTE: Reference impl has fixed minibatch size.
+
+ arXiv: 1710.10196
+
+ 1. We first compute the standard deviation for each feature in each
+ spatial location over the minibatch.
+
+ 2. We then average these estimates over all features and spatial
+ locations to arrive at a single value.
+
+ 3. We replicate the value and concatenate it to all spatial locations
+ and over the minibatch, yielding one additional (con-stant) feature
+ map.
+ """
+
+ # x : (N, C, H, W)
+ # x_stdev : (1, C, H, W)
+ x_stdev = BatchStdev.safe_stdev(x, dim = 0)
+
+ # x_norm : (1, 1, 1, 1)
+ x_norm = torch.mean(x_stdev, dim = (1, 2, 3), keepdim = True)
+
+ # x_norm : (N, 1, H, W)
+ x_norm = x_norm.expand((x.shape[0], 1, *x.shape[2:]))
+
+ # y : (N, C + 1, H, W)
+ y = torch.cat((x, x_norm), dim = 1)
+
+ return y
+
+class BatchAttention(nn.Module):
+
+ def __init__(
+ self, input_features, features, ffn_features, n_heads, n_blocks, activ,
+ norm, attention = 'dot', rezero = True,
+ input_flatten_method = 'average', **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__(**kwargs)
+
+ self.embed = nn.Linear(input_features, features)
+ self.attention = AltTransformerEncoder(
+ features, ffn_features, n_heads, n_blocks, activ, norm, attention,
+ rezero
+ )
+
+ self.flat_method = input_flatten_method
+
+ assert input_flatten_method in [ 'average', 'flatten' ], \
+ f"Unknown input flattening method: {input_flatten_method}"
+
+ def flatten_input(self, x):
+ # x : (N, C_in, ...)
+
+ if len(x.shape) == 2:
+ return x
+
+ if self.flat_method == 'average':
+ return x.mean(dim = tuple(range(2, len(x.shape))), keepdim = False)
+
+ if self.flat_method == 'stdev':
+ return BatchStdev.safe_stdev(
+ x, dim = tuple(range(2, len(x.shape)))
+ )
+
+ if self.flat_method == 'flatten':
+ return x.reshape((x.shape[0], -1))
+
+ raise ValueError(f"Unknown flatten method: {self.flat_method}")
+
+ def forward(self, x):
+ # x : (N, C_in, ...)
+
+ # x_flat : (N, C)
+ x_flat = self.flatten_input(x)
+
+ # e : (1, N, features)
+ e = self.embed(x_flat).unsqueeze(0)
+
+ # y : (1, N, features)
+ y = self.attention(e)
+
+ # result : (N, features)
+ return y.squeeze(0)
+
+class BatchHead1d(nn.Module):
+
+ def __init__(
+ self, input_features, mid_features = None, output_features = None,
+ activ = 'relu', activ_output = None, **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__(**kwargs)
+
+ if mid_features is None:
+ mid_features = input_features
+
+ if output_features is None:
+ output_features = mid_features
+
+ self.net = nn.Sequential(
+ nn.Linear(input_features, mid_features),
+ nn.BatchNorm1d(mid_features),
+ get_activ_layer(activ),
+
+ nn.Linear(mid_features, output_features),
+ get_activ_layer(activ_output),
+ )
+
+ def forward(self, x):
+ # x : (N, C)
+ return self.net(x)
+
+class BatchHead2d(nn.Module):
+
+ def __init__(
+ self, input_features, mid_features = None, output_features = None,
+ activ = 'relu', activ_output = None, n_signal = None, **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__(**kwargs)
+
+ if mid_features is None:
+ mid_features = input_features
+
+ if output_features is None:
+ output_features = mid_features
+
+ self._n_signal = n_signal
+
+ self.norm = nn.BatchNorm2d(input_features)
+ self.net = nn.Sequential(
+ nn.Conv2d(
+ input_features, mid_features, kernel_size = 3, padding = 1
+ ),
+ get_activ_layer(activ),
+
+ nn.Conv2d(
+ mid_features, output_features, kernel_size = 3, padding = 1
+ ),
+ get_activ_layer(activ_output),
+ )
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+ y = self.norm(x)
+
+ if self._n_signal is not None:
+ # Drop queue tokens
+ y = y[:self._n_signal, ...]
+
+ return self.net(y)
+
+class BatchStdevHead(nn.Module):
+
+ def __init__(
+ self, input_features, mid_features = None, output_features = None,
+ activ = 'relu', activ_output = None, **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__(**kwargs)
+
+ if mid_features is None:
+ mid_features = input_features
+
+ if output_features is None:
+ output_features = mid_features
+
+ self.net = nn.Sequential(
+ BatchStdev(),
+ nn.Conv2d(
+ input_features + 1, mid_features, kernel_size = 3, padding = 1
+ ),
+ get_activ_layer(activ),
+
+ nn.Conv2d(
+ mid_features, output_features, kernel_size = 3, padding = 1
+ ),
+ get_activ_layer(activ_output),
+ )
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+ return self.net(x)
+
+class BatchAverageHead(nn.Module):
+
+ def __init__(
+ self, input_features, reduce_channels = True, average_spacial = False,
+ activ_output = None, **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__(**kwargs)
+
+ layers = []
+
+ if reduce_channels:
+ layers.append(
+ nn.Conv2d(input_features, 1, kernel_size = 3, padding = 1)
+ )
+
+ if average_spacial:
+ layers.append(nn.AdaptiveAvgPool2d(1))
+
+ if activ_output is not None:
+ layers.append(get_activ_layer(activ_output))
+
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+ return self.net(x)
+
+class BatchAttentionHead(nn.Module):
+
+ def __init__(
+ self, input_features, features, ffn_features, n_heads, n_blocks, activ,
+ norm,
+ attention = 'dot',
+ rezero = True,
+ input_flatten_method = 'average',
+ output_features = None,
+ activ_output = None,
+ **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__(**kwargs)
+
+ if output_features is None:
+ output_features = input_features
+
+ self.net = nn.Sequential(
+ BatchAttention(
+ input_features, features, ffn_features, n_heads, n_blocks,
+ activ, norm, attention, rezero, input_flatten_method
+ ),
+
+ nn.Linear(features, output_features),
+ get_activ_layer(activ_output),
+ )
+
+ def forward(self, x):
+ # x : (N, C, ...)
+ return self.net(x)
+
+class BatchHeadWrapper(nn.Module):
+
+ def __init__(self, body, head, **kwargs):
+ super().__init__(**kwargs)
+ self._body = body
+ self._head = head
+
+ def forward_head(self, x_body):
+ return self._head(x_body)
+
+ def forward_body(self, x):
+ return self._body(x)
+
+ def forward(self, x, extra_bodies = None, return_body = False):
+ y_body = self._body(x)
+
+ if isinstance(y_body, (list, tuple)):
+ y_body_main = list(y_body[:-1])
+ y_body_last = y_body[-1]
+ else:
+ y_body_main = tuple()
+ y_body_last = y_body
+
+ if extra_bodies is not None:
+ all_bodies = torch.cat((y_body_last, extra_bodies), dim = 0)
+ y_head = self._head(all_bodies)
+ else:
+ y_head = self._head(y_body_last)
+
+ y_head = y_head[:y_body_last.shape[0]]
+
+ if len(y_body_main) == 0:
+ result = y_head
+ else:
+ result = y_body_main + [ y_head, ]
+
+ if return_body:
+ return (result, y_body_last)
+
+ return result
+
+BATCH_HEADS = {
+ 'batch-norm-1d' : BatchHead1d,
+ 'batch-norm-2d' : BatchHead2d,
+ 'batch-stdev' : BatchStdevHead,
+ 'batch-atten' : BatchAttentionHead,
+ 'simple-average' : BatchAverageHead,
+ 'idt' : nn.Identity,
+}
+
+def get_batch_head(batch_head):
+ name, kwargs = extract_name_kwargs(batch_head)
+
+ if name not in BATCH_HEADS:
+ raise ValueError("Unknown Batch Head: '%s'" % name)
+
+ return BATCH_HEADS[name](**kwargs)
+
diff --git a/uvcgan_s/torch/layers/cnn.py b/uvcgan_s/torch/layers/cnn.py
new file mode 100644
index 0000000..e76e6a8
--- /dev/null
+++ b/uvcgan_s/torch/layers/cnn.py
@@ -0,0 +1,143 @@
+from itertools import repeat
+
+from torch import nn
+from uvcgan_s.torch.select import extract_name_kwargs
+
+def calc_conv1d_output_size(input_size, kernel_size, padding, stride):
+ return (input_size + 2 * padding - kernel_size) // stride + 1
+
+def calc_conv_transpose1d_output_size(
+ input_size, kernel_size, padding, stride
+):
+ return (input_size - 1) * stride - 2 * padding + kernel_size
+
+def calc_conv_output_size(input_size, kernel_size, padding, stride):
+ if isinstance(kernel_size, int):
+ kernel_size = repeat(kernel_size, len(input_size))
+
+ if isinstance(stride, int):
+ stride = repeat(stride, len(input_size))
+
+ if isinstance(padding, int):
+ padding = repeat(padding, len(input_size))
+
+ return tuple(
+ calc_conv1d_output_size(sz, ks, p, s)
+ for (sz, ks, p, s) in zip(input_size, kernel_size, padding, stride)
+ )
+
+def calc_conv_transpose_output_size(input_size, kernel_size, padding, stride):
+ if isinstance(kernel_size, int):
+ kernel_size = repeat(kernel_size, len(input_size))
+
+ if isinstance(stride, int):
+ stride = repeat(stride, len(input_size))
+
+ if isinstance(padding, int):
+ padding = repeat(padding, len(input_size))
+
+ return tuple(
+ calc_conv_transpose1d_output_size(sz, ks, p, s)
+ for (sz, ks, p, s) in zip(input_size, kernel_size, padding, stride)
+ )
+
+def get_downsample_x2_conv2_layer(features, **kwargs):
+ return (
+ nn.Conv2d(features, features, kernel_size = 2, stride = 2, **kwargs),
+ features
+ )
+
+def get_downsample_x2_conv3_layer(features, **kwargs):
+ return (
+ nn.Conv2d(
+ features, features, kernel_size = 3, stride = 2, padding = 1,
+ **kwargs
+ ),
+ features
+ )
+
+def get_downsample_x2_pixelshuffle_layer(features, **kwargs):
+ out_features = 4 * features
+ return (nn.PixelUnshuffle(downscale_factor = 2, **kwargs), out_features)
+
+def get_downsample_x2_pixelshuffle_conv_layer(features, **kwargs):
+ out_features = features * 4
+
+ layer = nn.Sequential(
+ nn.PixelUnshuffle(downscale_factor = 2, **kwargs),
+ nn.Conv2d(
+ out_features, out_features, kernel_size = 3, padding = 1
+ ),
+ )
+
+ return (layer, out_features)
+
+def get_upsample_x2_deconv2_layer(features, **kwargs):
+ return (
+ nn.ConvTranspose2d(
+ features, features, kernel_size = 2, stride = 2, **kwargs
+ ),
+ features
+ )
+
+def get_upsample_x2_upconv_layer(features, **kwargs):
+ layer = nn.Sequential(
+ nn.Upsample(scale_factor = 2, **kwargs),
+ nn.Conv2d(features, features, kernel_size = 3, padding = 1),
+ )
+
+ return (layer, features)
+
+def get_upsample_x2_pixelshuffle_conv_layer(features, **kwargs):
+ out_features = features // 4
+
+ layer = nn.Sequential(
+ nn.PixelShuffle(upscale_factor = 2, **kwargs),
+ nn.Conv2d(out_features, out_features, kernel_size = 3, padding = 1),
+ )
+
+ return (layer, out_features)
+
+def get_downsample_x2_layer(layer, features):
+ name, kwargs = extract_name_kwargs(layer)
+
+ if name == 'conv':
+ return get_downsample_x2_conv2_layer(features, **kwargs)
+
+ if name == 'conv3':
+ return get_downsample_x2_conv3_layer(features, **kwargs)
+
+ if name == 'avgpool':
+ return (nn.AvgPool2d(kernel_size = 2, stride = 2, **kwargs), features)
+
+ if name == 'maxpool':
+ return (nn.MaxPool2d(kernel_size = 2, stride = 2, **kwargs), features)
+
+ if name == 'pixel-unshuffle':
+ return get_downsample_x2_pixelshuffle_layer(features, **kwargs)
+
+ if name == 'pixel-unshuffle-conv':
+ return get_downsample_x2_pixelshuffle_conv_layer(features, **kwargs)
+
+ raise ValueError("Unknown Downsample Layer: '%s'" % name)
+
+def get_upsample_x2_layer(layer, features):
+ name, kwargs = extract_name_kwargs(layer)
+
+ if name == 'deconv':
+ return get_upsample_x2_deconv2_layer(features, **kwargs)
+
+ if name == 'upsample':
+ return (nn.Upsample(scale_factor = 2, **kwargs), features)
+
+ if name == 'upsample-conv':
+ return get_upsample_x2_upconv_layer(features, **kwargs)
+
+ if name == 'pixel-shuffle':
+ return (nn.PixelShuffle(upscale_factor = 2, **kwargs), features // 4)
+
+ if name == 'pixel-shuffle-conv':
+ return get_upsample_x2_pixelshuffle_conv_layer(features, **kwargs)
+
+ raise ValueError("Unknown Upsample Layer: '%s'" % name)
+
diff --git a/uvcgan_s/torch/layers/modnet.py b/uvcgan_s/torch/layers/modnet.py
new file mode 100644
index 0000000..b779d95
--- /dev/null
+++ b/uvcgan_s/torch/layers/modnet.py
@@ -0,0 +1,422 @@
+# pylint: disable=too-many-arguments
+# pylint: disable=too-many-instance-attributes
+
+import torch
+from torch import nn
+
+from uvcgan_s.torch.select import get_activ_layer
+from .cnn import get_upsample_x2_layer
+from .unet import UNetEncBlock
+
+# Ref: https://arxiv.org/pdf/1912.04958.pdf
+
+def get_demod_scale(mod_scale, weights, eps = 1e-6):
+ # Ref: https://arxiv.org/pdf/1912.04958.pdf
+ #
+ # demod_scale[alpha] = 1 / sqrt(sigma[alpha]^2 + eps)
+ #
+ # sigma[alpha]^2
+ # = sum_{beta i} (mod_scale[alpha] * weights[alpha, beta, i])^2
+ # = sum_{beta} (mod_scale[alpha])^2 * sum_i (weights[alpha, beta, i])^2
+ #
+
+ # mod_scale : (N, C_in)
+ # weights : (C_out, C_in, h, w)
+
+ # w_sq : (C_out, C_in)
+ w_sq = torch.sum(weights.square(), dim = (2, 3))
+
+ # w_sq : (C_out, C_in) -> (1, C_in, C_out)
+ w_sq = torch.swapaxes(w_sq, 0, 1).unsqueeze(0)
+
+ # mod_scale_sq : (N, C_in, 1)
+ mod_scale_sq = mod_scale.square().unsqueeze(2)
+
+ # sigma : (N, C_out)
+ sigma_sq = torch.sum(mod_scale_sq * w_sq, dim = 1)
+
+ # result : (N, C_out)
+ return 1 / torch.sqrt(sigma_sq + eps)
+
+class ModulatedConv2d(nn.Module):
+
+ def __init__(
+ self, in_channels, out_channels, kernel_size,
+ stride = 1, padding = 0, dilation = 1, groups = 1, eps = 1e-6,
+ demod = True, device = None, dtype = None, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.conv = nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
+ groups, device = device, dtype = dtype
+ )
+
+ self._eps = eps
+ self._demod = demod
+
+ def forward(self, x, s):
+ # x : (N, C_in, H_in, W_in)
+ # s : (N, C_in)
+
+ # x_mod : (N, C_in, H_in, W_in)
+ x_mod = x * s.unsqueeze(2).unsqueeze(3)
+
+ # y : (N, C_out, H_out, W_out)
+ y = self.conv(x_mod)
+
+ if self._demod:
+ # s_demod : (N, C_out)
+ s_demod = get_demod_scale(s, self.conv.weight)
+ y_demod = y * s_demod.unsqueeze(2).unsqueeze(3)
+
+ return y_demod
+
+ return y
+
+class StyleBlock(nn.Module):
+
+ def __init__(
+ self, mod_features, style_features, rezero = True, bias = True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.affine_mod = nn.Linear(mod_features, style_features, bias = bias)
+ self.rezero = rezero
+
+ if rezero:
+ self.re_alpha = nn.Parameter(torch.zeros((1, )))
+ else:
+ self.re_alpha = 1
+
+ def forward(self, mod):
+ # mod : (N, mod_features)
+ # s : (N, style_features)
+ s = 1 + self.re_alpha * self.affine_mod(mod)
+
+ return s
+
+ def extra_repr(self):
+ return 're_alpha = %e' % (self.re_alpha, )
+
+class ModNetBasicBlock(nn.Module):
+
+ def __init__(
+ self, in_features, out_features, activ, mod_features,
+ mid_features = None,
+ demod = True,
+ style_rezero = True,
+ style_bias = True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ if mid_features is None:
+ mid_features = out_features
+
+ self.style_block1 = StyleBlock(
+ mod_features, in_features, style_rezero, style_bias
+ )
+ self.conv_block1 = ModulatedConv2d(
+ in_features, mid_features, kernel_size = 3, padding = 1,
+ demod = demod
+ )
+ self.activ1 = get_activ_layer(activ)
+
+ self.style_block2 = StyleBlock(
+ mod_features, mid_features, style_rezero, style_bias
+ )
+ self.conv_block2 = ModulatedConv2d(
+ mid_features, out_features, kernel_size = 3, padding = 1,
+ demod = demod
+ )
+ self.activ2 = get_activ_layer(activ)
+
+ def forward(self, x, mod):
+ # x : (N, C_in, H_in, W_in)
+ # mod : (N, mod_features)
+
+ # mod_scale1 : (N, C_out)
+ mod_scale1 = self.style_block1(mod)
+
+ # y1 : (N, C_mid, H_mid, W_mid)
+ y1 = self.conv_block1(x, mod_scale1)
+ y1 = self.activ1(y1)
+
+ # mod_scale2 : (N, C_out)
+ mod_scale2 = self.style_block2(mod)
+
+ # result : (N, C_out, H_out, W_out)
+ result = self.conv_block2(y1, mod_scale2)
+ result = self.activ2(result)
+
+ return result
+
+class ModNetDecBlock(nn.Module):
+
+ def __init__(
+ self, input_shape, output_features, skip_features, mod_features,
+ activ, upsample,
+ rezero = True,
+ demod = True,
+ style_rezero = True,
+ style_bias = True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ (input_features, H, W) = input_shape
+ self.upsample, input_features = get_upsample_x2_layer(
+ upsample, input_features
+ )
+
+ self.block = ModNetBasicBlock(
+ skip_features + input_features, output_features, activ,
+ mod_features = mod_features,
+ mid_features = max(input_features, input_shape[0]),
+ demod = demod,
+ style_rezero = style_rezero,
+ style_bias = style_bias,
+ )
+
+ self._output_shape = (output_features, 2 * H, 2 * W)
+
+ if rezero:
+ self.re_alpha = nn.Parameter(torch.zeros((1, )))
+ else:
+ self.re_alpha = 1
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def forward(self, x, r, mod):
+ # x : (N, C, H_in, W_in)
+ # r : (N, C, H_out, W_out)
+ # mod : (N, mod_features)
+
+ # x : (N, C_up, H_out, W_out)
+ x = self.re_alpha * self.upsample(x)
+
+ # y : (N, C + C_up, H_out, W_out)
+ y = torch.cat([x, r], dim = 1)
+
+ # result : (N, C_out, H_out, W_out)
+ return self.block(y, mod)
+
+ def extra_repr(self):
+ return 're_alpha = %e' % (self.re_alpha, )
+
+class ModNetBlock(nn.Module):
+
+ def __init__(
+ self, features, activ, norm, image_shape, downsample, upsample,
+ mod_features,
+ rezero = True,
+ demod = True,
+ style_rezero = True,
+ style_bias = True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.conv = UNetEncBlock(
+ features, activ, norm, downsample, image_shape
+ )
+
+ self.inner_shape = self.conv.output_shape
+ self.inner_module = None
+
+ self.deconv = ModNetDecBlock(
+ input_shape = self.inner_shape,
+ output_features = image_shape[0],
+ skip_features = self.inner_shape[0],
+ mod_features = mod_features,
+ activ = activ,
+ upsample = upsample,
+ rezero = rezero,
+ demod = demod,
+ style_rezero = style_rezero,
+ style_bias = style_bias,
+ )
+
+ def get_inner_shape(self):
+ return self.inner_shape
+
+ def set_inner_module(self, module):
+ self.inner_module = module
+
+ def get_inner_module(self):
+ return self.inner_module
+
+ def forward(self, x, mod = None):
+ # x : (N, C, H, W)
+
+ # y : (N, C_inner, H_inner, W_inner)
+ # r : (N, C_inner, H, W)
+ (y, r) = self.conv(x)
+
+ # y : (N, C_inner, H_inner, W_inner)
+ # mod : (N, mod_features)
+ if mod is None:
+ y, mod = self.inner_module(y)
+ else:
+ y, mod = self.inner_module(y, mod)
+
+ # y : (N, C, H, W)
+ y = self.deconv(y, r, mod)
+
+ return (y, mod)
+
+class ModNetLinearDecoder(nn.Module):
+
+ def __init__(
+ self, features_list, input_shape, output_shape, skip_shapes,
+ mod_features, activ, upsample,
+ rezero = True,
+ demod = True,
+ style_rezero = True,
+ style_bias = True,
+ **kwargs
+ ):
+ # pylint: disable = too-many-locals
+ super().__init__(**kwargs)
+
+ self.net = nn.ModuleList()
+ self._input_shape = input_shape
+ self._output_shape = output_shape
+ curr_shape = input_shape
+
+ for features, skip_shape in zip(
+ features_list[::-1], skip_shapes[::-1]
+ ):
+ layer = ModNetDecBlock(
+ input_shape = curr_shape,
+ output_features = features,
+ skip_features = skip_shape[0],
+ mod_features = mod_features,
+ activ = activ,
+ upsample = upsample,
+ rezero = rezero,
+ demod = demod,
+ style_rezero = style_rezero,
+ style_bias = style_bias,
+ )
+ curr_shape = layer.output_shape
+
+ self.net.append(layer)
+
+ self.output = nn.Conv2d(
+ curr_shape[0], output_shape[0], kernel_size = 1
+ )
+ curr_shape = (output_shape[0], *curr_shape[1:])
+
+ assert tuple(output_shape) == tuple(curr_shape)
+
+ @property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def forward(self, x, skip_list, mod):
+ # x : (N, C, H, W)
+ # mod : (N, mod_features)
+ # skip_list : List[(N, C_i, H_i, W_i)]
+
+ y = x
+
+ for layer, skip in zip(self.net, skip_list[::-1]):
+ y = layer(y, skip, mod)
+
+ return self.output(y)
+
+class ModNet(nn.Module):
+
+ def __init__(
+ self, features_list, activ, norm, input_shape, output_shape,
+ downsample, upsample,
+ mod_features,
+ rezero = True,
+ demod = True,
+ style_rezero = True,
+ style_bias = True,
+ return_mod = False,
+ **kwargs
+ ):
+ # pylint: disable = too-many-locals
+ super().__init__(**kwargs)
+ assert tuple(input_shape[1:]) == tuple(output_shape[1:])
+
+ self.features_list = features_list
+ self.input_shape = input_shape
+ self.output_shape = output_shape
+ self.return_mod = return_mod
+
+ self._construct_input_layer(activ)
+ self._construct_output_layer()
+
+ unet_layers = []
+ curr_image_shape = (features_list[0], *input_shape[1:])
+
+ for features in features_list:
+ layer = ModNetBlock(
+ features, activ, norm, curr_image_shape, downsample, upsample,
+ mod_features, rezero, demod, style_rezero, style_bias
+ )
+ curr_image_shape = layer.get_inner_shape()
+ unet_layers.append(layer)
+
+ for idx in range(len(unet_layers)-1):
+ unet_layers[idx].set_inner_module(unet_layers[idx+1])
+
+ self.modnet = unet_layers[0]
+
+ def _construct_input_layer(self, activ):
+ self.layer_input = nn.Sequential(
+ nn.Conv2d(
+ self.input_shape[0], self.features_list[0],
+ kernel_size = 3, padding = 1
+ ),
+ get_activ_layer(activ),
+ )
+
+ def _construct_output_layer(self):
+ self.layer_output = nn.Conv2d(
+ self.features_list[0], self.output_shape[0], kernel_size = 1
+ )
+
+ def get_innermost_block(self):
+ result = self.modnet
+
+ for _ in range(len(self.features_list)-1):
+ result = result.get_inner_module()
+
+ return result
+
+ def set_bottleneck(self, module):
+ self.get_innermost_block().set_inner_module(module)
+
+ def get_bottleneck(self):
+ return self.get_innermost_block().get_inner_module()
+
+ def get_inner_shape(self):
+ return self.get_innermost_block().get_inner_shape()
+
+ def forward(self, x, mod = None):
+ # x : (N, C, H, W)
+
+ y = self.layer_input(x)
+
+ y, mod = self.modnet(y, mod)
+
+ y = self.layer_output(y)
+
+ if self.return_mod:
+ return (y, mod)
+
+ return y
+
diff --git a/uvcgan_s/torch/layers/resnet.py b/uvcgan_s/torch/layers/resnet.py
new file mode 100644
index 0000000..607ecb9
--- /dev/null
+++ b/uvcgan_s/torch/layers/resnet.py
@@ -0,0 +1,480 @@
+import torch
+from torch import nn
+
+from uvcgan_s.torch.select import get_norm_layer, get_activ_layer
+from .cnn import calc_conv_output_size
+
+class ResNetBlock(nn.Module):
+
+ def __init__(
+ self, features, activ, norm, rezero = False,
+ kernel_size = 3, bottlneck_features = None, **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__(**kwargs)
+
+ if bottlneck_features is None:
+ bottlneck_features = features
+
+ self.block = nn.Sequential(
+ nn.Conv2d(
+ features, bottlneck_features,
+ kernel_size = kernel_size,
+ padding = 'same',
+ stride = 1,
+ ),
+ get_norm_layer(norm, bottlneck_features),
+ get_activ_layer(activ),
+
+ nn.Conv2d(
+ bottlneck_features, features,
+ kernel_size = kernel_size,
+ padding = 'same',
+ stride = 1
+ ),
+ get_norm_layer(norm, features),
+ )
+
+ self.block_out = get_activ_layer(activ)
+
+ if rezero:
+ self.re_alpha = nn.Parameter(torch.zeros((1, )))
+ else:
+ self.re_alpha = 1
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+ y = self.block(x)
+ z = x + self.re_alpha * y
+
+ return self.block_out(z)
+
+ def extra_repr(self):
+ return 're_alpha = %e' % (self.re_alpha, )
+
+class ResNetBlockv2(nn.Module):
+
+ def __init__(
+ self, features, activ, norm, rezero = False,
+ kernel_size = 3, bottlneck_features = None, **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__(**kwargs)
+
+ if bottlneck_features is None:
+ bottlneck_features = features
+
+ self.block = nn.Sequential(
+ get_norm_layer(norm, bottlneck_features),
+ get_activ_layer(activ),
+
+ nn.Conv2d(
+ features, bottlneck_features,
+ kernel_size = kernel_size,
+ padding = 'same',
+ stride = 1,
+ ),
+
+ get_norm_layer(norm, features),
+ get_activ_layer(activ),
+
+ nn.Conv2d(
+ bottlneck_features, features,
+ kernel_size = kernel_size,
+ padding = 'same',
+ stride = 1
+ ),
+ )
+
+ if rezero:
+ self.re_alpha = nn.Parameter(torch.zeros((1, )))
+ else:
+ self.re_alpha = 1
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+ y = self.block(x)
+ z = x + self.re_alpha * y
+
+ return self.block_out(z)
+
+ def extra_repr(self):
+ return 're_alpha = %e' % (self.re_alpha, )
+
+class BigGanResDownBlock(nn.Module):
+
+ def __init__(
+ self, input_shape, features, activ, norm, rezero = False,
+ kernel_size = 3, bottlneck_features = None, n_blocks = 1, **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__(**kwargs)
+
+ if bottlneck_features is None:
+ bottlneck_features = features
+
+ layers = []
+
+ curr_features = input_shape[0]
+
+ for _ in range(n_blocks):
+ block = nn.Sequential(
+ get_norm_layer(norm, curr_features),
+ get_activ_layer(activ),
+ nn.Conv2d(
+ curr_features, bottlneck_features,
+ kernel_size = kernel_size, padding = 'same', stride = 1,
+ ),
+ get_norm_layer(norm, bottlneck_features),
+ get_activ_layer(activ),
+ nn.Conv2d(
+ bottlneck_features, features,
+ kernel_size = kernel_size, padding = 'same', stride = 1
+ )
+ )
+
+ layers.append(block)
+ curr_features = features
+
+ layers.append(nn.AvgPool2d(kernel_size = 2, stride = 2))
+
+ self.net_main = nn.Sequential(*layers)
+ self.net_res = nn.Sequential(
+ nn.Conv2d(input_shape[0], features, kernel_size = 1),
+ nn.AvgPool2d(kernel_size = 2, stride = 2)
+ )
+
+ self._input_shape = input_shape
+ self._output_shape = (
+ features, input_shape[1] // 2, input_shape[2] // 2
+ )
+
+ if rezero:
+ self.re_alpha = nn.Parameter(torch.zeros((1, )))
+ else:
+ self.re_alpha = 1
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+ main = self.net_main(x)
+ res = self.net_res(x)
+
+ return res + self.re_alpha * main
+
+ @property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def extra_repr(self):
+ return 're_alpha = %e' % (self.re_alpha, )
+
+class BigGanDeepResDownBlock(nn.Module):
+
+ def __init__(
+ self, input_shape, features, activ, norm, rezero = False,
+ kernel_size = 3, bottlneck_features = None, n_blocks = 1, **kwargs
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__(**kwargs)
+
+ if bottlneck_features is None:
+ bottlneck_features = features
+
+ layers = []
+ layers.append(nn.Sequential(
+ get_norm_layer(norm, input_shape[0]),
+ get_activ_layer(activ),
+ nn.Conv2d(input_shape[0], features, kernel_size = 1)
+ ))
+
+ for _ in range(n_blocks):
+ block = nn.Sequential(
+ get_norm_layer(norm, features),
+ get_activ_layer(activ),
+ nn.Conv2d(
+ features, bottlneck_features,
+ kernel_size = kernel_size, padding = 'same', stride = 1,
+ ),
+ get_norm_layer(norm, bottlneck_features),
+ get_activ_layer(activ),
+ nn.Conv2d(
+ bottlneck_features, features,
+ kernel_size = kernel_size, padding = 'same', stride = 1
+ )
+ )
+
+ layers.append(block)
+
+ layers.append(nn.Sequential(
+ nn.AvgPool2d(kernel_size = 2, stride = 2),
+ nn.Conv2d(features, features, kernel_size = 1)
+ ))
+
+ self.net_main = nn.Sequential(*layers)
+ self.net_res_stem = nn.AvgPool2d(kernel_size = 2, stride = 2)
+
+ if features > input_shape[0]:
+ self.net_res_side = nn.Conv2d(
+ input_shape[0], features - input_shape[0], kernel_size = 1
+ )
+ else:
+ self.net_res_side = None
+
+ self._input_shape = input_shape
+ self._output_shape = (
+ features, input_shape[1] // 2, input_shape[2] // 2
+ )
+
+ if rezero:
+ self.re_alpha = nn.Parameter(torch.zeros((1, )))
+ else:
+ self.re_alpha = 1
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+ main = self.net_main(x)
+ res_stem = self.net_res_stem(x)
+
+ if self.net_res_side is None:
+ res = res_stem
+ else:
+ res_side = self.net_res_side(res_stem)
+ res = torch.cat((res_stem, res_side), dim = 1)
+
+ return res + self.re_alpha * main
+
+ @property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def extra_repr(self):
+ return 're_alpha = %e' % (self.re_alpha, )
+
+
+class ResNetStem(nn.Module):
+
+ def __init__(
+ self, input_shape, features, norm = None,
+ kernel_size = 4, padding = 0, stride = 4
+ ):
+ # pylint: disable=too-many-arguments
+ super().__init__()
+
+ self.net = nn.Sequential(
+ nn.Conv2d(
+ input_shape[0], features,
+ kernel_size = kernel_size, padding = padding, stride = stride
+ ),
+ get_norm_layer(norm, features),
+ )
+
+ self._input_shape = input_shape
+ self._output_shape = (
+ features,
+ *calc_conv_output_size(
+ input_shape[1:], kernel_size, padding, stride
+ )
+ )
+
+ @property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def forward(self, x):
+ return self.net(x)
+
+class ResNetEncoder(nn.Module):
+ # pylint: disable=too-many-instance-attributes
+
+ def _make_stem_block(self, block_spec, curr_shape):
+ # pylint: disable=no-self-use
+ block = ResNetStem(curr_shape, **block_spec)
+ curr_shape = block.output_shape
+
+ return (block, curr_shape)
+
+ def _make_resample_block(self, block_spec, curr_shape):
+ # pylint: disable=no-self-use
+ if isinstance(block_spec, (list, tuple)):
+ size, kwargs = block_spec
+ else:
+ size = block_spec
+ kwargs = {}
+
+ if isinstance(size, int):
+ size = (size, size)
+
+ features = curr_shape[0]
+ block = nn.Upsample(size, **kwargs)
+ curr_shape = (features, *size)
+
+ return (block, curr_shape)
+
+ def _make_resnet_block(self, block_spec, curr_shape):
+ if isinstance(block_spec, (list, tuple)):
+ n_blocks, kwargs = block_spec
+ else:
+ n_blocks = block_spec
+ kwargs = {}
+
+ features = curr_shape[0]
+ block = nn.Sequential(
+ *[ ResNetBlock(
+ features,
+ activ = self._activ,
+ norm = self._norm,
+ rezero = self._rezero,
+ **kwargs
+ )
+ for _ in range(n_blocks) ]
+ )
+
+ return (block, curr_shape)
+
+ def _make_resnet_block_v2(self, block_spec, curr_shape):
+ if isinstance(block_spec, (list, tuple)):
+ n_blocks, kwargs = block_spec
+ else:
+ n_blocks = block_spec
+ kwargs = {}
+
+ features = curr_shape[0]
+ block = nn.Sequential(
+ ResNetBlockv2(
+ features,
+ activ = self._activ,
+ norm = self._norm,
+ rezero = self._rezero,
+ **kwargs
+ ) for _ in range(n_blocks)
+ )
+
+ return (block, curr_shape)
+
+ def _make_biggan_resdown_block(self, block_spec, curr_shape):
+ # pylint: disable=no-self-use
+ block = BigGanResDownBlock(
+ curr_shape,
+ activ = self._activ,
+ norm = self._norm,
+ rezero = self._rezero,
+ **block_spec
+ )
+ return (block, block.output_shape)
+
+ def _make_biggan_deep_resdown_block(self, block_spec, curr_shape):
+ # pylint: disable=no-self-use
+ block = BigGanDeepResDownBlock(
+ curr_shape,
+ activ = self._activ,
+ norm = self._norm,
+ rezero = self._rezero,
+ **block_spec
+ )
+ return (block, block.output_shape)
+
+ def __init__(
+ self, input_shape, block_specs, activ, norm, rezero = True
+ ):
+ # pylint: disable=too-many-arguments
+ # pylint: disable=too-many-locals
+ super().__init__()
+
+ curr_shape = input_shape
+ self.blocks = nn.ModuleList()
+
+ self._activ = activ
+ self._norm = norm
+ self._rezero = rezero
+
+ block_idx = 0
+ skip_indices = set()
+ skip_shapes = []
+
+ for (block_type, block_spec) in block_specs:
+ if block_type == 'stem':
+ block, curr_shape \
+ = self._make_stem_block(block_spec, curr_shape)
+
+ elif block_type == 'resample':
+ block, curr_shape \
+ = self._make_resample_block(block_spec, curr_shape)
+
+ elif block_type == 'resnet':
+ block, curr_shape \
+ = self._make_resnet_block(block_spec, curr_shape)
+
+ elif block_type == 'resnet-v2':
+ block, curr_shape \
+ = self._make_resnet_block_v2(block_spec, curr_shape)
+
+ elif block_type == 'biggan-resdown':
+ block, curr_shape \
+ = self._make_biggan_resdown_block(block_spec, curr_shape)
+
+ elif block_type == 'biggan-deep-resdown':
+ block, curr_shape = self._make_biggan_deep_resdown_block(
+ block_spec, curr_shape
+ )
+
+ elif block_type == 'skip':
+ skip_indices.add(block_idx)
+ skip_shapes.append(curr_shape)
+ continue
+
+ else:
+ raise ValueError(f"Unknown block type: {block_type}")
+
+ self.blocks.append(block)
+ block_idx += 1
+
+ self._input_shape = input_shape
+ self._output_shape = curr_shape
+ self._skip_shapes = skip_shapes
+ self._skip_indices = skip_indices
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ @property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
+ def skip_indices(self):
+ return self._skip_indices
+
+ @property
+ def skip_shapes(self):
+ return self._skip_shapes
+
+ def forward(self, x, return_skips = False):
+ if return_skips:
+ skips = []
+
+ y = x
+
+ for idx, block in enumerate(self.blocks):
+ if return_skips and (idx in self._skip_indices):
+ skips.append(y)
+
+ y = block(y)
+
+ if return_skips:
+ return (y, skips)
+
+ return y
+
diff --git a/uvcgan_s/torch/layers/transformer.py b/uvcgan_s/torch/layers/transformer.py
new file mode 100644
index 0000000..52ac722
--- /dev/null
+++ b/uvcgan_s/torch/layers/transformer.py
@@ -0,0 +1,415 @@
+# pylint: disable=too-many-arguments
+# pylint: disable=too-many-instance-attributes
+
+import torch
+from torch import nn
+
+from uvcgan_s.torch.select import get_norm_layer, get_activ_layer
+
+def calc_tokenized_size(image_shape, token_size):
+ # image_shape : (C, H, W)
+ # token_size : (H_t, W_t)
+ if image_shape[1] % token_size[0] != 0:
+ raise ValueError(
+ "Token width %d does not divide image width %d" % (
+ token_size[0], image_shape[1]
+ )
+ )
+
+ if image_shape[2] % token_size[1] != 0:
+ raise ValueError(
+ "Token height %d does not divide image height %d" % (
+ token_size[1], image_shape[2]
+ )
+ )
+
+ # result : (N_h, N_w)
+ return (image_shape[1] // token_size[0], image_shape[2] // token_size[1])
+
+def img_to_tokens(image_batch, token_size):
+ # image_batch : (N, C, H, W)
+ # token_size : (H_t, W_t)
+
+ # result : (N, C, N_h, H_t, W)
+ result = image_batch.view(
+ (*image_batch.shape[:2], -1, token_size[0], image_batch.shape[3])
+ )
+
+ # result : (N, C, N_h, H_t, W )
+ # -> (N, C, N_h, H_t, N_w, W_t)
+ result = result.view((*result.shape[:4], -1, token_size[1]))
+
+ # result : (N, C, N_h, H_t, N_w, W_t)
+ # -> (N, N_h, N_w, C, H_t, W_t)
+ result = result.permute((0, 2, 4, 1, 3, 5))
+
+ return result
+
+def img_from_tokens(tokens):
+ # tokens : (N, N_h, N_w, C, H_t, W_t)
+ # result : (N, C, N_h, H_t, N_w, W_t)
+ result = tokens.permute((0, 3, 1, 4, 2, 5))
+
+ # result : (N, C, N_h, H_t, N_w, W_t)
+ # -> (N, C, N_h, H_t, N_w * W_t)
+ # = (N, C, N_h, H_t, W)
+ result = result.reshape((*result.shape[:4], -1))
+
+ # result : (N, C, N_h, H_t, W)
+ # -> (N, C, N_h * H_t, W)
+ # = (N, C, H, W)
+ result = result.reshape((*result.shape[:2], -1, result.shape[4]))
+
+ return result
+
+def img_to_pixelwise_tokens(image):
+ # image : (N, C, H, W)
+
+ # result : (N, C, H * W)
+ result = image.view(*image.shape[:2], -1)
+
+ # result : (N, C, H * W)
+ # -> (N, H * W, C )
+ # = (N, L, C)
+ result = result.permute((0, 2, 1))
+
+ # (N, L, C)
+ return result
+
+def img_from_pixelwise_tokens(tokens, image_shape):
+ # tokens : (N, L, C)
+ # image_shape : (3, )
+
+ # tokens : (N, L, C)
+ # -> (N, C, L)
+ # = (N, C, H * W)
+ tokens = tokens.permute((0, 2, 1))
+
+ # (N, C, H, W)
+ return tokens.view(*tokens.shape[:2], *image_shape[1:])
+
+class PositionWiseFFN(nn.Module):
+
+ def __init__(self, features, ffn_features, activ = 'gelu', **kwargs):
+ super().__init__(**kwargs)
+
+ self.net = nn.Sequential(
+ nn.Linear(features, ffn_features),
+ get_activ_layer(activ),
+ nn.Linear(ffn_features, features),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+class TransformerBlock(nn.Module):
+
+ def __init__(
+ self, features, ffn_features, n_heads, activ = 'gelu', norm = None,
+ rezero = True, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.norm1 = get_norm_layer(norm, features)
+ self.atten = nn.MultiheadAttention(features, n_heads)
+
+ self.norm2 = get_norm_layer(norm, features)
+ self.ffn = PositionWiseFFN(features, ffn_features, activ)
+
+ self.rezero = rezero
+
+ if rezero:
+ self.re_alpha = nn.Parameter(torch.zeros((1, )))
+ else:
+ self.re_alpha = 1
+
+ def forward(self, x):
+ # x: (L, N, features)
+
+ # Step 1: Multi-Head Self Attention
+ y1 = self.norm1(x)
+ y1, _atten_weights = self.atten(y1, y1, y1)
+
+ y = x + self.re_alpha * y1
+
+ # Step 2: PositionWise Feed Forward Network
+ y2 = self.norm2(y)
+ y2 = self.ffn(y2)
+
+ y = y + self.re_alpha * y2
+
+ return y
+
+ def extra_repr(self):
+ return 're_alpha = %e' % (self.re_alpha, )
+
+class TransformerEncoder(nn.Module):
+
+ def __init__(
+ self, features, ffn_features, n_heads, n_blocks, activ, norm,
+ rezero = True, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.encoder = nn.Sequential(*[
+ TransformerBlock(
+ features, ffn_features, n_heads, activ, norm, rezero
+ ) for _ in range(n_blocks)
+ ])
+
+ def forward(self, x):
+ # x : (N, L, features)
+
+ # y : (L, N, features)
+ y = x.permute((1, 0, 2))
+ y = self.encoder(y)
+
+ # result : (N, L, features)
+ result = y.permute((1, 0, 2))
+
+ return result
+
+class FourierEmbedding(nn.Module):
+ # arXiv: 2011.13775
+
+ def __init__(self, features, height, width, **kwargs):
+ super().__init__(**kwargs)
+ self.projector = nn.Linear(2, features)
+ self._height = height
+ self._width = width
+
+ def forward(self, y, x):
+ # x : (N, L)
+ # y : (N, L)
+ x_norm = 2 * x / (self._width - 1) - 1
+ y_norm = 2 * y / (self._height - 1) - 1
+
+ # z : (N, L, 2)
+ z = torch.cat((x_norm.unsqueeze(2), y_norm.unsqueeze(2)), dim = 2)
+
+ return torch.sin(self.projector(z))
+
+class ViTInput(nn.Module):
+
+ def __init__(
+ self, input_features, embed_features, features, height, width,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self._height = height
+ self._width = width
+
+ x = torch.arange(width).to(torch.float32)
+ y = torch.arange(height).to(torch.float32)
+
+ x, y = torch.meshgrid(x, y)
+ self.x = x.reshape((1, -1))
+ self.y = y.reshape((1, -1))
+
+ self.register_buffer('x_const', self.x)
+ self.register_buffer('y_const', self.y)
+
+ self.embed = FourierEmbedding(embed_features, height, width)
+ self.output = nn.Linear(embed_features + input_features, features)
+
+ def forward(self, x):
+ # x : (N, L, input_features)
+ # embed : (1, height * width, embed_features)
+ # = (1, L, embed_features)
+ embed = self.embed(self.y_const, self.x_const)
+
+ # embed : (1, L, embed_features)
+ # -> (N, L, embed_features)
+ embed = embed.expand((x.shape[0], *embed.shape[1:]))
+
+ # result : (N, L, embed_features + input_features)
+ result = torch.cat([embed, x], dim = 2)
+
+ # (N, L, features)
+ return self.output(result)
+
+class PixelwiseViT(nn.Module):
+
+ def __init__(
+ self, features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm, image_shape, rezero = True, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.image_shape = image_shape
+
+ self.trans_input = ViTInput(
+ image_shape[0], embed_features, features,
+ image_shape[1], image_shape[2],
+ )
+
+ self.encoder = TransformerEncoder(
+ features, ffn_features, n_heads, n_blocks, activ, norm, rezero
+ )
+
+ self.trans_output = nn.Linear(features, image_shape[0])
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+
+ # itokens : (N, C, H * W)
+ itokens = x.view(*x.shape[:2], -1)
+
+ # itokens : (N, C, H * W)
+ # -> (N, H * W, C )
+ # = (N, L, C)
+ itokens = itokens.permute((0, 2, 1))
+
+ # y : (N, L, features)
+ y = self.trans_input(itokens)
+ y = self.encoder(y)
+
+ # otokens : (N, L, C)
+ otokens = self.trans_output(y)
+
+ # otokens : (N, L, C)
+ # -> (N, C, L)
+ # = (N, C, H * W)
+ otokens = otokens.permute((0, 2, 1))
+
+ # result : (N, C, H, W)
+ result = otokens.view(*otokens.shape[:2], *self.image_shape[1:])
+
+ return result
+
+class ExtendedTransformerEncoder(nn.Module):
+
+ def __init__(
+ self, features, n_heads, n_blocks, ffn_features, activ, norm,
+ rezero = True, n_ext = 1, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.encoder = TransformerEncoder(
+ features, ffn_features, n_heads, n_blocks, activ, norm, rezero
+ )
+
+ self.extra_tokens = nn.Parameter(torch.empty((1, n_ext, features)))
+ torch.nn.init.normal_(self.extra_tokens)
+
+ def forward(self, x):
+ # x : (N, L, C)
+ N, L, _C = x.shape
+
+ # i_extra_tokens : (N, n_extra, C)
+ i_extra_tokens = self.extra_tokens.tile(N, 1, 1)
+
+ # y : (N, L + n_extra, C)
+ y = torch.cat([ x, i_extra_tokens ], dim = 1)
+ y = self.encoder(y)
+
+ # o_extra_tokens : (N, n_extra, features)
+ o_extra_tokens = y[:, L:, :]
+
+ # result : (N, L, C)
+ result = y[:, :L, :]
+
+ return (result, o_extra_tokens.reshape(N, -1))
+
+class ExtendedPixelwiseViT(nn.Module):
+
+ def __init__(
+ self, features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm, image_shape, rezero = True, n_ext = 1, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.image_shape = image_shape
+
+ self.trans_input = ViTInput(
+ image_shape[0], embed_features, features,
+ image_shape[1], image_shape[2],
+ )
+
+ self.encoder = TransformerEncoder(
+ features, ffn_features, n_heads, n_blocks, activ, norm, rezero
+ )
+
+ self.extra_tokens = nn.Parameter(torch.empty((1, n_ext, features)))
+ torch.nn.init.normal_(self.extra_tokens)
+
+ self.trans_output = nn.Linear(features, image_shape[0])
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+
+ # itokens : (N, L, C)
+ itokens = img_to_pixelwise_tokens(x)
+ (N, L, _C) = itokens.shape
+
+ # i_extra_tokens : (N, n_extra, C)
+ i_extra_tokens = self.extra_tokens.tile(itokens.shape[0], 1, 1)
+
+ # y : (N, L, features)
+ y = self.trans_input(itokens)
+
+ # y : (N, L + n_extra, C)
+ y = torch.cat([ y, i_extra_tokens ], dim = 1)
+ y = self.encoder(y)
+
+ # o_extra_tokens : (N, n_extra, features)
+ o_extra_tokens = y[:, L:, :]
+
+ # otokens : (N, L, C)
+ otokens = self.trans_output(y[:, :L, :])
+
+ # result : (N, C, H, W)
+ result = img_from_pixelwise_tokens(otokens, self.image_shape)
+
+ return (result, o_extra_tokens.reshape(N, -1))
+
+class CExtPixelwiseViT(nn.Module):
+
+ def __init__(
+ self, features, n_heads, n_blocks, ffn_features, embed_features,
+ activ, norm, image_shape, rezero = True, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.image_shape = image_shape
+
+ self.trans_input = ViTInput(
+ image_shape[0], embed_features, features,
+ image_shape[1], image_shape[2],
+ )
+
+ self.encoder = TransformerEncoder(
+ features, ffn_features, n_heads, n_blocks, activ, norm, rezero
+ )
+
+ self.trans_output = nn.Linear(features, image_shape[0])
+
+ def forward(self, x, control):
+ # x : (N, C, H, W)
+
+ # itokens : (N, L, C)
+ itokens = img_to_pixelwise_tokens(x)
+ (_N, L, _C) = itokens.shape
+
+ # control : (N, C)
+ # i_control : (N, 1, C)
+ i_control = control.unsqueeze(1)
+
+ # y : (N, L, features)
+ y = self.trans_input(itokens)
+
+ # y : (N, L + 1, C)
+ y = torch.cat([ y, i_control ], dim = 1)
+ y = self.encoder(y)
+
+ # o_control : (N, 1, features)
+ o_control = y[:, L:, :]
+
+ # otokens : (N, L, C)
+ otokens = self.trans_output(y[:, :L, :])
+
+ # result : (N, C, H, W)
+ result = img_from_pixelwise_tokens(otokens, self.image_shape)
+
+ return (result, o_control.squeeze(1))
+
diff --git a/uvcgan_s/torch/layers/unet.py b/uvcgan_s/torch/layers/unet.py
new file mode 100644
index 0000000..c8ede23
--- /dev/null
+++ b/uvcgan_s/torch/layers/unet.py
@@ -0,0 +1,331 @@
+# pylint: disable=too-many-arguments
+# pylint: disable=too-many-instance-attributes
+
+import torch
+from torch import nn
+
+from uvcgan_s.torch.select import get_norm_layer, get_activ_layer
+
+from .cnn import get_downsample_x2_layer, get_upsample_x2_layer
+
+class UnetBasicBlock(nn.Module):
+
+ def __init__(
+ self, in_features, out_features, activ, norm, mid_features = None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ if mid_features is None:
+ mid_features = out_features
+
+ self.block = nn.Sequential(
+ get_norm_layer(norm, in_features),
+ nn.Conv2d(in_features, mid_features, kernel_size = 3, padding = 1),
+ get_activ_layer(activ),
+
+ get_norm_layer(norm, mid_features),
+ nn.Conv2d(
+ mid_features, out_features, kernel_size = 3, padding = 1
+ ),
+ get_activ_layer(activ),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+class UNetEncBlock(nn.Module):
+
+ def __init__(
+ self, features, activ, norm, downsample, input_shape, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.downsample, output_features = \
+ get_downsample_x2_layer(downsample, features)
+
+ (C, H, W) = input_shape
+ self.block = UnetBasicBlock(C, features, activ, norm)
+
+ self._output_shape = (output_features, H//2, W//2)
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def forward(self, x):
+ r = self.block(x)
+ y = self.downsample(r)
+ return (y, r)
+
+class UNetDecBlock(nn.Module):
+
+ def __init__(
+ self, input_shape, output_features, skip_features, activ, norm,
+ upsample, rezero = False, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ (input_features, H, W) = input_shape
+ self.upsample, input_features = get_upsample_x2_layer(
+ upsample, input_features
+ )
+
+ self.block = UnetBasicBlock(
+ input_features + skip_features, output_features, activ, norm,
+ mid_features = max(skip_features, input_features, output_features)
+ )
+
+ self._output_shape = (output_features, 2 * H, 2 * W)
+
+ if rezero:
+ self.re_alpha = nn.Parameter(torch.zeros((1, )))
+ else:
+ self.re_alpha = 1
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def forward(self, x, r):
+ # x : (N, C, H_in, W_in)
+ # r : (N, C_skip, H_out, W_out)
+
+ # x : (N, C_up, H_out, W_out)
+ x = self.re_alpha * self.upsample(x)
+
+ # y : (N, C_skip + C_up, H_out, W_out)
+ y = torch.cat([x, r], dim = 1)
+
+ # result : (N, C_out, H_out, W_out)
+ return self.block(y)
+
+ def extra_repr(self):
+ return 're_alpha = %e' % (self.re_alpha, )
+
+class UNetBlock(nn.Module):
+
+ def __init__(
+ self, features, activ, norm, image_shape, downsample, upsample,
+ rezero = True, output_features = None, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.conv = UNetEncBlock(
+ features, activ, norm, downsample, image_shape
+ )
+
+ self.inner_shape = self.conv.output_shape
+ self.inner_module = None
+
+ if output_features is None:
+ output_features = image_shape[0]
+
+ self.deconv = UNetDecBlock(
+ self.inner_shape, output_features, self.inner_shape[0],
+ activ, norm, upsample, rezero
+ )
+
+ def get_inner_shape(self):
+ return self.inner_shape
+
+ def set_inner_module(self, module):
+ self.inner_module = module
+
+ def get_inner_module(self):
+ return self.inner_module
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+
+ # y : (N, C_inner, H_inner, W_inner)
+ # r : (N, C_inner, H, W)
+ (y, r) = self.conv(x)
+
+ # y : (N, C_inner, H_inner, W_inner)
+ y = self.inner_module(y)
+
+ # y : (N, C, H, W)
+ y = self.deconv(y, r)
+
+ return y
+
+class UNetLinearEncoder(nn.Module):
+
+ def __init__(
+ self, features_list, image_shape, activ, norm, downsample, **kwargs
+ ):
+ # pylint: disable = too-many-locals
+ super().__init__(**kwargs)
+
+ self.features_list = features_list
+ self.image_shape = image_shape
+ self._skip_shapes = []
+
+ self.net = nn.ModuleList()
+ curr_shape = image_shape
+
+ for features in features_list:
+ layer = UNetEncBlock(features, activ, norm, downsample, curr_shape)
+ self.net.append(layer)
+
+ curr_shape = layer.output_shape
+ self._skip_shapes.append(curr_shape)
+
+ self._output_shape = curr_shape
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ @property
+ def skip_shapes(self):
+ return self._skip_shapes
+
+ def forward(self, x, return_skips = True):
+ # x : (N, C, H, W)
+
+ skips = [ ]
+ y = x
+
+ for layer in self.net:
+ y, r = layer(y)
+ if return_skips:
+ skips.append(r)
+
+ if return_skips:
+ return y, skips
+ else:
+ return y
+
+class UNetLinearDecoder(nn.Module):
+
+ def __init__(
+ self, features_list, input_shape, output_shape, skip_shapes,
+ activ, norm, upsample, **kwargs
+ ):
+ # pylint: disable = too-many-locals
+ super().__init__(**kwargs)
+
+ self.net = nn.ModuleList()
+ self._input_shape = input_shape
+ self._output_shape = output_shape
+ curr_shape = input_shape
+
+ for features, skip_shape in zip(
+ features_list[::-1], skip_shapes[::-1]
+ ):
+ layer = UNetDecBlock(
+ curr_shape, features, skip_shape[0], activ, norm, upsample
+ )
+ curr_shape = layer.output_shape
+
+ self.net.append(layer)
+
+ if output_shape[0] == curr_shape[0]:
+ self.output = nn.Identity()
+ else:
+ self.output = nn.Conv2d(
+ curr_shape[0], output_shape[0], kernel_size = 1
+ )
+
+ curr_shape = (output_shape[0], *curr_shape[1:])
+ assert tuple(output_shape) == tuple(curr_shape)
+
+ @property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def forward(self, x, skip_list):
+ # x : (N, C, H, W)
+ # skip_list : List[(N, C_i, H_i, W_i)]
+
+ y = x
+
+ for layer, skip in zip(self.net, skip_list[::-1]):
+ y = layer(y, skip)
+
+ return self.output(y)
+
+class UNet(nn.Module):
+
+ def __init__(
+ self, features_list, activ, norm, image_shape, downsample, upsample,
+ rezero = True, output_shape = None, **kwargs
+ ):
+ # pylint: disable = too-many-locals
+ super().__init__(**kwargs)
+
+ self.features_list = features_list
+ self.image_shape = image_shape
+
+ if output_shape is None:
+ output_shape = image_shape
+
+ assert tuple(output_shape[1:]) == tuple(image_shape[1:])
+
+ self.output_shape = output_shape
+
+ self._construct_input_layer(activ)
+ self._construct_output_layer()
+
+ unet_layers = []
+ curr_image_shape = (features_list[0], *image_shape[1:])
+
+ for features in features_list:
+ layer = UNetBlock(
+ features, activ, norm, curr_image_shape, downsample, upsample,
+ rezero
+ )
+ curr_image_shape = layer.get_inner_shape()
+ unet_layers.append(layer)
+
+ for idx in range(len(unet_layers)-1):
+ unet_layers[idx].set_inner_module(unet_layers[idx+1])
+
+ self.unet = unet_layers[0]
+
+ def _construct_input_layer(self, activ):
+ self.layer_input = nn.Sequential(
+ nn.Conv2d(
+ self.image_shape[0], self.features_list[0],
+ kernel_size = 3, padding = 1
+ ),
+ get_activ_layer(activ),
+ )
+
+ def _construct_output_layer(self):
+ self.layer_output = nn.Conv2d(
+ self.features_list[0], self.output_shape[0], kernel_size = 1
+ )
+
+ def get_innermost_block(self):
+ result = self.unet
+
+ for _ in range(len(self.features_list)-1):
+ result = result.get_inner_module()
+
+ return result
+
+ def set_bottleneck(self, module):
+ self.get_innermost_block().set_inner_module(module)
+
+ def get_bottleneck(self):
+ return self.get_innermost_block().get_inner_module()
+
+ def get_inner_shape(self):
+ return self.get_innermost_block().get_inner_shape()
+
+ def forward(self, x):
+ # x : (N, C, H, W)
+
+ y = self.layer_input(x)
+ y = self.unet(y)
+ y = self.layer_output(y)
+
+ return y
+
diff --git a/uvcgan_s/torch/lr_equal.py b/uvcgan_s/torch/lr_equal.py
new file mode 100644
index 0000000..4ff4819
--- /dev/null
+++ b/uvcgan_s/torch/lr_equal.py
@@ -0,0 +1,59 @@
+import math
+
+import torch
+from torch import nn
+
+from torch.nn.init import _calculate_correct_fan, calculate_gain
+
+class LearningRateEqualizer(nn.Module):
+
+ def __init__(self, mode = 'fan_in', nonlinearity = 'leaky_relu', a = 0):
+ super().__init__()
+
+ self._mode = mode
+ self._activ = nonlinearity
+ self._param = a
+
+ def _calc_scale(self, w):
+ fan = _calculate_correct_fan(w, self._mode)
+ gain = calculate_gain(self._activ, self._param)
+ std = gain / math.sqrt(fan)
+
+ return std
+
+ def forward(self, w):
+ scale = self._calc_scale(w)
+ return w * scale
+
+# NOTE:
+# WARNING:
+# Behavior of parametrization changes between pytorch 1.9 and 1.10.
+# In pytorch 1.9 original_weight = unparametrized weight
+# In pytorch >1.9 original_weight = right_inverse(unparametrized weight)
+# To make the behavior consistent -- the right_inverse is masked for now.
+#
+# def right_inverse(self, w):
+# scale = self._calc_scale(w)
+# return w / scale
+
+def apply_lr_equal_to_module(module, name, param):
+ if isinstance(module, torch.nn.utils.parametrize.ParametrizationList):
+ return
+
+ if not hasattr(module, name):
+ return
+
+ w = getattr(module, name)
+
+ if (w is None) or len(w.shape) < 2:
+ return
+
+ torch.nn.utils.parametrize.register_parametrization(module, name, param)
+
+def apply_lr_equal(module, tensor_name = "weight", **kwargs):
+ parametrization = LearningRateEqualizer(**kwargs)
+ submodule_list = [ x[1] for x in module.named_modules() ]
+
+ for m in submodule_list:
+ apply_lr_equal_to_module(m, tensor_name, parametrization)
+
diff --git a/uvcgan_s/torch/queue.py b/uvcgan_s/torch/queue.py
new file mode 100644
index 0000000..0434dba
--- /dev/null
+++ b/uvcgan_s/torch/queue.py
@@ -0,0 +1,81 @@
+from collections import deque
+import torch
+
+class Queue:
+
+ def __init__(self, size):
+ self._queue = deque(maxlen = size)
+
+ def __len__(self):
+ return len(self._queue)
+
+ def push(self, x):
+ self._queue.append(x.detach())
+
+ def query(self):
+ return tuple(self._queue)
+
+# NOTE: FastQueue differs from Queue
+# Queue size means the number of objects to store
+# FastQueue size means the full length of the queue tensor
+
+class FastQueue:
+
+ def __init__(self, size, device):
+ self._queue = None
+ self._size = size
+ self._device = device
+ self._curr_idx = 0
+ self._full = False
+
+ def __len__(self):
+ if self._full:
+ return self._size
+
+ return self._curr_idx
+
+ def lazy_init_queue(self, x):
+ if self._queue is not None:
+ return
+
+ self._queue = torch.empty(
+ (self._size, *x.shape[1:]), dtype = x.dtype, device = self._device
+ )
+
+ def push(self, x):
+ self.lazy_init_queue(x)
+
+ n = x.shape[0]
+ n_avail_to_end = self._size - self._curr_idx
+
+ if n > self._size:
+ x = x[-self._size:, ...]
+ n = self._size
+
+ if n_avail_to_end <= n:
+ self._queue[self._curr_idx:self._size, ...] \
+ = x[:n_avail_to_end, ...].detach().to(
+ self._device, non_blocking = True
+ )
+
+ self._curr_idx = 0
+ self._full = True
+
+ if n_avail_to_end < n:
+ self.push(x[n_avail_to_end:, ...])
+
+ else:
+ self._queue[self._curr_idx:self._curr_idx + n, ...] \
+ = x.detach().to(self._device)
+
+ self._curr_idx += n
+
+ def query(self):
+ if self._queue is None:
+ return None
+
+ if self._full:
+ return self._queue
+
+ return self._queue[:self._curr_idx, ...]
+
diff --git a/uvcgan_s/torch/select.py b/uvcgan_s/torch/select.py
new file mode 100644
index 0000000..8381da5
--- /dev/null
+++ b/uvcgan_s/torch/select.py
@@ -0,0 +1,107 @@
+import copy
+import torch
+from torch import nn
+
+from .layers.activation import Exponential
+
+def extract_name_kwargs(obj):
+ if isinstance(obj, dict):
+ obj = copy.copy(obj)
+ name = obj.pop('name')
+ kwargs = obj
+ else:
+ name = obj
+ kwargs = {}
+
+ return (name, kwargs)
+
+def get_norm_layer(norm, features):
+ name, kwargs = extract_name_kwargs(norm)
+
+ if name is None:
+ return nn.Identity(**kwargs)
+
+ if name == 'layer':
+ return nn.LayerNorm((features,), **kwargs)
+
+ if name == 'batch':
+ return nn.BatchNorm2d(features, **kwargs)
+
+ if name == 'group':
+ return nn.GroupNorm(num_channels = features, **kwargs)
+
+ if name == 'instance':
+ return nn.InstanceNorm2d(features, **kwargs)
+
+ raise ValueError("Unknown Layer: '%s'" % name)
+
+def get_norm_layer_fn(norm):
+ return lambda features : get_norm_layer(norm, features)
+
+def get_activ_layer(activ):
+ # pylint: disable=too-many-return-statements
+ name, kwargs = extract_name_kwargs(activ)
+
+ if (name is None) or (name == 'linear'):
+ return nn.Identity()
+
+ if name == 'gelu':
+ return nn.GELU(**kwargs)
+
+ if name == 'selu':
+ return nn.SELU(**kwargs)
+
+ if name == 'relu':
+ return nn.ReLU(inplace = True, **kwargs)
+
+ if name == 'leakyrelu':
+ return nn.LeakyReLU(inplace = True, **kwargs)
+
+ if name == 'tanh':
+ return nn.Tanh()
+
+ if name == 'sigmoid':
+ return nn.Sigmoid()
+
+ if name == 'exp':
+ return Exponential(**kwargs)
+
+ raise ValueError("Unknown activation: '%s'" % name)
+
+def select_activation(activation):
+ return get_activ_layer(activation)
+
+def select_optimizer(parameters, optimizer):
+ name, kwargs = extract_name_kwargs(optimizer)
+
+ if name == 'AdamW':
+ return torch.optim.AdamW(parameters, **kwargs)
+
+ if name == 'SGD':
+ return torch.optim.SGD(parameters, **kwargs)
+
+ if name == 'Adam':
+ return torch.optim.Adam(parameters, **kwargs)
+
+ raise ValueError("Unknown optimizer: '%s'" % name)
+
+def select_loss(loss, reduction = None):
+ name, kwargs = extract_name_kwargs(loss)
+
+ if reduction is not None:
+ kwargs['reduction'] = reduction
+
+ if name.lower() in [ 'l1', 'mae' ]:
+ return nn.L1Loss(**kwargs)
+
+ if name.lower() in [ 'l2', 'mse' ]:
+ return nn.MSELoss(**kwargs)
+
+ if name.lower() in [ 'bce', 'binary-cross-entropy' ]:
+ return nn.BCELoss(**kwargs)
+
+ if name.lower() in [ 'bce-logits', ]:
+ return nn.BCEWithLogitsLoss(**kwargs)
+
+ raise ValueError("Unknown loss: '%s'" % name)
+
diff --git a/uvcgan_s/torch/spectr_norm.py b/uvcgan_s/torch/spectr_norm.py
new file mode 100644
index 0000000..f714204
--- /dev/null
+++ b/uvcgan_s/torch/spectr_norm.py
@@ -0,0 +1,23 @@
+import torch
+from torch.nn.utils.parametrizations import spectral_norm
+
+def apply_sn_to_module(module, name, n_power_iterations = 1):
+ if isinstance(module, torch.nn.utils.parametrize.ParametrizationList):
+ return
+
+ if not hasattr(module, name):
+ return
+
+ w = getattr(module, name)
+
+ if (w is None) or len(w.shape) < 2:
+ return
+
+ spectral_norm(module, name, n_power_iterations)
+
+def apply_sn(module, tensor_name = 'weight', n_power_iterations = 1):
+ submodule_list = list(module.modules())
+
+ for m in submodule_list:
+ apply_sn_to_module(m, tensor_name, n_power_iterations)
+
diff --git a/uvcgan_s/train/__init__.py b/uvcgan_s/train/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/uvcgan_s/train/callbacks/__init__.py b/uvcgan_s/train/callbacks/__init__.py
new file mode 100644
index 0000000..92e5104
--- /dev/null
+++ b/uvcgan_s/train/callbacks/__init__.py
@@ -0,0 +1 @@
+from .history import TrainingHistory
diff --git a/uvcgan_s/train/callbacks/history.py b/uvcgan_s/train/callbacks/history.py
new file mode 100644
index 0000000..4da6b92
--- /dev/null
+++ b/uvcgan_s/train/callbacks/history.py
@@ -0,0 +1,39 @@
+import os
+import pandas as pd
+
+HISTORY_NAME = 'history.csv'
+
+class TrainingHistory:
+
+ def __init__(self, savedir):
+ self._history = None
+ self._savedir = savedir
+
+ def end_epoch(self, epoch, metrics):
+ values = metrics.values
+ values['epoch'] = epoch
+ values['time'] = pd.Timestamp.utcnow()
+
+ if self._history is None:
+ self._history = pd.DataFrame([ values, ])
+ else:
+ self._history = self._history.append([ values, ])
+
+ self.save()
+
+ def save(self):
+ self._history.to_csv(
+ os.path.join(self._savedir, HISTORY_NAME), index = False
+ )
+
+ def load(self):
+ path = os.path.join(self._savedir, HISTORY_NAME)
+
+ if os.path.exists(path):
+ self._history = pd.read_csv(path, parse_dates = [ 'time', ])
+
+ @property
+ def history(self):
+ return self._history
+
+
diff --git a/uvcgan_s/train/metrics/__init__.py b/uvcgan_s/train/metrics/__init__.py
new file mode 100644
index 0000000..f6a8572
--- /dev/null
+++ b/uvcgan_s/train/metrics/__init__.py
@@ -0,0 +1,2 @@
+from .loss_metrics import LossMetrics
+
diff --git a/uvcgan_s/train/metrics/loss_metrics.py b/uvcgan_s/train/metrics/loss_metrics.py
new file mode 100644
index 0000000..2381e84
--- /dev/null
+++ b/uvcgan_s/train/metrics/loss_metrics.py
@@ -0,0 +1,37 @@
+import copy
+
+class LossMetrics:
+
+ def __init__(self, values = None, n = 0):
+ self._values = values
+ self._n = n
+
+ @property
+ def values(self):
+ if self._values is None:
+ return None
+
+ return { k : v / self._n for (k,v) in self._values.items() }
+
+ def update(self, values):
+ if self._values is None:
+ self._values = copy.deepcopy(values)
+ else:
+ for k,v in values.items():
+ self._values[k] += v
+
+ self._n += 1
+
+ def join(self, other, other_prefix = None):
+ self_dict = self.values
+ other_dict = other.values
+
+ if other_prefix is not None:
+ other_dict = {
+ other_prefix + k : v for (k, v) in other_dict.items()
+ }
+
+ values_dict = { **self_dict, **other_dict }
+
+ return LossMetrics(values_dict, n = 1)
+
diff --git a/uvcgan_s/train/train.py b/uvcgan_s/train/train.py
new file mode 100644
index 0000000..4428891
--- /dev/null
+++ b/uvcgan_s/train/train.py
@@ -0,0 +1,84 @@
+from itertools import islice
+import tqdm
+
+from uvcgan_s.config import Args
+from uvcgan_s.data import construct_data_loaders
+from uvcgan_s.torch.funcs import get_torch_device_smart, seed_everything
+from uvcgan_s.cgan import construct_model
+from uvcgan_s.utils.log import setup_logging
+
+from .metrics import LossMetrics
+from .callbacks import TrainingHistory
+from .transfer import transfer
+
+def training_epoch(it_train, model, title, steps_per_epoch):
+ model.train()
+
+ steps = len(it_train)
+ if steps_per_epoch is not None:
+ steps = min(steps, steps_per_epoch)
+
+ progbar = tqdm.tqdm(desc = title, total = steps, dynamic_ncols = True)
+ metrics = LossMetrics()
+
+ for batch in islice(it_train, steps):
+ model.set_input(batch)
+ model.optimization_step()
+
+ metrics.update(model.get_current_losses())
+
+ progbar.set_postfix(metrics.values, refresh = False)
+ progbar.update()
+
+ progbar.close()
+ return metrics
+
+def try_continue_training(args, model):
+ history = TrainingHistory(args.savedir)
+
+ start_epoch = model.find_last_checkpoint_epoch()
+ model.load(start_epoch)
+
+ if start_epoch > 0:
+ history.load()
+
+ start_epoch = max(start_epoch, 0)
+
+ return (start_epoch, history)
+
+def train(args_dict):
+ args = Args.from_args_dict(**args_dict)
+
+ setup_logging(args.log_level)
+ seed_everything(args.config.seed)
+
+ device = get_torch_device_smart()
+ it_train = construct_data_loaders(
+ args.config.data, args.config.batch_size, split = 'train'
+ )
+
+ print("Starting training...")
+ print(args.config.to_json(indent = 4))
+
+ model = construct_model(
+ args.savedir, args.config, is_train = True, device = device
+ )
+ start_epoch, history = try_continue_training(args, model)
+
+ if (start_epoch == 0) and (args.transfer is not None):
+ transfer(model, args.transfer)
+
+ for epoch in range(start_epoch + 1, args.epochs + 1):
+ title = 'Epoch %d / %d' % (epoch, args.epochs)
+ metrics = training_epoch(
+ it_train, model, title, args.config.steps_per_epoch
+ )
+
+ history.end_epoch(epoch, metrics)
+ model.end_epoch(epoch)
+
+ if epoch % args.checkpoint == 0:
+ model.save(epoch)
+
+ model.save(epoch = None)
+
diff --git a/uvcgan_s/train/transfer.py b/uvcgan_s/train/transfer.py
new file mode 100644
index 0000000..165a4df
--- /dev/null
+++ b/uvcgan_s/train/transfer.py
@@ -0,0 +1,189 @@
+import os
+import logging
+
+import torch
+
+from uvcgan_s.consts import ROOT_OUTDIR
+from uvcgan_s.config import Args
+from uvcgan_s.cgan import construct_model
+
+LOGGER = logging.getLogger('uvcgan_s.train')
+
+def load_base_model(model, transfer_config):
+ try:
+ model.load(epoch = None)
+ return
+
+ except IOError as e:
+ if not transfer_config.allow_partial:
+ raise IOError(
+ "Failed to find fully trained model in '%s' for transfer: %s"\
+ % (model.savedir, e)
+ ) from e
+
+ LOGGER.warning(
+ (
+ "Failed to find fully trained model in '%s' for transfer."
+ " Trying to load from a checkpoint..."
+ ), model.savedir
+ )
+
+ epoch = model.find_last_checkpoint_epoch()
+
+ if epoch > 0:
+ LOGGER.warning("Load transfer model from a checkpoint '%d'", epoch)
+ else:
+ raise RuntimeError("Failed to find transfer model checkpoints.")
+
+ model.load(epoch)
+
+def get_base_model(transfer_config, device):
+ base_path = os.path.join(ROOT_OUTDIR, transfer_config.base_model)
+ base_args = Args.load(base_path)
+
+ model = construct_model(
+ base_args.savedir, base_args.config, is_train = True, device = device
+ )
+
+ load_base_model(model, transfer_config)
+
+ return model
+
+def transfer_from_larger_model(module, state_dict, strict):
+ source_keys = set(state_dict.keys())
+ target_keys = set(k for (k, _p) in module.named_parameters())
+
+ matching_dict = { k : state_dict[k] for k in target_keys }
+ module.load_state_dict(matching_dict, strict = strict)
+
+ LOGGER.warning(
+ "Transfer from a large model. Transferred %d / %d parameters",
+ len(target_keys), len(source_keys)
+ )
+
+def collect_keys_for_transfer_to_wider_model(module, state_dict, strict):
+ source_keys = set(state_dict.keys())
+ target_keys = set()
+ matching_keys = set()
+ wider_keys = set()
+ narrower_keys = set()
+
+ for (k, p_target) in module.state_dict().items():
+ target_keys.add(k)
+
+ if strict:
+ assert k in source_keys
+
+ p_source = state_dict[k]
+
+ shape_source = p_source.shape
+ shape_target = p_target.shape
+
+ if (shape_target is None) or (shape_target == shape_source):
+ matching_keys.add(k)
+
+ elif (
+ (len(shape_target) == len(shape_source))
+ and all(t >= s for (t, s) in zip(shape_target, shape_source))
+ ):
+ LOGGER.warning(
+ "Weight Transfer. Found wider parameter: '%s'. "
+ "%s vs %s",
+ k, shape_target, shape_source
+ )
+
+ wider_keys.add(k)
+
+ elif (
+ (len(shape_target) == len(shape_source))
+ and all(t <= s for (t, s) in zip(shape_target, shape_source))
+ ):
+ LOGGER.warning(
+ "Weight Transfer. Found narrower parameter: '%s'. "
+ "%s vs %s",
+ k, shape_target, shape_source
+ )
+
+ narrower_keys.add(k)
+
+ else:
+ raise ValueError(
+ "Weight Transfer. "
+ f"Cannot transfer parameter '{k}' due to mismatching"
+ f" shapes {shape_target} vs {shape_source}"
+ )
+
+ if strict and (source_keys != target_keys):
+ keys_diff = source_keys.symmetric_difference(target_keys)
+
+ raise RuntimeError(
+ "Transfer to wide model. Strict transfer failed due to"
+ f" mismatching keys {keys_diff}"
+ )
+
+ return matching_keys, wider_keys, narrower_keys
+
+def transfer_to_wider_model(module, state_dict, strict):
+ matching_keys, wider_keys, _ = \
+ collect_keys_for_transfer_to_wider_model(module, state_dict, strict)
+
+ matching_dict = { k : state_dict[k] for k in matching_keys }
+ module.load_state_dict(matching_dict, strict = False)
+
+ for k, p in module.named_parameters():
+ if k not in wider_keys:
+ continue
+
+ source_tensor = state_dict[k]
+ target_slice = tuple(slice(0, s) for s in source_tensor.shape)
+
+ with torch.no_grad():
+ p[target_slice] = source_tensor
+
+def fully_fuzzy_transfer(module, state_dict, strict):
+ matching_keys, _wider_keys, _narrower_keys = \
+ collect_keys_for_transfer_to_wider_model(module, state_dict, strict)
+
+ matching_dict = { k : state_dict[k] for k in matching_keys }
+ module.load_state_dict(matching_dict, strict = False)
+
+def transfer_state_dict(module, state_dict, fuzzy, strict):
+ if (fuzzy is None) or (fuzzy == 'none'):
+ module.load_state_dict(state_dict, strict = strict)
+
+ elif fuzzy == 'from-larger-model':
+ transfer_from_larger_model(module, state_dict, strict)
+
+ elif fuzzy == 'to-wider-model':
+ transfer_to_wider_model(module, state_dict, strict)
+
+ elif fuzzy == 'full':
+ fully_fuzzy_transfer(module, state_dict, strict)
+
+ else:
+ raise ValueError(f"Unknown fuzzy transfer type: {fuzzy}")
+
+def transfer_parameters(model, base_model, transfer_config):
+ for (dst, src) in transfer_config.transfer_map.items():
+ transfer_state_dict(
+ model.models[dst], base_model.models[src].state_dict(),
+ transfer_config.fuzzy, transfer_config.strict
+ )
+
+def transfer(model, transfer_config):
+ if transfer_config is None:
+ return
+
+ if isinstance(transfer_config, list):
+ for conf in transfer_config:
+ transfer(model, conf)
+ return
+
+ LOGGER.info(
+ "Initiating parameter transfer : '%s'", transfer_config.to_dict()
+ )
+
+ base_model = get_base_model(transfer_config, model.device)
+ transfer_parameters(model, base_model, transfer_config)
+
+
diff --git a/uvcgan_s/utils/__init__.py b/uvcgan_s/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/uvcgan_s/utils/funcs.py b/uvcgan_s/utils/funcs.py
new file mode 100644
index 0000000..f69d522
--- /dev/null
+++ b/uvcgan_s/utils/funcs.py
@@ -0,0 +1,58 @@
+import copy
+
+import torch
+from torchvision.datasets.folder import default_loader
+
+def recursive_update_dict(base_dict, new_dict):
+ if new_dict is None:
+ return
+
+ for k,v in new_dict.items():
+ if (
+ isinstance(v, dict)
+ and k in base_dict
+ and isinstance(base_dict[k], dict)
+ ):
+ recursive_update_dict(base_dict[k], v)
+ else:
+ base_dict[k] = copy.deepcopy(v)
+
+def join_dicts(*dicts_list):
+ base_dict = {}
+
+ for d in dicts_list:
+ recursive_update_dict(base_dict, d)
+
+ return base_dict
+
+def check_value_in_range(value, value_range, hint = None):
+ if value in value_range:
+ return
+
+ msg = ''
+
+ if hint is not None:
+ msg = hint + ' '
+
+ msg += f"value '{value}' is not range {value_range}"
+
+ raise ValueError(msg)
+
+@torch.no_grad()
+def load_image_fuzzy(path, transforms, device):
+ result = None
+
+ try:
+ result = default_loader(path)
+ except FileNotFoundError:
+ base_path = path.split('.', maxsplit = 1)[0]
+
+ for ext in [ '.png', '.jpg', '.jpg.png' ]:
+ try:
+ result = default_loader(base_path + ext)
+ except FileNotFoundError:
+ continue
+
+ result = transforms(result)
+ return result.to(device).unsqueeze(0)
+
diff --git a/uvcgan_s/utils/log.py b/uvcgan_s/utils/log.py
new file mode 100644
index 0000000..c03f80c
--- /dev/null
+++ b/uvcgan_s/utils/log.py
@@ -0,0 +1,23 @@
+import logging
+
+def reduce_pil_verbosity(level):
+ """A hack to stop PIL dumping large amounts of useless DEBUG info"""
+ logger = logging.getLogger('PIL')
+ logger.setLevel(max(logging.WARNING, level))
+
+def setup_logging(level = logging.DEBUG):
+ """Setup logging."""
+ logger = logging.getLogger()
+
+ formatter = logging.Formatter(
+ '[%(asctime)s] [%(name)s]: %(levelname)s %(message)s'
+ )
+
+ logger.setLevel(level)
+
+ handler = logging.StreamHandler()
+ handler.setFormatter(formatter)
+
+ logger.addHandler(handler)
+
+ reduce_pil_verbosity(logger.level)
diff --git a/uvcgan_s/utils/parsers.py b/uvcgan_s/utils/parsers.py
new file mode 100644
index 0000000..149ec82
--- /dev/null
+++ b/uvcgan_s/utils/parsers.py
@@ -0,0 +1,110 @@
+from uvcgan_s.consts import (
+ MODEL_STATE_TRAIN, MODEL_STATE_EVAL, SPLIT_TRAIN, SPLIT_TEST, SPLIT_VAL
+)
+
+def add_model_state_parser(parser):
+ parser.add_argument(
+ '--model-state',
+ choices = [ MODEL_STATE_TRAIN, MODEL_STATE_EVAL ],
+ default = 'eval',
+ dest = 'model_state',
+ help = "evaluate model in 'train' or 'eval' states",
+ type = str,
+ )
+
+def add_plot_extension_parser(parser, default = ( 'png', )):
+ parser.add_argument(
+ '-e', '--ext',
+ default = None if default is None else list(default),
+ dest = 'ext',
+ help = 'plot extensions',
+ type = str,
+ nargs = '+',
+ )
+
+def add_batch_size_parser(parser, default = 1):
+ parser.add_argument(
+ '--batch-size',
+ default = default,
+ dest = 'batch_size',
+ help = 'batch size to use for evaluation',
+ type = int,
+ )
+
+def add_n_eval_samples_parser(parser, default = None):
+ parser.add_argument(
+ '-n',
+ default = default,
+ dest = 'n_eval',
+ help = 'number of samples to use for evaluation',
+ type = int,
+ )
+
+def add_eval_type_parser(parser, default = 'transfer'):
+ parser.add_argument(
+ '--type',
+ choices = [ 'transfer', 'reco', 'masked', 'simple-reco' ],
+ default = default,
+ dest = 'eval_type',
+ help = 'type of evaluation',
+ type = str,
+ )
+
+def add_split_parser(parser, default = SPLIT_TEST):
+ parser.add_argument(
+ '--split',
+ choices = [ SPLIT_TRAIN, SPLIT_TEST, SPLIT_VAL ],
+ default = default,
+ dest = 'split',
+ help = 'data split',
+ type = str,
+ )
+
+def add_eval_epoch_parser(parser, default = None):
+ parser.add_argument(
+ '--epoch',
+ default = default,
+ dest = 'epoch',
+ help = (
+ 'checkpoint epoch to evaluate.'
+ ' If not specified, then the evaluation will be performed for'
+ ' the final model. If epoch is -1, then the evaluation will'
+ ' be performed for the last checkpoint.'
+ ),
+ type = int,
+ )
+
+def add_model_directory_parser(parser):
+ parser.add_argument(
+ 'model',
+ help = 'directory containing model to evaluate',
+ metavar = 'MODEL',
+ type = str,
+ )
+
+def add_preset_name_parser(
+ parser, name, presets, default = None, help_msg = None,
+):
+ parser.add_argument(
+ f'--{name}',
+ default = default,
+ dest = name,
+ choices = list(presets),
+ help = help_msg or name,
+ type = str,
+ )
+
+def add_standard_eval_parsers(
+ parser,
+ default_batch_size = 1,
+ default_epoch = None,
+ default_n_eval = None,
+):
+ add_model_directory_parser(parser)
+ add_model_state_parser(parser)
+ add_split_parser(parser)
+
+ add_batch_size_parser(parser, default_batch_size)
+ add_eval_epoch_parser(parser, default_epoch)
+ add_n_eval_samples_parser(parser, default_n_eval)
+