diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5cb922e --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +outdir/ +__pycache__ +*.egg-info/ +*.swp diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..82d676a --- /dev/null +++ b/.pylintrc @@ -0,0 +1,20 @@ +[MESSAGES CONTROL] + +disable= + bad-continuation, + bad-whitespace, + invalid-name, + no-else-return, + superfluous-parens, + too-few-public-methods, + trailing-newlines, + duplicate-code, + missing-function-docstring, + missing-class-docstring, + missing-module-docstring, + consider-using-f-string, + +[TYPECHECK] +generated-members= + torch + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..9d0fff5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,26 @@ +Copyright (c) 2021-2025, The LS4GAN Project Developers +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +The above copyright and license notices apply to all files in this repository +except for any file that contains its own copyright and/or license declaration. diff --git a/README.md b/README.md index cf5ed5b..023e838 100644 --- a/README.md +++ b/README.md @@ -1 +1,308 @@ -The code will be available shortly. +# UVCGAN-S: Stratified CycleGAN for Unsupervised Data Decomposition + + + +## Overview + + +This repository provides a reference implementation of Stratified CycleGAN +(UVCGAN-S), an architecture for unsupervised signal extraction from mixed data. + +What problem does Stratified CycleGAN solve? + +Imagine you have three datasets. The first contains clean signals. The second +contains backgrounds. The third contains mixed data where signals and +backgrounds have been combined in some complicated way. You don't know exactly +how the mixing happens, and you can't find pairs that show which clean signal +corresponds to which mixed observation. But you need to take new mixed data and +decompose it back into signal and background components. + +Stratified CycleGAN learns to do this decomposition from unpaired examples. +You show it random samples of signals, random samples of backgrounds and mixed +data, and it figures out both how to combine signals and backgrounds into +realistic mixed data, and how to decompose mixed data back into its parts. + +

+ +

+ + +See the [Quick Start](#quick-start-guide) section for a concrete example +using cat and dog images. + + +## Installation + +The package was tested only under Linux systems. + + +### Environment Setup + +Development environment based on +`pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime` container. + +There are several ways to setup the package environment: + +**Option 1: Docker** + +Download the docker container `pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime`. +Inside the container, create a virtual environment to avoid package conflicts: +```bash +python3 -m venv --system-site-packages ~/.venv/uvcgan-s +source ~/.venv/uvcgan-s/bin/activate +``` + +**Option 2: Conda** +```bash +conda env create -f contrib/conda_env.yaml +conda activate uvcgan-s +``` + +### Install Package + +Once the environment is set, install the `uvcgan-s` package and its +requirements: + +```bash +pip install -r requirements.txt +pip install -e . +``` + +### Environment Variables + +By default, UVCGAN-S reads datasets from `./data` and saves models to +`./outdir`. If any other location is desired, these defaults can be overriden +with: +```bash +export UVCGAN_S_DATA=/path/to/datasets +export UVCGAN_S_OUTDIR=/path/to/models +``` + +## Quick Start Guide + +This package was developed for sPHENIX jet signal extraction. However, jet +signal analysis requires familiarity with sPHENIX-specific reconstruction +algorithms and jet quality analysis procedures. This section demonstrates +application of Stratified CycleGAN method on a simpler toy problem with +intuitive interpretation. The toy example illustrates the basic workflow and +serves as a template for applying the method to your own data. + +

+ +

+ +The toy problem is this: we have images that contain a blurry mix of cat and +dog faces. The goal is to automatically decompose these mixed images into +separate cat and dog images. To this end, we present Stratified CycleGAN with +the mixed images, a random sample of cat images, and a random sample of dog +images. By observing these three collections, the model learns how cats look on +average, how dogs look on average, and what would be the best way to decompose +the current mixed image into a cat and a dog. Importantly, the model is never +shown training pairs like "this specific mixed image was created from this +specific cat image and this specific dog image." It only sees random examples +from each collection and figures out the decomposition on its own. + +### Installation + +Before proceeding further, the package needs to be installed following +instructions at the top of this README, if not installed already. + +### Dataset Preparation + +The toy example uses cat and dog images from the AFHQ dataset. To download and +preprocess it: + +```bash +# Download the AFHQ dataset +./scripts/download_dataset.sh afhq + +# Resize all images to 256x256 pixels +python3 scripts/downsize_right.py -s 256 256 -i lanczos \ + "${UVCGAN_S_DATA:-./data}/afhq/" \ + "${UVCGAN_S_DATA:-./data}/afhq_resized_lanczos" +``` + +The resizing script creates a new directory `afhq_resized_lanczos` containing +256x256 versions of all images, which is the format expected by the training +script. + +### Training + +To train the model, run the following command: +```bash +python3 scripts/train/toy_mix_blur/train_uvcgan-s.py +``` + +The script trains the Stratified CycleGAN model for 100 epochs. On an RTX 3090 +GPU, each epoch takes approximately 3 minutes, so the complete training process +requires about 5 hours. The trained model and intermediate checkpoints are +saved in the directory +`${UVCGAN_S_OUTDIR:-./outdir}/toy_mix_blur/uvcgan-s/model_m(uvcgan-s)_d(resnet)_g(vit-modnet)_cat_dog_sub/`. + +The structure of the model directory is described in the +[F.A.Q.](#what-is-the-structure-of-a-model-directory). + +### Evaluation + +

+ +

+ + +After training completes, the model can be used to decompose images from the +validation set of AFHQ. Run: +```bash +python3 scripts/translate_images.py \ + "${UVCGAN_S_OUTDIR:-./outdir}/toy_mix_blur/uvcgan-s/cat_dog_sub" \ + --split val \ + --domain 2 \ + --format image +``` + +This command takes the mixture cat-dog images (Domain B) and decomposes them +into separate cat and dog components (Domain A). The `--domain 2` flag +specifies that the input images come from Domain B, which contains the mixed +data. + +The results are saved in the model directory under +`evals/final/translated(None)_domain(2)_eval-val/`. This evaluation directory +contains several subdirectories: + +- `fake_a0/` - extracted cat components +- `fake_a1/` - extracted dog components +- `real_b/` - original blurred mixture inputs + +Each subdirectory contains numbered image files (`sample_0.png`, + `sample_1.png`, etc.) corresponding to the validation set. + + +### Adapting to Your Own Data + +To apply Stratified CycleGAN to a different decomposition problem, use the toy +example training script as a starting point. The script +`scripts/train/toy_mix_blur/train_uvcgan-s.py` contains a declarative +configuration showing how to structure the three required datasets and set up +the domain structure for decomposition. For a more complex example, see +`scripts/train/sphenix/train_uvcgan-s.py`. + + +## sPHENIX Application: Jet Background Subtraction + +The package was developed for extracting particle jets from heavy-ion collision +backgrounds in sPHENIX calorimeter data. This section describes how to +reproduce the paper results. + +### Dataset + +The sPHENIX dataset can be downloaded from Zenodo: https://zenodo.org/records/17783990 + +Alternatively, use the download script: +```bash +./scripts/download_dataset.sh sphenix +``` + +The dataset contains HDF5 files with calorimeter energy measurements organized +as 24×64 eta-phi grids. The data is split into training, validation, and test +sets. Training data consists of three components: PYTHIA jets (the signal +component), HIJING minimum-bias events (the background component), and +embedded PYTHIA+HIJING events (the mixed data). The test set uses JEWEL jets +embedded in HIJING backgrounds. JEWEL models jet-medium interactions +differently from PYTHIA, providing an out-of-distribution test of the model's +generalization capability. + + +### Training or Using Pre-trained Model + +There are two options for obtaining a trained model: training from scratch or +downloading the pre-trained model from the paper. + +To train a new model from scratch: +```bash +python3 scripts/train/sphenix/train_uvcgan-s.py +``` + +The training configuration uses the same Stratified CycleGAN architecture as +the toy example, adapted for single-channel calorimeter data. The trained model +is saved in the directory +`${UVCGAN_S_OUTDIR:-./outdir}/sphenix/uvcgan-s/model_m(uvcgan-s)_d(resnet)_g(vit-modnet)_sgn_bkg_sub` +(see [F.A.Q.](#what-is-the-structure-of-a-model-directory) for details on the + model directory structure). + +Alternatively, a pre-trained model can be downloaded from Zenodo: https://zenodo.org/records/17809156 + +The pre-trained model can be used directly for evaluation without retraining. + + +### Evaluation + +To evaluate the model on the test set: +```bash +python3 scripts/translate_images.py \ + "${UVCGAN_S_OUTDIR:-./outdir}/path/to/sphenix/model" \ + --split test \ + --domain 2 \ + --format ndarray +``` + +The `--format ndarray` flag saves results as NumPy arrays rather than images. +The output structure is similar to the toy example: extracted signal and +background components are saved in separate directories under `evals/final/`. +Each output file contains a 24×64 calorimeter energy grid that can be used for +physics analysis. + + +# F.A.Q. + +## I am training my model on a multi-GPU node. How to make sure that I use only one GPU? + +You can specify GPUs that `pytorch` will use with the help of the +`CUDA_VISIBLE_DEVICES` environment variable. This variable can be set to a list +of comma-separated GPU indices. When it is set, `pytorch` will only use GPUs +whose IDs are in the `CUDA_VISIBLE_DEVICES`. + + +## What is the structure of a model directory? + +`uvcgan-s` saves each model in a separate directory that contains: + - `MODEL/config.json` -- model architecture, training, and evaluation + configurations + - `MODEL/net_*.pth` -- PyTorch weights of model networks + - `MODEL/opt_*.pth` -- PyTorch weights of training optimizers + - `MODEL/shed_*.pth` -- PyTorch weights of training schedulers + - `MODEL/checkpoints/` -- training checkpoints + - `MODEL/evals/` -- evaluation results + + +## Training fails with "Config collision detected" error + +`uvcgan-s` enforces a one-model-per-directory policy to prevent accidental +overwrites of existing models. Each model directory must have a unique +configuration - if you try to place a model with different settings in a +directory that already contains a model, you'll receive a "Config collision +detected" error. + +This safeguard helps prevent situations where you might accidentally lose +trained models by starting a new training run with different parameters in the +same directory. + +Solutions: +1. To overwrite the old model: delete the old `config.json` configuration file + and restart the training process. +2. To preserve the old model: modify the training script of the new model and + update the `label` or `outdir` configuration options to avoid collisions. + + +# LICENSE + +`uvcgan-s` is distributed under `BSD-2` license. + +`uvcgan-s` repository contains some code (primarily in `uvcgan_s/base` +subdirectory) from [pytorch-CycleGAN-and-pix2pix][cyclegan_repo]. +This code is also licensed under `BSD-2` license (please refer to +`uvcgan_s/base/LICENSE` for details). + +Each code snippet that was taken from +[pytorch-CycleGAN-and-pix2pix][cyclegan_repo] has a note about proper copyright +attribution. + +[cyclegan_repo]: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix diff --git a/contrib/conda_env.yaml b/contrib/conda_env.yaml new file mode 100644 index 0000000..0f8e8a8 --- /dev/null +++ b/contrib/conda_env.yaml @@ -0,0 +1,119 @@ +name: uvcgan-s +channels: + - pytorch + - nvidia + - defaults + - conda-forge +dependencies: + - _libgcc_mutex=0.1 + - _openmp_mutex=4.5 + - backcall=0.2.0 + - beautifulsoup4=4.11.1 + - blas=1.0 + - brotlipy=0.7.0 + - bzip2=1.0.8 + - ca-certificates=2022.07.19 + - certifi=2022.6.15 + - cffi=1.15.0 + - chardet=4.0.0 + - charset-normalizer=2.0.4 + - colorama=0.4.4 + - conda=4.13.0 + - conda-build=3.21.9 + - conda-content-trust=0.1.1 + - conda-package-handling=1.8.1 + - cryptography=37.0.1 + - cudatoolkit=11.3.1 + - decorator=5.1.1 + - ffmpeg=4.3 + - filelock=3.6.0 + - freetype=2.11.0 + - giflib=5.2.1 + - glob2=0.7 + - gmp=6.2.1 + - gnutls=3.6.15 + - icu=58.2 + - idna=3.3 + - intel-openmp=2021.4.0 + - ipython=7.31.1 + - jedi=0.18.1 + - jinja2=2.10.1 + - jpeg=9e + - lame=3.100 + - lcms2=2.12 + - ld_impl_linux-64=2.35.1 + - libarchive=3.5.2 + - libffi=3.3 + - libgcc-ng=9.3.0 + - libgomp=9.3.0 + - libiconv=1.16 + - libidn2=2.3.2 + - liblief=0.11.5 + - libpng=1.6.37 + - libstdcxx-ng=9.3.0 + - libtasn1=4.16.0 + - libtiff=4.2.0 + - libunistring=0.9.10 + - libwebp=1.2.2 + - libwebp-base=1.2.2 + - libxml2=2.9.14 + - lz4-c=1.9.3 + - markupsafe=2.0.1 + - matplotlib-inline=0.1.2 + - mkl=2021.4.0 + - mkl-service=2.4.0 + - mkl_fft=1.3.1 + - mkl_random=1.2.2 + - ncurses=6.3 + - nettle=3.7.3 + - numpy=1.21.5 + - numpy-base=1.21.5 + - openh264=2.1.1 + - openssl=1.1.1q + - parso=0.8.3 + - patchelf=0.13 + - pexpect=4.8.0 + - pickleshare=0.7.5 + - pillow=9.0.1 + - pip=22.1.2 + - pkginfo=1.8.2 + - prompt-toolkit=3.0.20 + - psutil=5.8.0 + - ptyprocess=0.7.0 + - py-lief=0.11.5 + - pycosat=0.6.3 + - pycparser=2.21 + - pygments=2.11.2 + - pyopenssl=22.0.0 + - pysocks=1.7.1 + - python=3.7.13 + - python-libarchive-c=2.9 + - pytorch=1.12.1 + - pytorch-mutex=1.0 + - pytz=2022.1 + - pyyaml=6.0 + - readline=8.1.2 + - requests=2.27.1 + - ripgrep=12.1.1 + - ruamel_yaml=0.15.100 + - setuptools=61.2.0 + - six=1.16.0 + - soupsieve=2.3.1 + - sqlite=3.38.2 + - tk=8.6.11 + - torchtext=0.13.1 + - torchvision=0.13.1 + - tqdm=4.63.0 + - traitlets=5.1.1 + - typing_extensions=4.3.0 + - tzdata=2022a + - urllib3=1.26.8 + - wcwidth=0.2.5 + - wheel=0.37.1 + - xz=5.2.5 + - yaml=0.2.5 + - zlib=1.2.12 + - zstd=1.5.2 + - einops=0.4.* + - h5py=3.6.* + - pandas=1.3.* diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ca95049 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +einops==0.4.1 +h5py==3.6.0 +numpy==1.21.5 +pandas==1.3.3 +Pillow==9.0.1 +torch==1.12.1 +torchvision==0.13.1 +tqdm==4.63.0 diff --git a/scripts/download_dataset.sh b/scripts/download_dataset.sh new file mode 100755 index 0000000..7300076 --- /dev/null +++ b/scripts/download_dataset.sh @@ -0,0 +1,166 @@ +#!/usr/bin/env bash + +DATADIR="${UVCGAN_S_DATA:-data}" + +declare -A URL_LIST=( + [afhq]="https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip" + [sphenix]="https://zenodo.org/record/17783990/files/2025-06-05_jet_bkg_sub.tar" +) + +declare -A CHECKSUMS=( + [afhq]="7f63dcc14ef58c0e849b59091287e1844da97016073aac20403ae6c6132b950f" +) + +die () +{ + echo "${*}" + exit 1 +} + +usage () +{ + cat < 0: # create an empty pool + self.num_imgs = 0 + self.images = [] + + def query(self, images): + """Return an image from the pool. + + Parameters: + images: the latest generated images from the generator + + Returns images from the buffer. + + By 50/100, the buffer will return input images. + By 50/100, the buffer will return images previously stored in the buffer, + and insert the current images to the buffer. + """ + if self.pool_size == 0: # if the buffer size is 0, do nothing + return images + return_images = [] + for image in images: + image = torch.unsqueeze(image.data, 0) + if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: # by another 50% chance, the buffer will return the current image + return_images.append(image) + return_images = torch.cat(return_images, 0) # collect all the images and return + return return_images diff --git a/uvcgan_s/base/losses.py b/uvcgan_s/base/losses.py new file mode 100644 index 0000000..025b03a --- /dev/null +++ b/uvcgan_s/base/losses.py @@ -0,0 +1,200 @@ +import torch +from torch import nn + +def reduce_loss(loss, reduction): + if (reduction is None) or (reduction == 'none'): + return loss + + if reduction == 'mean': + return loss.mean() + + if reduction == 'sum': + return loss.sum() + + raise ValueError(f"Unknown reduction method: '{reduction}'") + +class GANLoss(nn.Module): + """Define different GAN objectives. + + The GANLoss class abstracts away the need to create the target label tensor + that has the same size as the input. + """ + + def __init__( + self, gan_mode, target_real_label = 1.0, target_fake_label = 0.0, + reduction = 'mean' + ): + """ Initialize the GANLoss class. + + Parameters: + gan_mode (str) -- the type of GAN objective. + Choices: vanilla, lsgan, and wgangp. + target_real_label (bool) -- label for a real image + target_fake_label (bool) -- label of a fake image + + Note: Do not use sigmoid as the last layer of Discriminator. + LSGAN needs no sigmoid. Vanilla GANs will handle it with + BCEWithLogitsLoss. + """ + super().__init__() + + # pylint: disable=not-callable + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + + self.gan_mode = gan_mode + self.reduction = reduction + + if gan_mode == 'lsgan': + self.loss = nn.MSELoss(reduction = reduction) + + elif gan_mode == 'vanilla': + self.loss = nn.BCEWithLogitsLoss(reduction = reduction) + + elif gan_mode == 'softplus': + self.loss = nn.Softplus() + + elif gan_mode == 'wgan': + self.loss = None + + else: + raise NotImplementedError('gan mode %s not implemented' % gan_mode) + + def get_target_tensor(self, prediction, target_is_real): + """Create label tensors with the same size as the input. + + Parameters: + prediction (tensor) -- tpyically the prediction from a + discriminator + target_is_real (bool) -- if the ground truth label is for real + images or fake images + + Returns: + A label tensor filled with ground truth label, and with the size of + the input + """ + + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(prediction) + + def eval_wgan_loss(self, prediction, target_is_real): + if target_is_real: + result = -prediction.mean() + else: + result = prediction.mean() + + return reduce_loss(result, self.reduction) + + def eval_softplus_loss(self, prediction, target_is_real): + if target_is_real: + result = self.loss(prediction) + else: + result = self.loss(-prediction) + + return reduce_loss(result, self.reduction) + + def forward(self, prediction, target_is_real): + """Calculate loss given Discriminator's output and grount truth labels. + + Parameters: + prediction (tensor) -- tpyically the prediction output from a + discriminator + target_is_real (bool) -- if the ground truth label is for real + images or fake images + + Returns: + the calculated loss. + """ + + if isinstance(prediction, (list, tuple)): + result = sum(self.forward(x, target_is_real) for x in prediction) + return result / len(prediction) + + if self.gan_mode == 'wgan': + return self.eval_wgan_loss(prediction, target_is_real) + + if self.gan_mode == 'softplus': + return self.eval_softplus_loss(prediction, target_is_real) + + target_tensor = self.get_target_tensor(prediction, target_is_real) + return self.loss(prediction, target_tensor) + +# pylint: disable=too-many-arguments +# pylint: disable=redefined-builtin +def cal_gradient_penalty( + netD, real_data, fake_data, device, + type = 'mixed', constant = 1.0, lambda_gp = 10.0 +): + """Calculate the gradient penalty loss, used in WGAN-GP + + source: https://arxiv.org/abs/1704.00028 + + Arguments: + netD (network) -- discriminator network + real_data (tensor array) -- real images + fake_data (tensor array) -- generated images from the generator + device (str) -- torch device + type (str) -- if we mix real and fake data or not + Choices: [real | fake | mixed]. + constant (float) -- the constant used in formula: + (||gradient||_2 - constant)^2 + lambda_gp (float) -- weight for this loss + + Returns the gradient penalty loss + """ + if lambda_gp == 0.0: + return 0.0, None + + if type == 'real': + interpolatesv = real_data + elif type == 'fake': + interpolatesv = fake_data + elif type == 'mixed': + alpha = torch.rand(real_data.shape[0], 1, device = device) + alpha = alpha.expand( + real_data.shape[0], real_data.nelement() // real_data.shape[0] + ).contiguous().view(*real_data.shape) + + interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) + else: + raise NotImplementedError('{} not implemented'.format(type)) + + interpolatesv.requires_grad_(True) + disc_interpolates = netD(interpolatesv) + + gradients = torch.autograd.grad( + outputs=disc_interpolates, inputs=interpolatesv, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True + ) + + gradients = gradients[0].view(real_data.size(0), -1) + + gradient_penalty = ( + ((gradients + 1e-16).norm(2, dim=1) - constant) ** 2 + ).mean() * lambda_gp + + return gradient_penalty, gradients + +def calc_zero_gp(model, x, **model_kwargs): + x.requires_grad_(True) + y = model(x, **model_kwargs) + + grad = torch.autograd.grad( + outputs = y, + inputs = x, + grad_outputs = torch.ones(y.size()).to(y.device), + create_graph = True, + retain_graph = True, + only_inputs = True + ) + + grad = grad[0].view(x.shape[0], -1) + # NOTE: 1/2 for backward compatibility + gp = 1/2 * torch.sum(grad.square(), dim = 1).mean() + + return gp, grad + diff --git a/uvcgan_s/base/networks.py b/uvcgan_s/base/networks.py new file mode 100644 index 0000000..e70024c --- /dev/null +++ b/uvcgan_s/base/networks.py @@ -0,0 +1,411 @@ +# pylint: disable=line-too-long +# pylint: disable=redefined-builtin +# pylint: disable=too-many-arguments +# pylint: disable=unidiomatic-typecheck +# pylint: disable=super-with-arguments + +import functools + +import torch +from torch import nn + +class Identity(nn.Module): + # pylint: disable=no-self-use + def forward(self, x): + return x + +def get_norm_layer(norm_type='instance'): + """Return a normalization layer + + Parameters: + norm_type (str) -- the name of the normalization layer: batch | instance | none + + For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). + For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. + """ + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + elif norm_type == 'none': + norm_layer = lambda _features : Identity() + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + + return norm_layer + +def join_args(a, b): + return { **a, **b } + +def select_base_generator(model, **kwargs): + default_args = dict(norm = 'instance', use_dropout = False, ngf = 64) + kwargs = join_args(default_args, kwargs) + + if model == 'resnet_9blocks': + return ResnetGenerator(n_blocks = 9, **kwargs) + + if model == 'resnet_6blocks': + return ResnetGenerator(n_blocks = 6, **kwargs) + + if model == 'unet_128': + return UnetGenerator(num_downs = 7, **kwargs) + + if model == 'unet_256': + return UnetGenerator(num_downs = 8, **kwargs) + + raise ValueError("Unknown generator: %s" % model) + +def select_base_discriminator(model, **kwargs): + default_args = dict(norm = 'instance', ndf = 64) + kwargs = join_args(default_args, kwargs) + + if model == 'basic': + return NLayerDiscriminator(n_layers = 3, **kwargs) + + if model == 'n_layers': + return NLayerDiscriminator(**kwargs) + + if model == 'pixel': + return PixelDiscriminator(**kwargs) + + raise ValueError("Unknown discriminator: %s" % model) + +class ResnetGenerator(nn.Module): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. + + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, image_shape, ngf=64, norm = 'batch', use_dropout=False, n_blocks=6, padding_type='reflect'): + """Construct a Resnet-based generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero + """ + + assert n_blocks >= 0 + super().__init__() + + norm_layer = get_norm_layer(norm_type = norm) + + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(image_shape[0], ngf, kernel_size=7, padding=0, bias=use_bias), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + + mult = 2 ** n_downsampling + for i in range(n_blocks): # add ResNet blocks + + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] + + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1, + bias=use_bias), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(ngf, image_shape[0], kernel_size=7, padding=0)] + + if image_shape[0] == 3: + model.append(nn.Sigmoid()) + + self.model = nn.Sequential(*model) + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class ResnetBlock(nn.Module): + """Define a Resnet block""" + + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Initialize the Resnet block + + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super().__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + + # pylint: disable=no-self-use + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Construct a convolutional block. + + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding layer: reflect | replicate | zero + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + + Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) + """ + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + """Forward function (with skip connections)""" + out = x + self.conv_block(x) # add skip connections + return out + + +class UnetGenerator(nn.Module): + """Create a Unet-based generator""" + + def __init__(self, image_shape, num_downs, ngf=64, norm = 'batch', use_dropout=False): + """Construct a Unet generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetGenerator, self).__init__() + norm_layer = get_norm_layer(norm_type=norm) + + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for _i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionBlock(image_shape[0], ngf, input_nc=image_shape[0], submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + # pylint: disable=too-many-locals + super().__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv ] + + if outer_nc == 3: + up.append(nn.Sigmoid()) + + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: # add skip connections + return torch.cat([x, self.model(x)], 1) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__( + self, image_shape, ndf=64, n_layers=3, norm='batch', max_mult=8, + shrink_output = True, return_intermediate_activations = False + ): + # pylint: disable=too-many-locals + """Construct a PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + + norm_layer = get_norm_layer(norm_type = norm) + + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(image_shape[0], ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, max_mult) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, max_mult) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + self.model = nn.Sequential(*sequence) + self.shrink_conv = None + + if shrink_output: + self.shrink_conv = nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + + self._intermediate = return_intermediate_activations + + def forward(self, input): + """Standard forward.""" + z = self.model(input) + + if self.shrink_conv is None: + return z + + y = self.shrink_conv(z) + + if self._intermediate: + return (y, z) + + return y + +class PixelDiscriminator(nn.Module): + """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" + + def __init__(self, image_shape, ndf=64, norm='batch'): + """Construct a 1x1 PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + """ + super(PixelDiscriminator, self).__init__() + + norm_layer = get_norm_layer(norm_type=norm) + + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(image_shape[0], ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + """Standard forward.""" + return self.net(input) diff --git a/uvcgan_s/base/schedulers.py b/uvcgan_s/base/schedulers.py new file mode 100644 index 0000000..b133123 --- /dev/null +++ b/uvcgan_s/base/schedulers.py @@ -0,0 +1,61 @@ +from torch.optim import lr_scheduler +from uvcgan_s.torch.select import extract_name_kwargs + +def linear_scheduler(optimizer, epochs_warmup, epochs_anneal, verbose = True): + + def lambda_rule(epoch, epochs_warmup, epochs_anneal): + if epoch < epochs_warmup: + return 1.0 + + return 1.0 - (epoch - epochs_warmup) / (epochs_anneal + 1) + + lr_fn = lambda epoch : lambda_rule(epoch, epochs_warmup, epochs_anneal) + + return lr_scheduler.LambdaLR(optimizer, lr_fn, verbose = verbose) + +SCHED_DICT = { + 'step' : lr_scheduler.StepLR, + 'plateau' : lr_scheduler.ReduceLROnPlateau, + 'cosine' : lr_scheduler.CosineAnnealingLR, + 'cosine-restarts' : lr_scheduler.CosineAnnealingWarmRestarts, + 'constant' : lr_scheduler.ConstantLR, + # lr scheds below are for backward compatibility + 'linear' : linear_scheduler, + 'linear-v2' : lr_scheduler.LinearLR, + 'CosineAnnealingWarmRestarts' : lr_scheduler.CosineAnnealingWarmRestarts, +} + +def select_single_scheduler(optimizer, scheduler): + if scheduler is None: + return None + + name, kwargs = extract_name_kwargs(scheduler) + kwargs['verbose'] = True + + if name not in SCHED_DICT: + raise ValueError( + f"Unknown scheduler: '{name}'. Supported: {SCHED_DICT.keys()}" + ) + + return SCHED_DICT[name](optimizer, **kwargs) + +def select_scheduler(optimizer, scheduler, compose = False): + if scheduler is None: + return None + + if not isinstance(scheduler, (list, tuple)): + scheduler = [ scheduler, ] + + result = [ select_single_scheduler(optimizer, x) for x in scheduler ] + + if compose: + if len(result) == 1: + return result[0] + else: + return lr_scheduler.ChainedScheduler(result) + else: + return result + +def get_scheduler(optimizer, scheduler): + return select_scheduler(optimizer, scheduler, compose = True) + diff --git a/uvcgan_s/base/weight_init.py b/uvcgan_s/base/weight_init.py new file mode 100644 index 0000000..2f8ce81 --- /dev/null +++ b/uvcgan_s/base/weight_init.py @@ -0,0 +1,49 @@ +import logging +from torch.nn import init + +from uvcgan_s.torch.select import extract_name_kwargs + +LOGGER = logging.getLogger('uvcgan_s.base') + +def winit_func(m, init_type = 'normal', init_gain = 0.2): + classname = m.__class__.__name__ + + if ( + hasattr(m, 'weight') + and (classname.find('Conv') != -1 or classname.find('Linear') != -1) + ): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain = init_gain) + + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in') + + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain = init_gain) + + else: + raise NotImplementedError( + 'Initialization method [%s] is not implemented' % init_type + ) + + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + +def init_weights(net, weight_init): + if weight_init is None: + return + + name, kwargs = extract_name_kwargs(weight_init) + + LOGGER.debug('Initializnig network with %s', name) + net.apply( + lambda m, name=name, kwargs=kwargs : winit_func(m, name, **kwargs) + ) + diff --git a/uvcgan_s/cgan/__init__.py b/uvcgan_s/cgan/__init__.py new file mode 100644 index 0000000..21eeffa --- /dev/null +++ b/uvcgan_s/cgan/__init__.py @@ -0,0 +1,30 @@ +from .cyclegan import CycleGANModel +from .pix2pix import Pix2PixModel +from .autoencoder import Autoencoder +from .simple_autoencoder import SimpleAutoencoder +from .uvcgan2 import UVCGAN2 +from .uvcgan_s import UVCGAN_S + +CGAN_MODELS = { + 'cyclegan' : CycleGANModel, + 'pix2pix' : Pix2PixModel, + 'autoencoder' : Autoencoder, + 'simple-autoencoder' : SimpleAutoencoder, + 'uvcgan-v2' : UVCGAN2, + 'uvcgan-s' : UVCGAN_S, +} + +def select_model(name, **kwargs): + if name not in CGAN_MODELS: + raise ValueError("Unknown model: %s" % name) + + return CGAN_MODELS[name](**kwargs) + +def construct_model(savedir, config, is_train, device): + model = select_model( + config.model, savedir = savedir, config = config, is_train = is_train, + device = device, **config.model_args + ) + + return model + diff --git a/uvcgan_s/cgan/autoencoder.py b/uvcgan_s/cgan/autoencoder.py new file mode 100644 index 0000000..f3bcf88 --- /dev/null +++ b/uvcgan_s/cgan/autoencoder.py @@ -0,0 +1,166 @@ +# pylint: disable=not-callable +# NOTE: Mistaken lint: +# E1102: self.encoder is not callable (not-callable) +from uvcgan_s.torch.select import select_optimizer, select_loss +from uvcgan_s.torch.background_penalty import BackgroundPenaltyReduction +from uvcgan_s.torch.image_masking import select_masking +from uvcgan_s.models.generator import construct_generator + +from .model_base import ModelBase +from .named_dict import NamedDict +from .funcs import set_two_domain_input + +class Autoencoder(ModelBase): + + def _setup_images(self, _config): + images = [ 'real_a', 'reco_a', 'real_b', 'reco_b', ] + + if self.masking is not None: + images += [ 'masked_a', 'masked_b' ] + + return NamedDict(*images) + + def _setup_models(self, config): + if self.joint: + image_shape = config.data.datasets[0].shape + + assert image_shape == config.data.datasets[1].shape, ( + "Joint autoencoder requires all datasets to have " + "the same image shape" + ) + + return NamedDict( + encoder = construct_generator( + config.generator, image_shape, image_shape, self.device + ) + ) + + models = NamedDict('encoder_a', 'encoder_b') + models.encoder_a = construct_generator( + config.generator, + config.data.datasets[0].shape, + config.data.datasets[0].shape, + self.device + ) + models.encoder_b = construct_generator( + config.generator, + config.data.datasets[1].shape, + config.data.datasets[1].shape, + self.device + ) + + return models + + def _setup_losses(self, config): + self.loss_fn = select_loss(config.loss) + + assert config.gradient_penalty is None, \ + "Autoencoder model does not support gradient penalty" + + return NamedDict('loss_a', 'loss_b') + + def _setup_optimizers(self, config): + if self.joint: + return NamedDict( + encoder = select_optimizer( + self.models.encoder.parameters(), + config.generator.optimizer + ) + ) + + optimizers = NamedDict('encoder_a', 'encoder_b') + + optimizers.encoder_a = select_optimizer( + self.models.encoder_a.parameters(), config.generator.optimizer + ) + optimizers.encoder_b = select_optimizer( + self.models.encoder_b.parameters(), config.generator.optimizer + ) + + return optimizers + + def __init__( + self, savedir, config, is_train, device, + joint = False, background_penalty = None, masking = None + ): + # pylint: disable=too-many-arguments + self.joint = joint + self.masking = select_masking(masking) + + assert len(config.data.datasets) == 2, \ + "Autoencoder expects a pair of datasets" + + super().__init__(savedir, config, is_train, device) + + if background_penalty is None: + self.background_penalty = None + else: + self.background_penalty = BackgroundPenaltyReduction( + **background_penalty + ) + + assert config.discriminator is None, \ + "Autoencoder model does not use discriminator" + + def _handle_epoch_end(self): + if self.background_penalty is not None: + self.background_penalty.end_epoch(self.epoch) + + def _set_input(self, inputs, domain): + set_two_domain_input(self.images, inputs, domain, self.device) + + def forward(self): + input_a = self.images.real_a + input_b = self.images.real_b + + if self.masking is not None: + if input_a is not None: + input_a = self.masking(input_a) + + if input_b is not None: + input_b = self.masking(input_b) + + self.images.masked_a = input_a + self.images.masked_b = input_b + + if input_a is not None: + if self.joint: + self.images.reco_a = self.models.encoder (input_a) + else: + self.images.reco_a = self.models.encoder_a(input_a) + + if input_b is not None: + if self.joint: + self.images.reco_b = self.models.encoder (input_b) + else: + self.images.reco_b = self.models.encoder_b(input_b) + + def backward_generator_base(self, real, reco): + if self.background_penalty is not None: + reco = self.background_penalty(reco, real) + + loss = self.loss_fn(reco, real) + loss.backward() + + return loss + + def backward_generators(self): + self.losses.loss_b = self.backward_generator_base( + self.images.real_b, self.images.reco_b + ) + + self.losses.loss_a = self.backward_generator_base( + self.images.real_a, self.images.reco_a + ) + + def optimization_step(self): + self.forward() + + for optimizer in self.optimizers.values(): + optimizer.zero_grad() + + self.backward_generators() + + for optimizer in self.optimizers.values(): + optimizer.step() + diff --git a/uvcgan_s/cgan/checkpoint.py b/uvcgan_s/cgan/checkpoint.py new file mode 100644 index 0000000..70c6a77 --- /dev/null +++ b/uvcgan_s/cgan/checkpoint.py @@ -0,0 +1,73 @@ +import os +import re +import torch + +CHECKPOINTS_DIR = 'checkpoints' + +def find_last_checkpoint_epoch(savedir, prefix = None): + root = os.path.join(savedir, CHECKPOINTS_DIR) + if not os.path.exists(root): + return -1 + + if prefix is None: + r = re.compile(r'(\d+)_.*') + else: + r = re.compile(r'(\d+)_' + re.escape(prefix) + '_.*') + + last_epoch = -1 + + for fname in os.listdir(root): + m = r.match(fname) + if m: + epoch = int(m.groups()[0]) + last_epoch = max(last_epoch, epoch) + + return last_epoch + +def get_save_path(savedir, name, epoch, mkdir = False): + if epoch is None: + fname = '%s.pth' % (name) + root = savedir + else: + fname = '%04d_%s.pth' % (epoch, name) + root = os.path.join(savedir, CHECKPOINTS_DIR) + + result = os.path.join(root, fname) + + if mkdir: + os.makedirs(root, exist_ok = True) + + return result + +def save(named_dict, savedir, prefix, epoch = None): + for (k,v) in named_dict.items(): + if v is None: + continue + + save_path = get_save_path( + savedir, prefix + '_' + k, epoch, mkdir = True + ) + + if isinstance(v, torch.nn.DataParallel): + torch.save(v.module.state_dict(), save_path) + else: + torch.save(v.state_dict(), save_path) + +def load(named_dict, savedir, prefix, epoch, device): + for (k,v) in named_dict.items(): + if v is None: + continue + + load_path = get_save_path( + savedir, prefix + '_' + k, epoch, mkdir = False + ) + + if isinstance(v, torch.nn.DataParallel): + v.module.load_state_dict( + torch.load(load_path, map_location = device) + ) + else: + v.load_state_dict( + torch.load(load_path, map_location = device) + ) + diff --git a/uvcgan_s/cgan/cyclegan.py b/uvcgan_s/cgan/cyclegan.py new file mode 100644 index 0000000..295674f --- /dev/null +++ b/uvcgan_s/cgan/cyclegan.py @@ -0,0 +1,226 @@ +# pylint: disable=not-callable +# NOTE: Mistaken lint: +# E1102: self.criterion_gan is not callable (not-callable) + +import itertools +import torch + +from uvcgan_s.torch.select import select_optimizer +from uvcgan_s.base.image_pool import ImagePool +from uvcgan_s.base.losses import GANLoss, cal_gradient_penalty +from uvcgan_s.models.discriminator import construct_discriminator +from uvcgan_s.models.generator import construct_generator + +from .model_base import ModelBase +from .named_dict import NamedDict +from .funcs import set_two_domain_input + +class CycleGANModel(ModelBase): + # pylint: disable=too-many-instance-attributes + + def _setup_images(self, _config): + images = [ 'real_a', 'fake_b', 'reco_a', 'real_b', 'fake_a', 'reco_b' ] + + if self.is_train and self.lambda_idt > 0: + images += [ 'idt_a', 'idt_b' ] + + return NamedDict(*images) + + def _setup_models(self, config): + models = {} + + models['gen_ab'] = construct_generator( + config.generator, + config.data.datasets[0].shape, + config.data.datasets[1].shape, + self.device + ) + models['gen_ba'] = construct_generator( + config.generator, + config.data.datasets[1].shape, + config.data.datasets[0].shape, + self.device + ) + + if self.is_train: + models['disc_a'] = construct_discriminator( + config.discriminator, + config.data.datasets[0].shape, + self.device + ) + models['disc_b'] = construct_discriminator( + config.discriminator, + config.data.datasets[1].shape, + self.device + ) + + return NamedDict(**models) + + def _setup_losses(self, config): + losses = [ + 'gen_ab', 'gen_ba', 'cycle_a', 'cycle_b', 'disc_a', 'disc_b' + ] + + if self.is_train and self.lambda_idt > 0: + losses += [ 'idt_a', 'idt_b' ] + + return NamedDict(*losses) + + def _setup_optimizers(self, config): + optimizers = NamedDict('gen', 'disc') + + optimizers.gen = select_optimizer( + itertools.chain( + self.models.gen_ab.parameters(), + self.models.gen_ba.parameters() + ), + config.generator.optimizer + ) + + optimizers.disc = select_optimizer( + itertools.chain( + self.models.disc_a.parameters(), + self.models.disc_b.parameters() + ), + config.discriminator.optimizer + ) + + return optimizers + + def __init__( + self, savedir, config, is_train, device, pool_size = 50, + lambda_a = 10.0, lambda_b = 10.0, lambda_idt = 0.5 + ): + # pylint: disable=too-many-arguments + self.lambda_a = lambda_a + self.lambda_b = lambda_b + self.lambda_idt = lambda_idt + + assert len(config.data.datasets) == 2, \ + "CycleGAN expects a pair of datasets" + + super().__init__(savedir, config, is_train, device) + + self.criterion_gan = GANLoss(config.loss).to(self.device) + self.gradient_penalty = config.gradient_penalty + self.criterion_cycle = torch.nn.L1Loss() + self.criterion_idt = torch.nn.L1Loss() + + if self.is_train: + self.pred_a_pool = ImagePool(pool_size) + self.pred_b_pool = ImagePool(pool_size) + + def _set_input(self, inputs, domain): + set_two_domain_input(self.images, inputs, domain, self.device) + + def forward(self): + def simple_fwd(batch, gen_fwd, gen_bkw): + if batch is None: + return (None, None) + + fake = gen_fwd(batch) + reco = gen_bkw(fake) + + return (fake, reco) + + self.images.fake_b, self.images.reco_a = simple_fwd( + self.images.real_a, self.models.gen_ab, self.models.gen_ba + ) + + self.images.fake_a, self.images.reco_b = simple_fwd( + self.images.real_b, self.models.gen_ba, self.models.gen_ab + ) + + def backward_discriminator_base(self, model, real, fake): + pred_real = model(real) + loss_real = self.criterion_gan(pred_real, True) + + # + # NOTE: + # This is a workaround to a pytorch 1.9.0 bug that manifests when + # cudnn is enabled. When the bug is solved remove no_grad block and + # replace `model(fake)` by `model(fake.detach())`. + # + # bug: https://github.com/pytorch/pytorch/issues/48439 + # + with torch.no_grad(): + fake = fake.contiguous() + + pred_fake = model(fake) + loss_fake = self.criterion_gan(pred_fake, False) + + loss = (loss_real + loss_fake) * 0.5 + + if self.gradient_penalty is not None: + loss += cal_gradient_penalty( + model, real, fake, real.device, **self.gradient_penalty + )[0] + + loss.backward() + return loss + + def backward_discriminators(self): + fake_a = self.pred_a_pool.query(self.images.fake_a) + fake_b = self.pred_b_pool.query(self.images.fake_b) + + self.losses.disc_b = self.backward_discriminator_base( + self.models.disc_b, self.images.real_b, fake_b + ) + + self.losses.disc_a = self.backward_discriminator_base( + self.models.disc_a, self.images.real_a, fake_a + ) + + def backward_generators(self): + lambda_idt = self.lambda_idt + lambda_a = self.lambda_a + lambda_b = self.lambda_b + + self.losses.gen_ab = self.criterion_gan( + self.models.disc_b(self.images.fake_b), True + ) + self.losses.gen_ba = self.criterion_gan( + self.models.disc_a(self.images.fake_a), True + ) + self.losses.cycle_a = lambda_a * self.criterion_cycle( + self.images.reco_a, self.images.real_a + ) + self.losses.cycle_b = lambda_b * self.criterion_cycle( + self.images.reco_b, self.images.real_b + ) + + loss = ( + self.losses.gen_ab + self.losses.gen_ba + + self.losses.cycle_a + self.losses.cycle_b + ) + + if lambda_idt > 0: + self.images.idt_b = self.models.gen_ab(self.images.real_b) + self.losses.idt_b = lambda_b * lambda_idt * self.criterion_idt( + self.images.idt_b, self.images.real_b + ) + + self.images.idt_a = self.models.gen_ba(self.images.real_a) + self.losses.idt_a = lambda_a * lambda_idt * self.criterion_idt( + self.images.idt_a, self.images.real_a + ) + + loss += (self.losses.idt_a + self.losses.idt_b) + + loss.backward() + + def optimization_step(self): + self.forward() + + # Generators + self.set_requires_grad([self.models.disc_a, self.models.disc_b], False) + self.optimizers.gen.zero_grad() + self.backward_generators() + self.optimizers.gen.step() + + # Discriminators + self.set_requires_grad([self.models.disc_a, self.models.disc_b], True) + self.optimizers.disc.zero_grad() + self.backward_discriminators() + self.optimizers.disc.step() + diff --git a/uvcgan_s/cgan/funcs.py b/uvcgan_s/cgan/funcs.py new file mode 100644 index 0000000..e042d04 --- /dev/null +++ b/uvcgan_s/cgan/funcs.py @@ -0,0 +1,55 @@ +import torch + +def set_two_domain_input(images, inputs, domain, device): + if (domain is None) or (domain == 'both'): + images.real_a = inputs[0].to(device, non_blocking = True) + images.real_b = inputs[1].to(device, non_blocking = True) + + elif domain in [ 'a', 0 ]: + images.real_a = inputs.to(device, non_blocking = True) + + elif domain in [ 'b', 1 ]: + images.real_b = inputs.to(device, non_blocking = True) + + else: + raise ValueError( + f"Unknown domain: '{domain}'." + " Supported domains: 'a' (alias 0), 'b' (alias 1), or 'both'" + ) + +def set_asym_two_domain_input(images, inputs, domain, device): + if (domain is None) or (domain == 'all'): + images.real_a0 = inputs[0].to(device, non_blocking = True) + images.real_a1 = inputs[1].to(device, non_blocking = True) + images.real_b = inputs[2].to(device, non_blocking = True) + + elif domain in [ 'a0', 0 ]: + images.real_a0 = inputs.to(device, non_blocking = True) + + elif domain in [ 'a1', 1 ]: + images.real_a1 = inputs.to(device, non_blocking = True) + + elif domain in [ 'b', 2 ]: + images.real_b = inputs.to(device, non_blocking = True) + + else: + raise ValueError( + f"Unknown domain: '{domain}'." + " Supported domains: 'a0' (alias 0), 'a1' (alias 1), " + "'b' (alias 2), or 'all'" + ) + +def trace_models(models, input_shapes, device): + result = {} + + for (name, model) in models.items(): + if name not in input_shapes: + continue + + shape = input_shapes[name] + data = torch.randn((1, *shape)).to(device) + + result[name] = torch.jit.trace(model, data) + + return result + diff --git a/uvcgan_s/cgan/model_base.py b/uvcgan_s/cgan/model_base.py new file mode 100644 index 0000000..bbae864 --- /dev/null +++ b/uvcgan_s/cgan/model_base.py @@ -0,0 +1,176 @@ +import logging +import torch +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from uvcgan_s.base.schedulers import get_scheduler +from .named_dict import NamedDict +from .checkpoint import find_last_checkpoint_epoch, save, load + +PREFIX_MODEL = 'net' +PREFIX_OPT = 'opt' +PREFIX_SCHED = 'sched' + +LOGGER = logging.getLogger('uvcgan_s.cgan') + +class ModelBase: + # pylint: disable=too-many-instance-attributes + + def __init__(self, savedir, config, is_train, device): + self.is_train = is_train + self.device = device + self.savedir = savedir + self._config = config + + self.models = self._setup_models(config) + self.images = self._setup_images(config) + self.losses = self._setup_losses(config) + self.metric = 0 + self.epoch = 0 + + self.optimizers = NamedDict() + self.schedulers = NamedDict() + + if is_train: + self.optimizers = self._setup_optimizers(config) + self.schedulers = self._setup_schedulers(config) + + def set_input(self, inputs, domain = None): + for key in self.images: + self.images[key] = None + + self._set_input(inputs, domain) + + def forward(self): + raise NotImplementedError + + def optimization_step(self): + raise NotImplementedError + + def _set_input(self, inputs, domain): + raise NotImplementedError + + def _setup_images(self, config): + raise NotImplementedError + + def _setup_models(self, config): + raise NotImplementedError + + def _setup_losses(self, config): + raise NotImplementedError + + def _setup_optimizers(self, config): + raise NotImplementedError + + def _setup_schedulers(self, config): + schedulers = { } + + for (name, opt) in self.optimizers.items(): + schedulers[name] = get_scheduler(opt, config.scheduler) + + return NamedDict(**schedulers) + + def _save_model_state(self, epoch): + pass + + def _load_model_state(self, epoch): + pass + + def _handle_epoch_end(self): + pass + + def eval(self): + self.is_train = False + + for model in self.models.values(): + model.eval() + + def train(self): + self.is_train = True + + for model in self.models.values(): + model.train() + + def forward_nograd(self): + with torch.no_grad(): + self.forward() + + def find_last_checkpoint_epoch(self): + return find_last_checkpoint_epoch(self.savedir, PREFIX_MODEL) + + def load(self, epoch): + if (epoch is not None) and (epoch <= 0): + return + + LOGGER.debug('Loading model from epoch %s', epoch) + + load(self.models, self.savedir, PREFIX_MODEL, epoch, self.device) + load(self.optimizers, self.savedir, PREFIX_OPT, epoch, self.device) + load(self.schedulers, self.savedir, PREFIX_SCHED, epoch, self.device) + + self.epoch = epoch + self._load_model_state(epoch) + self._handle_epoch_end() + + def save(self, epoch = None): + LOGGER.debug('Saving model at epoch %s', epoch) + + save(self.models, self.savedir, PREFIX_MODEL, epoch) + save(self.optimizers, self.savedir, PREFIX_OPT, epoch) + save(self.schedulers, self.savedir, PREFIX_SCHED, epoch) + + self._save_model_state(epoch) + + def end_epoch(self, epoch = None): + for scheduler in self.schedulers.values(): + if scheduler is None: + continue + + if isinstance(scheduler, ReduceLROnPlateau): + scheduler.step(self.metric) + else: + scheduler.step() + + self._handle_epoch_end() + + if epoch is None: + self.epoch = self.epoch + 1 + else: + self.epoch = epoch + + def pprint(self, verbose): + for name,model in self.models.items(): + num_params = 0 + + for param in model.parameters(): + num_params += param.numel() + + if verbose: + print(model) + + print( + '[Network %s] Total number of parameters : %.3f M' % ( + name, num_params / 1e6 + ) + ) + + def set_requires_grad(self, models, requires_grad = False): + # pylint: disable=no-self-use + if not isinstance(models, list): + models = [models, ] + + for model in models: + for param in model.parameters(): + param.requires_grad = requires_grad + + def get_current_losses(self): + result = {} + + for (k,v) in self.losses.items(): + result[k] = float(v) + + return result + + def get_images(self): + return self.images + + diff --git a/uvcgan_s/cgan/named_dict.py b/uvcgan_s/cgan/named_dict.py new file mode 100644 index 0000000..02c0180 --- /dev/null +++ b/uvcgan_s/cgan/named_dict.py @@ -0,0 +1,48 @@ +from collections.abc import Mapping + +class NamedDict(Mapping): + # pylint: disable=too-many-instance-attributes + + _fields = None + + def __init__(self, *args, **kwargs): + self._fields = {} + + for arg in args: + self._fields[arg] = None + + self._fields.update(**kwargs) + + def __contains__(self, key): + return (key in self._fields) + + def __getitem__(self, key): + return self._fields[key] + + def __setitem__(self, key, value): + self._fields[key] = value + + def __getattr__(self, key): + return self._fields[key] + + def __setattr__(self, key, value): + if (self._fields is not None) and (key in self._fields): + self._fields[key] = value + else: + super().__setattr__(key, value) + + def __iter__(self): + return iter(self._fields) + + def __len__(self): + return len(self._fields) + + def items(self): + return self._fields.items() + + def keys(self): + return self._fields.keys() + + def values(self): + return self._fields.values() + diff --git a/uvcgan_s/cgan/pix2pix.py b/uvcgan_s/cgan/pix2pix.py new file mode 100644 index 0000000..ac97b41 --- /dev/null +++ b/uvcgan_s/cgan/pix2pix.py @@ -0,0 +1,166 @@ +# pylint: disable=not-callable +# NOTE: Mistaken lint: +# E1102: self.criterion_gan is not callable (not-callable) + +import torch + +from uvcgan_s.torch.select import select_optimizer +from uvcgan_s.base.losses import GANLoss, cal_gradient_penalty +from uvcgan_s.models.discriminator import construct_discriminator +from uvcgan_s.models.generator import construct_generator + +from .model_base import ModelBase +from .named_dict import NamedDict +from .funcs import set_two_domain_input + +class Pix2PixModel(ModelBase): + + def _setup_images(self, _config): + return NamedDict('real_a', 'fake_b', 'real_b', 'fake_a') + + def _setup_models(self, config): + models = { } + + image_shape_a = config.data.datasets[0].shape + image_shape_b = config.data.datasets[1].shape + + assert image_shape_a[1:] == image_shape_b[1:], \ + "Pix2Pix needs images in both domains to have the same size" + + models['gen_ab'] = construct_generator( + config.generator, image_shape_a, image_shape_b, self.device + ) + models['gen_ba'] = construct_generator( + config.generator, image_shape_b, image_shape_a, self.device + ) + + if self.is_train: + extended_image_shape = ( + image_shape_a[0] + image_shape_b[0], *image_shape_a[1:] + ) + + for name in [ 'disc_a', 'disc_b' ]: + models[name] = construct_discriminator( + config.discriminator, extended_image_shape, self.device + ) + + return NamedDict(**models) + + def _setup_losses(self, config): + return NamedDict( + 'gen_ab', 'gen_ba', 'l1_ab', 'l1_ba', 'disc_a', 'disc_b' + ) + + def _setup_optimizers(self, config): + optimizers = NamedDict('gen_ab', 'gen_ba', 'disc_a', 'disc_b') + + optimizers.gen_ab = select_optimizer( + self.models.gen_ab.parameters(), config.generator.optimizer + ) + optimizers.gen_ba = select_optimizer( + self.models.gen_ba.parameters(), config.generator.optimizer + ) + + optimizers.disc_a = select_optimizer( + self.models.disc_a.parameters(), config.discriminator.optimizer + ) + optimizers.disc_b = select_optimizer( + self.models.disc_b.parameters(), config.discriminator.optimizer + ) + + return optimizers + + def __init__(self, savedir, config, is_train, device): + super().__init__(savedir, config, is_train, device) + + assert len(config.data.datasets) == 2, \ + "Pix2Pix expects a pair of datasets" + + self.criterion_gan = GANLoss(config.loss).to(self.device) + self.criterion_l1 = torch.nn.L1Loss() + self.gradient_penalty = config.gradient_penalty + + def _set_input(self, inputs, domain): + set_two_domain_input(self.images, inputs, domain, self.device) + + def forward(self): + if self.images.real_a is not None: + self.images.fake_b = self.models.gen_ab(self.images.real_a) + + if self.images.real_b is not None: + self.images.fake_a = self.models.gen_ba(self.images.real_b) + + def backward_discriminator_base(self, model, real, fake, preimage): + cond_real = torch.cat([real, preimage], dim = 1) + cond_fake = torch.cat([fake, preimage], dim = 1).detach() + + pred_real = model(cond_real) + loss_real = self.criterion_gan(pred_real, True) + + pred_fake = model(cond_fake) + loss_fake = self.criterion_gan(pred_fake, False) + + loss = (loss_real + loss_fake) * 0.5 + + if self.gradient_penalty is not None: + loss += cal_gradient_penalty( + model, cond_real, cond_fake, real.device, + **self.gradient_penalty + )[0] + + loss.backward() + return loss + + def backward_discriminators(self): + self.losses.disc_b = self.backward_discriminator_base( + self.models.disc_b, + self.images.real_b, self.images.fake_b, self.images.real_a + ) + + self.losses.disc_a = self.backward_discriminator_base( + self.models.disc_a, + self.images.real_a, self.images.fake_a, self.images.real_b + ) + + def backward_generator_base(self, disc, real, fake, preimage): + loss_gen = self.criterion_gan( + disc(torch.cat([fake, preimage], dim = 1)), True + ) + + loss_l1 = self.criterion_l1(fake, real) + + loss = loss_gen + loss_l1 + loss.backward() + + return (loss_gen, loss_l1) + + def backward_generators(self): + self.losses.gen_ab, self.losses.l1_ab = self.backward_generator_base( + self.models.disc_b, + self.images.real_b, self.images.fake_b, self.images.real_a + ) + + self.losses.gen_ba, self.losses.l1_ba = self.backward_generator_base( + self.models.disc_a, + self.images.real_a, self.images.fake_a, self.images.real_b + ) + + def optimization_step(self): + self.forward() + + # Generators + self.set_requires_grad([self.models.disc_a, self.models.disc_b], False) + self.optimizers.gen_ab.zero_grad() + self.optimizers.gen_ba.zero_grad() + self.backward_generators() + self.optimizers.gen_ab.step() + self.optimizers.gen_ba.step() + + # Discriminators + self.set_requires_grad([self.models.disc_a, self.models.disc_b], True) + self.optimizers.disc_a.zero_grad() + self.optimizers.disc_b.zero_grad() + self.backward_discriminators() + self.optimizers.disc_a.step() + self.optimizers.disc_b.step() + diff --git a/uvcgan_s/cgan/simple_autoencoder.py b/uvcgan_s/cgan/simple_autoencoder.py new file mode 100644 index 0000000..bf252a3 --- /dev/null +++ b/uvcgan_s/cgan/simple_autoencoder.py @@ -0,0 +1,111 @@ +from uvcgan_s.torch.select import select_optimizer, select_loss +from uvcgan_s.torch.background_penalty import BackgroundPenaltyReduction +from uvcgan_s.torch.image_masking import select_masking +from uvcgan_s.models.generator import construct_generator + +from .model_base import ModelBase +from .named_dict import NamedDict + +class SimpleAutoencoder(ModelBase): + """Model that tries to train an autoencoder (i.e. target == input). + + This autoencoder expects inputs to be either tuples of the form + `(features, target)` or the `features` itself. + """ + + def _setup_images(self, _config): + images = [ 'real', 'reco' ] + + if self.masking is not None: + images.append('masked') + + return NamedDict(*images) + + def _setup_models(self, config): + return NamedDict( + encoder = construct_generator( + config.generator, + config.data.datasets[0].shape, + config.data.datasets[0].shape, + self.device + ) + ) + + def _setup_losses(self, config): + self.loss_fn = select_loss(config.loss) + + assert config.gradient_penalty is None, \ + "Autoencoder model does not support gradient penalty" + + return NamedDict('loss') + + def _setup_optimizers(self, config): + return NamedDict( + encoder = select_optimizer( + self.models.encoder.parameters(), config.generator.optimizer + ) + ) + + def __init__( + self, savedir, config, is_train, device, + background_penalty = None, masking = None + ): + # pylint: disable=too-many-arguments + self.masking = select_masking(masking) + assert len(config.data.datasets) == 1, \ + "Simple Autoencoder can work only with a single dataset" + + super().__init__(savedir, config, is_train, device) + + if background_penalty is None: + self.background_penalty = None + else: + self.background_penalty = BackgroundPenaltyReduction( + **background_penalty + ) + + assert config.discriminator is None, \ + "Autoencoder model does not use discriminator" + + def _handle_epoch_end(self): + if self.background_penalty is not None: + self.background_penalty.end_epoch(self.epoch) + + def _set_input(self, inputs, _domain): + # inputs : image or (image, label) + if isinstance(inputs, (list, tuple)): + self.images.real = inputs[0].to(self.device) + else: + self.images.real = inputs.to(self.device) + + def forward(self): + if self.masking is None: + input_img = self.images.real + else: + self.images.masked = self.masking(self.images.real) + input_img = self.images.masked + + self.images.reco = self.models.encoder(input_img) + + def backward(self): + if self.background_penalty is not None: + reco = self.background_penalty(self.images.reco, self.images.real) + else: + reco = self.images.reco + + loss = self.loss_fn(reco, self.images.real) + loss.backward() + + self.losses.loss = loss + + def optimization_step(self): + self.forward() + + for optimizer in self.optimizers.values(): + optimizer.zero_grad() + + self.backward() + + for optimizer in self.optimizers.values(): + optimizer.step() + diff --git a/uvcgan_s/cgan/uvcgan2.py b/uvcgan_s/cgan/uvcgan2.py new file mode 100644 index 0000000..ebaf40a --- /dev/null +++ b/uvcgan_s/cgan/uvcgan2.py @@ -0,0 +1,537 @@ +# pylint: disable=not-callable +# NOTE: Mistaken lint: +# E1102: self.criterion_gan is not callable (not-callable) + +import itertools +import torch + +from torchvision.transforms import GaussianBlur, Resize + +from uvcgan_s.torch.select import select_optimizer, extract_name_kwargs +from uvcgan_s.torch.gan_losses import select_gan_loss +from uvcgan_s.torch.queue import FastQueue +from uvcgan_s.torch.funcs import prepare_model, update_average_model + +from uvcgan_s.torch.layers.batch_head import BatchHeadWrapper, get_batch_head +from uvcgan_s.torch.gradient_penalty import GradientPenalty +from uvcgan_s.torch.gradient_cacher import GradientCacher +from uvcgan_s.torch.image_masking import select_masking +from uvcgan_s.models.discriminator import construct_discriminator +from uvcgan_s.models.generator import construct_generator + +from .model_base import ModelBase +from .named_dict import NamedDict +from .funcs import set_two_domain_input + +def construct_consistency_model(consist, device): + name, kwargs = extract_name_kwargs(consist) + + if name == 'blur': + return GaussianBlur(**kwargs).to(device) + + if name == 'resize': + return Resize(**kwargs).to(device) + + raise ValueError(f'Unknown consistency type: {name}') + + +def queued_forward(batch_head_model, input_image, queue, update_queue = True): + output, pred_body = batch_head_model.forward( + input_image, extra_bodies = queue.query(), return_body = True + ) + + if update_queue: + queue.push(pred_body) + + return output + +def init_hc(image, n_hidden): + return torch.zeros( + (image.shape[0], n_hidden, *image.shape[2:]), + dtype = image.dtype, device = image.device + ) + +class UVCGAN2(ModelBase): + # pylint: disable=too-many-instance-attributes + + def _setup_images(self, _config): + images = [ + 'real_a', 'real_b', + 'fake_a', 'fake_b', + 'reco_a', 'reco_b', + 'real_hc_a', 'real_hc_b', + 'fake_hc_a', 'fake_hc_b', + 'reco_hc_a', 'reco_hc_b', + 'consist_real_a', 'consist_real_b', + 'consist_fake_a', 'consist_fake_b', + ] + + if self.is_train and self.lambda_idt > 0: + images += [ + 'idt_input_a', 'idt_input_b', + 'idt_a', 'idt_b', 'idt_hc_a', 'idt_hc_b' + ] + + return NamedDict(*images) + + def _construct_batch_head_disc(self, model_config, input_shape): + disc_body = construct_discriminator( + model_config, input_shape, self.device + ) + + disc_head = get_batch_head(self.head_config) + disc_head = prepare_model(disc_head, self.device) + + return BatchHeadWrapper(disc_body, disc_head) + + def _add_hc_to_shape(self, shape): + return (shape[0] + self.n_hidden, *shape[1:]) + + def _setup_models(self, config): + models = {} + + shape_a = self._add_hc_to_shape(config.data.datasets[0].shape) + shape_b = self._add_hc_to_shape(config.data.datasets[1].shape) + + models['gen_ab'] = construct_generator( + config.generator, shape_a, shape_b, self.device + ) + models['gen_ba'] = construct_generator( + config.generator, shape_b, shape_a, self.device + ) + + if self.avg_momentum is not None: + models['avg_gen_ab'] = construct_generator( + config.generator, shape_a, shape_b, self.device + ) + models['avg_gen_ba'] = construct_generator( + config.generator, shape_b, shape_a, self.device + ) + + models['avg_gen_ab'].load_state_dict(models['gen_ab'].state_dict()) + models['avg_gen_ba'].load_state_dict(models['gen_ba'].state_dict()) + + if self.is_train: + models['disc_a'] = self._construct_batch_head_disc( + config.discriminator, config.data.datasets[0].shape + ) + models['disc_b'] = self._construct_batch_head_disc( + config.discriminator, config.data.datasets[1].shape + ) + + return NamedDict(**models) + + def _setup_losses(self, config): + losses = [ + 'gen_ab', 'gen_ba', 'cycle_a', 'cycle_b', 'disc_a', 'disc_b' + ] + + if self.is_train and self.lambda_idt > 0: + losses += [ 'idt_a', 'idt_b' ] + + if self.is_train and config.gradient_penalty is not None: + losses += [ 'gp_a', 'gp_b' ] + + if self.consist_model is not None: + losses += [ 'consist_a', 'consist_b' ] + + return NamedDict(*losses) + + def _setup_optimizers(self, config): + optimizers = NamedDict('gen', 'disc') + + optimizers.gen = select_optimizer( + itertools.chain( + self.models.gen_ab.parameters(), + self.models.gen_ba.parameters() + ), + config.generator.optimizer + ) + + optimizers.disc = select_optimizer( + itertools.chain( + self.models.disc_a.parameters(), + self.models.disc_b.parameters() + ), + config.discriminator.optimizer + ) + + return optimizers + + def __init__( + self, savedir, config, is_train, device, head_config = None, + lambda_a = 10.0, + lambda_b = 10.0, + lambda_idt = 0.5, + lambda_consist = 0, + n_hidden = 0, + head_queue_size = 3, + avg_momentum = None, + gp_cache_period = 1, + consistency = None, + masking = None, + ): + # pylint: disable=too-many-arguments + # pylint: disable=too-many-locals + self.lambda_a = lambda_a + self.lambda_b = lambda_b + self.lambda_idt = lambda_idt + self.lambda_consist = lambda_consist + self.avg_momentum = avg_momentum + self.n_hidden = n_hidden + self.head_config = head_config or {} + self.consist_model = None + self.masking = select_masking(masking) + + if (lambda_consist > 0) and (consistency is not None): + self.consist_model \ + = construct_consistency_model(consistency, device) + + assert len(config.data.datasets) == 2, \ + "CycleGAN expects a pair of datasets" + + super().__init__(savedir, config, is_train, device) + + self.criterion_gan = select_gan_loss(config.loss).to(self.device) + self.criterion_cycle = torch.nn.L1Loss() + self.criterion_idt = torch.nn.L1Loss() + self.criterion_consist = torch.nn.L1Loss() + self.gradient_penalty = config.gradient_penalty + + if self.is_train: + self.queues = NamedDict(**{ + name : FastQueue(head_queue_size, device = device) + for name in [ 'real_a', 'real_b', 'fake_a', 'fake_b' ] + }) + + self.gp = None + self.gp_cacher_a = None + self.gp_cacher_b = None + + if config.gradient_penalty is not None: + self.gp = GradientPenalty(**config.gradient_penalty) + + self.gp_cacher_a = GradientCacher( + self.models.disc_a, self.gp, gp_cache_period + ) + self.gp_cacher_b = GradientCacher( + self.models.disc_b, self.gp, gp_cache_period + ) + + def _set_input(self, inputs, domain): + set_two_domain_input(self.images, inputs, domain, self.device) + + if self.images.real_a is not None: + self.images.real_hc_a = init_hc(self.images.real_a, self.n_hidden) + + if self.masking is None: + self.images.idt_input_a = self.images.real_a + else: + self.images.idt_input_a = self.masking(self.images.real_a) + + if self.consist_model is not None: + self.images.consist_real_a \ + = self.consist_model(self.images.real_a) + + if self.images.real_b is not None: + self.images.real_hc_b = init_hc(self.images.real_b, self.n_hidden) + + if self.masking is None: + self.images.idt_input_b = self.images.real_b + else: + self.images.idt_input_b = self.masking(self.images.real_b) + + if self.consist_model is not None: + self.images.consist_real_b \ + = self.consist_model(self.images.real_b) + + def cycle_forward_image(self, real, real_hc, gen_fwd, gen_bkw): + # pylint: disable=no-self-use + + # (N, C+n_hidden, H, W) + real_with_hc = torch.cat((real, real_hc), dim = 1) + + fake_with_hc = gen_fwd(real_with_hc) + reco_with_hc = gen_bkw(fake_with_hc) + + if self.n_hidden > 0: + fake = fake_with_hc[:, :-self.n_hidden, ...] + fake_hc = fake_with_hc[:, -self.n_hidden:, ...] + + reco = reco_with_hc[:, :-self.n_hidden, ...] + reco_hc = reco_with_hc[:, -self.n_hidden:, ...] + else: + fake = fake_with_hc + fake_hc = None + + reco = reco_with_hc + reco_hc = None + + consist_fake = None + + if self.consist_model is not None: + consist_fake = self.consist_model(fake) + + return (fake, fake_hc, reco, reco_hc, consist_fake) + + def idt_forward_image(self, idt_input, real_hc, gen): + # pylint: disable=no-self-use + + # (N, C+n_hidden, H, W) + real_with_hc = torch.cat((idt_input, real_hc), dim = 1) + idt_with_hc = gen(real_with_hc) + + if self.n_hidden > 0: + idt = idt_with_hc[:, :-self.n_hidden, ...] + idt_hc = idt_with_hc[:, -self.n_hidden:, ...] + else: + idt = idt_with_hc + idt_hc = None + + return (idt, idt_hc) + + def forward_dispatch(self, direction): + if direction == 'ab': + ( + self.images.fake_b, self.images.fake_hc_b, + self.images.reco_a, self.images.reco_hc_a, + self.images.consist_fake_b + ) = self.cycle_forward_image( + self.images.real_a, self.images.real_hc_a, + self.models.gen_ab, self.models.gen_ba + ) + + elif direction == 'ba': + ( + self.images.fake_a, self.images.fake_hc_a, + self.images.reco_b, self.images.reco_hc_b, + self.images.consist_fake_a + ) = self.cycle_forward_image( + self.images.real_b, self.images.real_hc_b, + self.models.gen_ba, self.models.gen_ab + ) + + elif direction == 'aa': + self.images.idt_a, self.images.idt_hc_a = \ + self.idt_forward_image( + self.images.idt_input_a, self.images.real_hc_a, + self.models.gen_ba + ) + + elif direction == 'bb': + self.images.idt_b, self.images.idt_hc_b = \ + self.idt_forward_image( + self.images.idt_input_b, self.images.real_hc_b, + self.models.gen_ab + ) + + elif direction == 'avg-ab': + ( + self.images.fake_b, self.images.fake_hc_b, + self.images.reco_a, self.images.reco_hc_a, + self.images.consist_fake_b + ) = self.cycle_forward_image( + self.images.real_a, self.images.real_hc_a, + self.models.avg_gen_ab, self.models.avg_gen_ba + ) + + elif direction == 'avg-ba': + ( + self.images.fake_a, self.images.fake_hc_a, + self.images.reco_b, self.images.reco_hc_b, + self.images.consist_fake_a + ) = self.cycle_forward_image( + self.images.real_b, self.images.real_hc_b, + self.models.avg_gen_ba, self.models.avg_gen_ab + ) + + else: + raise ValueError(f"Unknown forward direction: '{direction}'") + + def forward(self): + if self.images.real_a is not None: + if self.avg_momentum is not None: + self.forward_dispatch(direction = 'avg-ab') + else: + self.forward_dispatch(direction = 'ab') + + if self.images.real_b is not None: + if self.avg_momentum is not None: + self.forward_dispatch(direction = 'avg-ba') + else: + self.forward_dispatch(direction = 'ba') + + def eval_consist_loss( + self, consist_real_0, consist_fake_1, lambda_cycle_0 + ): + return lambda_cycle_0 * self.lambda_consist * self.criterion_consist( + consist_fake_1, consist_real_0 + ) + + def eval_loss_of_cycle_forward( + self, disc_1, real_0, fake_1, reco_0, fake_queue_1, lambda_cycle_0 + ): + # pylint: disable=too-many-arguments + # NOTE: Queue is updated in discriminator backprop + disc_pred_fake_1 = queued_forward( + disc_1, fake_1, fake_queue_1, update_queue = False + ) + + loss_gen = self.criterion_gan( + disc_pred_fake_1, is_real = True, is_generator = True + ) + loss_cycle = lambda_cycle_0 * self.criterion_cycle(reco_0, real_0) + + loss = loss_gen + loss_cycle + + return (loss_gen, loss_cycle, loss) + + def eval_loss_of_idt_forward(self, real_0, idt_0, lambda_cycle_0): + loss_idt = ( + lambda_cycle_0 + * self.lambda_idt + * self.criterion_idt(idt_0, real_0) + ) + + loss = loss_idt + + return (loss_idt, loss) + + def backward_gen(self, direction): + if direction == 'ab': + (self.losses.gen_ab, self.losses.cycle_a, loss) \ + = self.eval_loss_of_cycle_forward( + self.models.disc_b, + self.images.real_a, self.images.fake_b, self.images.reco_a, + self.queues.fake_b, self.lambda_a + ) + + if self.consist_model is not None: + self.losses.consist_a = self.eval_consist_loss( + self.images.consist_real_a, self.images.consist_fake_b, + self.lambda_a + ) + + loss += self.losses.consist_a + + elif direction == 'ba': + (self.losses.gen_ba, self.losses.cycle_b, loss) \ + = self.eval_loss_of_cycle_forward( + self.models.disc_a, + self.images.real_b, self.images.fake_a, self.images.reco_b, + self.queues.fake_a, self.lambda_b + ) + + if self.consist_model is not None: + self.losses.consist_b = self.eval_consist_loss( + self.images.consist_real_b, self.images.consist_fake_a, + self.lambda_b + ) + + loss += self.losses.consist_b + + elif direction == 'aa': + (self.losses.idt_a, loss) \ + = self.eval_loss_of_idt_forward( + self.images.real_a, self.images.idt_a, self.lambda_a + ) + + elif direction == 'bb': + (self.losses.idt_b, loss) \ + = self.eval_loss_of_idt_forward( + self.images.real_b, self.images.idt_b, self.lambda_b + ) + else: + raise ValueError(f"Unknown forward direction: '{direction}'") + + + loss.backward() + + def backward_discriminator_base( + self, model, real, fake, queue_real, queue_fake, gp_cacher + ): + # pylint: disable=too-many-arguments + loss_gp = None + + if self.gp is not None: + loss_gp = gp_cacher( + model, fake, real, + model_kwargs_fake = { 'extra_bodies' : queue_fake.query() }, + model_kwargs_real = { 'extra_bodies' : queue_real.query() }, + ) + + pred_real = queued_forward( + model, real, queue_real, update_queue = True + ) + loss_real = self.criterion_gan(pred_real, is_real = True) + + pred_fake = queued_forward( + model, fake, queue_fake, update_queue = True + ) + loss_fake = self.criterion_gan(pred_fake, is_real = False) + + loss = (loss_real + loss_fake) * 0.5 + loss.backward() + + return (loss_gp, loss) + + def backward_discriminators(self): + fake_a = self.images.fake_a.detach() + fake_b = self.images.fake_b.detach() + + loss_gp_b, self.losses.disc_b \ + = self.backward_discriminator_base( + self.models.disc_b, self.images.real_b, fake_b, + self.queues.real_b, self.queues.fake_b, self.gp_cacher_b + ) + + if loss_gp_b is not None: + self.losses.gp_b = loss_gp_b + + loss_gp_a, self.losses.disc_a = \ + self.backward_discriminator_base( + self.models.disc_a, self.images.real_a, fake_a, + self.queues.real_a, self.queues.fake_a, self.gp_cacher_a + ) + + if loss_gp_a is not None: + self.losses.gp_a = loss_gp_a + + def optimization_step_gen(self): + self.set_requires_grad([self.models.disc_a, self.models.disc_b], False) + self.optimizers.gen.zero_grad(set_to_none = True) + + dir_list = [ 'ab', 'ba' ] + if self.lambda_idt > 0: + dir_list += [ 'aa', 'bb' ] + + for direction in dir_list: + self.forward_dispatch(direction) + self.backward_gen(direction) + + self.optimizers.gen.step() + + def optimization_step_disc(self): + self.set_requires_grad([self.models.disc_a, self.models.disc_b], True) + self.optimizers.disc.zero_grad(set_to_none = True) + + self.backward_discriminators() + + self.optimizers.disc.step() + + def _accumulate_averages(self): + update_average_model( + self.models.avg_gen_ab, self.models.gen_ab, self.avg_momentum + ) + update_average_model( + self.models.avg_gen_ba, self.models.gen_ba, self.avg_momentum + ) + + def optimization_step(self): + self.optimization_step_gen() + self.optimization_step_disc() + + if self.avg_momentum is not None: + self._accumulate_averages() + + diff --git a/uvcgan_s/cgan/uvcgan_s.py b/uvcgan_s/cgan/uvcgan_s.py new file mode 100644 index 0000000..c0d9148 --- /dev/null +++ b/uvcgan_s/cgan/uvcgan_s.py @@ -0,0 +1,601 @@ +# pylint: disable=not-callable +# NOTE: Mistaken lint: +# E1102: self.criterion_gan is not callable (not-callable) + +import itertools +import torch + +from uvcgan_s.torch.data_norm import select_data_normalization +from uvcgan_s.torch.gan_losses import select_gan_loss +from uvcgan_s.torch.select import select_optimizer +from uvcgan_s.torch.queue import FastQueue +from uvcgan_s.torch.funcs import ( + prepare_model, update_average_model, clip_gradients +) +from uvcgan_s.torch.layers.batch_head import BatchHeadWrapper, get_batch_head +from uvcgan_s.torch.gradient_penalty import GradientPenalty +from uvcgan_s.torch.gradient_cacher import GradientCacher +from uvcgan_s.models.discriminator import construct_discriminator +from uvcgan_s.models.generator import construct_generator + +from .model_base import ModelBase +from .named_dict import NamedDict +from .funcs import set_asym_two_domain_input + +def queued_forward( + batch_head_model, input_image, queue, data_norm, normalize, + update_queue = True +): + # pylint: disable=too-many-arguments + if normalize: + input_image = data_norm.normalize(input_image) + + output, pred_body = batch_head_model.forward( + input_image, extra_bodies = queue.query(), return_body = True + ) + + if update_queue: + queue.push(pred_body) + + return output + +def eval_loss_with_norm(loss_fn, a, b, data_norm, normalize): + if normalize: + a = data_norm.normalize(a) + b = data_norm.normalize(b) + + return loss_fn(a, b) + +class UVCGAN_S(ModelBase): + # pylint: disable=too-many-instance-attributes + + def _setup_images(self, _config): + images = [ + 'real_a0', 'real_a1', 'real_a', 'real_b', + 'fake_a0', 'fake_a1', 'fake_a', 'fake_b', + 'reco_a0', 'reco_a1', 'reco_a', 'reco_b', + ] + + if self.lambda_idt_bb: + images += [ 'idt_bb_input', 'idt_bb', ] + + if self.lambda_idt_aa: + images += [ + 'idt_aa_input', 'idt_aa', 'idt_aa_a0', 'idt_aa_a1', + ] + + return NamedDict(*images) + + def _construct_batch_head_disc(self, model_config, input_shape): + disc_body = construct_discriminator( + model_config, input_shape, self.device + ) + + disc_head = get_batch_head(self.head_config) + disc_head = prepare_model(disc_head, self.device) + + return BatchHeadWrapper(disc_body, disc_head) + + def _setup_models(self, config): + models = {} + + shape_a0 = tuple(config.data.datasets[0].shape) + shape_a1 = tuple(config.data.datasets[1].shape) + + shape_a = (shape_a0[0] + shape_a1[0], *shape_a1[1:]) + shape_b = tuple(config.data.datasets[2].shape) + + models['gen_ab'] = construct_generator( + config.generator, shape_a, shape_b, self.device + ) + models['gen_ba'] = construct_generator( + config.generator, shape_b, shape_a, self.device + ) + + if self.ema_momentum is not None: + models['ema_gen_ab'] = construct_generator( + config.generator, shape_a, shape_b, self.device + ) + models['ema_gen_ba'] = construct_generator( + config.generator, shape_b, shape_a, self.device + ) + + models['ema_gen_ab'].load_state_dict(models['gen_ab'].state_dict()) + models['ema_gen_ba'].load_state_dict(models['gen_ba'].state_dict()) + + if self.is_train: + models['disc_a0'] = self._construct_batch_head_disc( + config.discriminator, shape_a0 + ) + models['disc_a1'] = self._construct_batch_head_disc( + config.discriminator, shape_a1 + ) + models['disc_b'] = self._construct_batch_head_disc( + config.discriminator, shape_b + ) + + return NamedDict(**models) + + def _setup_losses(self, config): + losses = [ + 'gen_ab', 'gen_ba0', 'gen_ba1', + 'cycle_a0', 'cycle_a1', 'cycle_b', + 'disc_a0', 'disc_a1', 'disc_b' + ] + + if self.lambda_idt_aa: + losses += [ + 'idt_aa_a0', 'idt_aa_a1', + ] + + if self.lambda_idt_bb: + losses += [ 'idt_bb' ] + + if config.gradient_penalty is not None: + losses += [ 'gp_a0', 'gp_a1', 'gp_b' ] + + return NamedDict(*losses) + + def _setup_optimizers(self, config): + optimizers = NamedDict('gen', 'disc') + + optimizers.gen = select_optimizer( + itertools.chain( + self.models.gen_ab.parameters(), + self.models.gen_ba.parameters() + ), + config.generator.optimizer + ) + + optimizers.disc = select_optimizer( + itertools.chain( + self.models.disc_a0.parameters(), + self.models.disc_a1.parameters(), + self.models.disc_b.parameters() + ), + config.discriminator.optimizer + ) + + return optimizers + + def __init__( + self, savedir, config, is_train, device, head_config = None, + lambda_adv_a0 = 1.0, + lambda_adv_a1 = 1.0, + lambda_adv_b = 1.0, + lambda_cyc_a0 = 10.0, + lambda_cyc_a1 = 10.0, + lambda_cyc_b = 10.0, + lambda_idt_aa = 0.5, + lambda_idt_bb = 0.5, + head_queue_size = 3, + ema_momentum = None, + data_norm = None, + gp_cache_period = 1, + grad_clip = None, + norm_loss_a0 = False, + norm_loss_a1 = False, + norm_loss_b = False, + norm_disc_a0 = False, + norm_disc_a1 = False, + norm_disc_b = False, + ): + # pylint: disable=too-many-arguments + # pylint: disable=too-many-locals + self.lambda_adv_a0 = lambda_adv_a0 + self.lambda_adv_a1 = lambda_adv_a1 + self.lambda_adv_b = lambda_adv_b + + self.lambda_cyc_a0 = lambda_cyc_a0 + self.lambda_cyc_a1 = lambda_cyc_a1 + self.lambda_cyc_b = lambda_cyc_b + + self.lambda_idt_aa = lambda_idt_aa + self.lambda_idt_bb = lambda_idt_bb + + self.ema_momentum = ema_momentum + self.data_norm = select_data_normalization(data_norm) + self.head_config = head_config or {} + + assert len(config.data.datasets) == 3, \ + "Asymmetric CycleGAN expects a triplet of datasets" + + self._c_a0 = config.data.datasets[0].shape[0] + self._c_a1 = config.data.datasets[1].shape[0] + + self._grad_clip = grad_clip or {} + + self._norm_loss_a0 = norm_loss_a0 + self._norm_loss_a1 = norm_loss_a1 + self._norm_loss_b = norm_loss_b + + self._norm_disc_a0 = norm_disc_a0 + self._norm_disc_a1 = norm_disc_a1 + self._norm_disc_b = norm_disc_b + + super().__init__(savedir, config, is_train, device) + + self.criterion_gan = select_gan_loss(config.loss).to(self.device) + self.criterion_idt = torch.nn.L1Loss() + self.criterion_cycle = torch.nn.L1Loss() + self.gradient_penalty = config.gradient_penalty + + if self.is_train: + self.queues = NamedDict(**{ + name : FastQueue(head_queue_size, device = device) + for name in [ + 'real_a0', 'real_a1', 'real_b', + 'fake_a0', 'fake_a1', 'fake_b' + ] + }) + + self.gp = None + self.gp_cacher_a0 = None + self.gp_cacher_a1 = None + self.gp_cacher_b = None + + if config.gradient_penalty is not None: + self.gp = GradientPenalty(**config.gradient_penalty) + + self.gp_cacher_a0 = GradientCacher( + self.models.disc_a0, self.gp, gp_cache_period + ) + self.gp_cacher_a1 = GradientCacher( + self.models.disc_a1, self.gp, gp_cache_period + ) + self.gp_cacher_b = GradientCacher( + self.models.disc_b, self.gp, gp_cache_period + ) + + def split_domain_a_image(self, image_a): + image_a0 = image_a[:, :self._c_a0, ...] + image_a1 = image_a[:, self._c_a0:, ...] + + return (image_a0, image_a1) + + def merge_domain_a_images(self, image_a0, image_a1): + # pylint: disable=no-self-use + return torch.cat((image_a0, image_a1), dim = 1) + + def _set_input(self, inputs, domain): + set_asym_two_domain_input(self.images, inputs, domain, self.device) + + if ( + (self.images.real_a0 is not None) + and (self.images.real_a1 is not None) + ): + self.images.real_a = self.merge_domain_a_images( + self.images.real_a0, self.images.real_a1 + ) + + if not self.is_train: + return + + if self.lambda_idt_bb and (self.images.real_b is not None): + self.images.idt_bb_input = self.merge_domain_a_images( + self.images.real_b, torch.zeros_like(self.images.real_a1) + ) + + if self.lambda_idt_aa and ( + (self.images.real_a0 is not None) + and (self.images.real_a1 is not None) + ): + self.images.idt_aa_input = ( + self.images.real_a0 + self.images.real_a1 + ) + + def cycle_forward_image(self, real, gen_fwd, gen_bkw): + # pylint: disable=no-self-use + + # (N, C, H, W) + if self.data_norm is not None: + real = self.data_norm.normalize(real) + + fake = gen_fwd(real) + reco = gen_bkw(fake) + + if self.data_norm is not None: + fake = self.data_norm.denormalize(fake) + reco = self.data_norm.denormalize(reco) + + return (fake, reco) + + def idt_forward_image(self, idt_input, gen): + # pylint: disable=no-self-use + if self.data_norm is not None: + idt_input = self.data_norm.normalize(idt_input) + + idt = gen(idt_input) + + if self.data_norm is not None: + idt = self.data_norm.denormalize(idt) + + return idt + + def forward_dispatch(self, direction): + if direction == 'cyc-aba': + (self.images.fake_b, self.images.reco_a) \ + = self.cycle_forward_image( + self.images.real_a, self.models.gen_ab, self.models.gen_ba + ) + + (self.images.reco_a0, self.images.reco_a1) \ + = self.split_domain_a_image(self.images.reco_a) + + elif direction == 'cyc-bab': + (self.images.fake_a, self.images.reco_b) \ + = self.cycle_forward_image( + self.images.real_b, self.models.gen_ba, self.models.gen_ab + ) + + (self.images.fake_a0, self.images.fake_a1) \ + = self.split_domain_a_image(self.images.fake_a) + + elif direction == 'idt-bb': + self.images.idt_bb = self.idt_forward_image( + self.images.idt_bb_input, self.models.gen_ab + ) + + elif direction == 'idt-aa': + self.images.idt_aa = self.idt_forward_image( + self.images.idt_aa_input, self.models.gen_ba + ) + + (self.images.idt_aa_a0, self.images.idt_aa_a1) \ + = self.split_domain_a_image(self.images.idt_aa) + + elif direction == 'ema-cyc-aba': + (self.images.fake_b, self.images.reco_a) \ + = self.cycle_forward_image( + self.images.real_a, + self.models.ema_gen_ab, self.models.ema_gen_ba + ) + + (self.images.reco_a0, self.images.reco_a1) \ + = self.split_domain_a_image(self.images.reco_a) + + elif direction == 'ema-cyc-bab': + (self.images.fake_a, self.images.reco_b) \ + = self.cycle_forward_image( + self.images.real_b, + self.models.ema_gen_ba, self.models.ema_gen_ab + ) + + (self.images.fake_a0, self.images.fake_a1) \ + = self.split_domain_a_image(self.images.fake_a) + + else: + raise ValueError(f"Unknown forward direction: '{direction}'") + + def forward(self): + if self.images.real_a is not None: + if self.ema_momentum is not None: + self.forward_dispatch(direction = 'ema-cyc-aba') + else: + self.forward_dispatch(direction = 'cyc-aba') + + if self.images.real_b is not None: + if self.ema_momentum is not None: + self.forward_dispatch(direction = 'ema-cyc-bab') + else: + self.forward_dispatch(direction = 'cyc-bab') + + def eval_loss_of_cycle_forward_aba(self): + # NOTE: Queue is updated in discriminator backprop + disc_pred_fake_b = queued_forward( + self.models.disc_b, self.images.fake_b, self.queues.fake_b, + self.data_norm, self._norm_disc_b, update_queue = False + ) + + self.losses.gen_ab = self.criterion_gan( + disc_pred_fake_b, is_real = True, is_generator = True + ) + + self.losses.cycle_a0 = eval_loss_with_norm( + self.criterion_cycle, self.images.reco_a0, self.images.real_a0, + self.data_norm, self._norm_loss_a0 + ) + self.losses.cycle_a1 = eval_loss_with_norm( + self.criterion_cycle, self.images.reco_a1, self.images.real_a1, + self.data_norm, self._norm_loss_a1 + ) + + return ( + self.lambda_adv_b * self.losses.gen_ab + + 0.5 * self.lambda_cyc_a0 * self.losses.cycle_a0 + + 0.5 * self.lambda_cyc_a1 * self.losses.cycle_a1 + ) + + def eval_loss_of_cycle_forward_bab(self): + # NOTE: Queue is updated in discriminator backprop + disc_pred_fake_a0 = queued_forward( + self.models.disc_a0, self.images.fake_a0, self.queues.fake_a0, + self.data_norm, self._norm_disc_a0, update_queue = False + ) + disc_pred_fake_a1 = queued_forward( + self.models.disc_a1, self.images.fake_a1, self.queues.fake_a1, + self.data_norm, self._norm_disc_a1, update_queue = False + ) + + self.losses.gen_ba0 = self.criterion_gan( + disc_pred_fake_a0, is_real = True, is_generator = True + ) + self.losses.gen_ba1 = self.criterion_gan( + disc_pred_fake_a1, is_real = True, is_generator = True + ) + + self.losses.cycle_b = eval_loss_with_norm( + self.criterion_cycle, self.images.reco_b, self.images.real_b, + self.data_norm, self._norm_loss_b + ) + + return ( + 0.5 * self.lambda_adv_a0 * self.losses.gen_ba0 + + 0.5 * self.lambda_adv_a1 * self.losses.gen_ba1 + + self.lambda_cyc_b * self.losses.cycle_b + ) + + def eval_loss_of_idt_forward_bb(self): + self.losses.idt_bb = eval_loss_with_norm( + self.criterion_idt, self.images.idt_bb, self.images.real_b, + self.data_norm, self._norm_loss_b + ) + + return self.lambda_idt_bb * self.lambda_cyc_b * self.losses.idt_bb + + def eval_loss_of_idt_forward_aa(self): + target_a0 = self.images.real_a0 + target_a1 = self.images.real_a1 + + self.losses.idt_aa_a0 = eval_loss_with_norm( + self.criterion_idt, self.images.idt_aa_a0, target_a0, + self.data_norm, self._norm_loss_a0 + ) + self.losses.idt_aa_a1 = eval_loss_with_norm( + self.criterion_idt, self.images.idt_aa_a1, target_a1, + self.data_norm, self._norm_loss_a1 + ) + + # * 0.5 to account for 2 losses ist_aa_a{0,1} relative to bb + return 0.5 * self.lambda_idt_aa * ( + self.lambda_cyc_a0 * self.losses.idt_aa_a0 + + self.lambda_cyc_a1 * self.losses.idt_aa_a1 + ) + + def backward_gen(self, direction): + if direction == 'cyc-aba': + loss = self.eval_loss_of_cycle_forward_aba() + + elif direction == 'cyc-bab': + loss = self.eval_loss_of_cycle_forward_bab() + + elif direction == 'idt-bb': + loss = self.eval_loss_of_idt_forward_bb() + + elif direction == 'idt-aa': + loss = self.eval_loss_of_idt_forward_aa() + + else: + raise ValueError(f"Unknown forward direction: '{direction}'") + + loss.backward() + + def backward_discriminator_base( + self, model, real, fake, queue_real, queue_fake, gp_cacher, scale, + normalize + ): + # pylint: disable=too-many-arguments + loss_gp = None + + if self.gp is not None: + loss_gp = gp_cacher( + model, fake, real, + model_kwargs_fake = { 'extra_bodies' : queue_fake.query() }, + model_kwargs_real = { 'extra_bodies' : queue_real.query() }, + ) + + pred_real = queued_forward( + model, real, queue_real, self.data_norm, normalize, + update_queue = True + ) + loss_real = self.criterion_gan( + pred_real, is_real = True, is_generator = False + ) + + pred_fake = queued_forward( + model, fake, queue_fake, self.data_norm, normalize, + update_queue = True + ) + loss_fake = self.criterion_gan( + pred_fake, is_real = False, is_generator = False + ) + + loss = (loss_real + loss_fake) * 0.5 * scale + loss.backward() + + return (loss_gp, loss) + + def backward_discriminators(self): + fake_a0 = self.images.fake_a0.detach() + fake_a1 = self.images.fake_a1.detach() + fake_b = self.images.fake_b .detach() + + loss_gp_b, self.losses.disc_b \ + = self.backward_discriminator_base( + self.models.disc_b, self.images.real_b, fake_b, + self.queues.real_b, self.queues.fake_b, self.gp_cacher_b, + self.lambda_adv_b, self._norm_disc_b + ) + + if loss_gp_b is not None: + self.losses.gp_b = loss_gp_b + + loss_gp_a0, self.losses.disc_a0 = \ + self.backward_discriminator_base( + self.models.disc_a0, self.images.real_a0, fake_a0, + self.queues.real_a0, self.queues.fake_a0, self.gp_cacher_a0, + 0.5 * self.lambda_adv_a0, self._norm_disc_a0 + ) + + if loss_gp_a0 is not None: + self.losses.gp_a0 = loss_gp_a0 + + loss_gp_a1, self.losses.disc_a1 = \ + self.backward_discriminator_base( + self.models.disc_a1, self.images.real_a1, fake_a1, + self.queues.real_a1, self.queues.fake_a1, self.gp_cacher_a1, + 0.5 * self.lambda_adv_a1, self._norm_disc_a1 + ) + + if loss_gp_a1 is not None: + self.losses.gp_a1 = loss_gp_a1 + + def optimization_step_gen(self): + self.set_requires_grad( + [self.models.disc_a0, self.models.disc_a1, self.models.disc_b], + False + ) + self.optimizers.gen.zero_grad(set_to_none = True) + + dir_list = [ 'cyc-aba', 'cyc-bab' ] + + if self.lambda_idt_bb: + dir_list += [ 'idt-bb' ] + + if self.lambda_idt_aa: + dir_list += [ 'idt-aa' ] + + for direction in dir_list: + self.forward_dispatch(direction) + self.backward_gen(direction) + + clip_gradients(self.optimizers.gen, **self._grad_clip) + self.optimizers.gen.step() + + def optimization_step_disc(self): + self.set_requires_grad( + [self.models.disc_a0, self.models.disc_a1, self.models.disc_b], + True + ) + self.optimizers.disc.zero_grad(set_to_none = True) + + self.backward_discriminators() + + clip_gradients(self.optimizers.disc, **self._grad_clip) + self.optimizers.disc.step() + + def _accumulate_averages(self): + update_average_model( + self.models.ema_gen_ab, self.models.gen_ab, self.ema_momentum + ) + update_average_model( + self.models.ema_gen_ba, self.models.gen_ba, self.ema_momentum + ) + + def optimization_step(self): + self.optimization_step_gen() + self.optimization_step_disc() + + if self.ema_momentum is not None: + self._accumulate_averages() + diff --git a/uvcgan_s/config/__init__.py b/uvcgan_s/config/__init__.py new file mode 100644 index 0000000..6c20b3e --- /dev/null +++ b/uvcgan_s/config/__init__.py @@ -0,0 +1,4 @@ +from .args import Args +from .config import Config +from .model_config import ModelConfig + diff --git a/uvcgan_s/config/args.py b/uvcgan_s/config/args.py new file mode 100644 index 0000000..306ed00 --- /dev/null +++ b/uvcgan_s/config/args.py @@ -0,0 +1,101 @@ +import difflib +import os +from .config import Config + +LABEL_FNAME = 'label' + +def get_config_difference(config_old, config_new): + diff_gen = difflib.unified_diff( + config_old.to_json(sort_keys = True, indent = 4).split('\n'), + config_new.to_json(sort_keys = True, indent = 4).split('\n'), + fromfile = 'Old Config', + tofile = 'New Config', + ) + + return "\n".join(diff_gen) + +class Args: + __slots__ = [ + 'config', + 'label', + 'savedir', + 'checkpoint', + 'log_level', + ] + + def __init__( + self, config, savedir, label, + log_level = 'INFO', + checkpoint = 100, + ): + # pylint: disable=too-many-arguments + self.config = config + self.label = label + self.savedir = savedir + self.checkpoint = checkpoint + self.log_level = log_level + + def __getattr__(self, attr): + return getattr(self.config, attr) + + def save(self): + self.config.save(self.savedir) + + if self.label is not None: + # pylint: disable=unspecified-encoding + with open(os.path.join(self.savedir, LABEL_FNAME), 'wt') as f: + f.write(self.label) + + def check_no_collision(self): + try: + old_config = Config.load(self.savedir) + except IOError: + return + + old = old_config.to_json(sort_keys = True) + new = self.config.to_json(sort_keys = True) + + if old != new: + diff = get_config_difference(old_config, self.config) + + raise RuntimeError( + ( + f"Config collision detected in '{self.savedir}'" + f" . Difference:\n{diff}\n" + "If you would like to overwrite the config then delete the" + f" old one in '{self.savedir}' ." + ) + ) + + @staticmethod + def from_args_dict( + outdir, + label = None, + log_level = 'INFO', + checkpoint = 100, + **args_dict + ): + config = Config(**args_dict) + savedir = config.get_savedir(outdir, label) + + result = Args(config, savedir, label, log_level, checkpoint) + result.check_no_collision() + + result.save() + + return result + + @staticmethod + def load(savedir): + config = Config.load(savedir) + label = None + + label_path = os.path.join(savedir, LABEL_FNAME) + + if os.path.exists(label_path): + # pylint: disable=unspecified-encoding + with open(label_path, 'rt') as f: + label = f.read() + + return Args(config, savedir, label) + diff --git a/uvcgan_s/config/config.py b/uvcgan_s/config/config.py new file mode 100644 index 0000000..77bdf7a --- /dev/null +++ b/uvcgan_s/config/config.py @@ -0,0 +1,142 @@ +import json +import logging +import os + +from uvcgan_s.consts import CONFIG_NAME + +from .config_base import ConfigBase +from .data_config import parse_data_config +from .model_config import ModelConfig +from .transfer_config import TransferConfig + +LOGGER = logging.getLogger('uvcgan_s.config') + +class Config(ConfigBase): + # pylint: disable=too-many-instance-attributes + + __slots__ = [ + 'batch_size', + 'data', + 'epochs', + 'discriminator', + 'generator', + 'model', + 'model_args', + 'loss', + 'gradient_penalty', + 'seed', + 'scheduler', + 'steps_per_epoch', + 'transfer', + ] + + def __init__( + self, + batch_size = 32, + data = None, + data_args = None, + epochs = 100, + image_shape = None, + discriminator = None, + generator = None, + model = 'cyclegan', + model_args = None, + loss = 'lsgan', + gradient_penalty = None, + seed = 0, + scheduler = None, + steps_per_epoch = 250, + transfer = None, + workers = None, + ): + # pylint: disable=too-many-arguments + # pylint: disable=too-many-locals + self.data = parse_data_config(data, data_args, image_shape, workers) + + self.batch_size = batch_size + self.model = model + self.model_args = model_args or {} + self.seed = seed + self.loss = loss + self.epochs = epochs + self.scheduler = scheduler + self.steps_per_epoch = steps_per_epoch + + if discriminator is not None: + discriminator = ModelConfig(**discriminator) + + if generator is not None: + generator = ModelConfig(**generator) + + if gradient_penalty is True: + gradient_penalty = {} + + if transfer is not None: + if isinstance(transfer, list): + transfer = [ TransferConfig(**conf) for conf in transfer ] + else: + transfer = TransferConfig(**transfer) + + self.discriminator = discriminator + self.generator = generator + self.gradient_penalty = gradient_penalty + self.transfer = transfer + + Config._check_deprecated_args(image_shape, workers) + + if image_shape is not None: + self._validate_image_shape(image_shape) + + @staticmethod + def _check_deprecated_args(image_shape, workers): + if image_shape is not None: + LOGGER.warning( + "Deprecation Warning: Deprecated `image_shape` configuration " + "parameter detected." + ) + + if workers is not None: + LOGGER.warning( + "Deprecation Warning: Deprecated `workers` configuration " + "parameter detected." + ) + + def _validate_image_shape(self, image_shape): + assert all(d.shape == image_shape for d in self.data.datasets), ( + f"Value of the deprecated `image_shape` parameter {image_shape}" + f"does not match shapes of the datasets." + ) + + def get_savedir(self, outdir, label = None): + if label is None: + label = self.get_hash() + + discriminator = None + if self.discriminator is not None: + discriminator = self.discriminator.model + + generator = None + if self.generator is not None: + generator = self.generator.model + + savedir = 'model_m(%s)_d(%s)_g(%s)_%s' % ( + self.model, discriminator, generator, label + ) + + savedir = savedir.replace('/', ':') + path = os.path.join(outdir, savedir) + + os.makedirs(path, exist_ok = True) + return path + + def save(self, path): + # pylint: disable=unspecified-encoding + with open(os.path.join(path, CONFIG_NAME), 'wt') as f: + f.write(self.to_json(sort_keys = True, indent = ' ')) + + @staticmethod + def load(path): + # pylint: disable=unspecified-encoding + with open(os.path.join(path, CONFIG_NAME), 'rt') as f: + return Config(**json.load(f)) + diff --git a/uvcgan_s/config/config_base.py b/uvcgan_s/config/config_base.py new file mode 100644 index 0000000..813445a --- /dev/null +++ b/uvcgan_s/config/config_base.py @@ -0,0 +1,27 @@ +import json +import hashlib + +class ConfigBase: + + __slots__ = [] + + def to_dict(self): + return { x : getattr(self, x) for x in self.__slots__ } + + def to_json(self, **kwargs): + return json.dumps(self, default = lambda x : x.to_dict(), **kwargs) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def get_hash(self): + s = self.to_json(sort_keys = True) + + md5 = hashlib.md5() + md5.update(s.encode()) + + return md5.hexdigest() + diff --git a/uvcgan_s/config/data_config.py b/uvcgan_s/config/data_config.py new file mode 100644 index 0000000..ba808ca --- /dev/null +++ b/uvcgan_s/config/data_config.py @@ -0,0 +1,281 @@ +import logging + +from uvcgan_s.consts import MERGE_PAIRED, MERGE_UNPAIRED, MERGE_NONE +from uvcgan_s.utils.funcs import check_value_in_range + +from .config_base import ConfigBase + +LOGGER = logging.getLogger('uvcgan_s.config') +MERGE_TYPES = [ MERGE_PAIRED, MERGE_UNPAIRED, MERGE_NONE ] + +class DatasetConfig(ConfigBase): + """Dataset configuration. + + Parameters + ---------- + dataset : str or dict + Dataset specification. + shape : tuple of int + Shape of inputs. + transform_train : None or str or dict or list of those + Transformations to be applied to the training dataset. + If `transform_train` is None, then no transformations will be applied + to the training dataset. + If `transform_train` is str, then its value is interpreted as a name + of the transformation. + If `transform_train` is dict, then it is expected to be of the form + `{ 'name' : TRANFORM_NAME, **kwargs }`, where 'name' is the name of + the transformation, and `kwargs` dict will be passed to the + transformation constructor. + Otherwise, `transform_train` is expected to be a list of values above. + The corresponding transformations will be chained together in the + order that they are specified. + Default: None. + transform_val : None or str or dict or list of those + Transformations to be applied to the validation dataset. + C.f. `transform_train`. + Default: None. + """ + + __slots__ = [ + 'dataset', + 'shape', + 'transform_train', + 'transform_test', + ] + + def __init__( + self, dataset, shape, + transform_train = None, + transform_test = None, + ): + super().__init__() + + self.dataset = dataset + self.shape = shape + self.transform_train = transform_train + self.transform_test = transform_test + +class DataConfig(ConfigBase): + """Data configuration. + + Parameters + ---------- + datasets : list of dict + List of dataset specifications. + merge_type : str, optional + How to merge samples from datasets. + Choices: 'paired', 'unpaired', 'none'. + Default: 'unpaired' + workers : int, optional + Number of data workers. + Default: None + """ + + __slots__ = [ + 'datasets', + 'merge_type', + 'workers', + ] + + def __init__(self, datasets, merge_type = MERGE_UNPAIRED, workers = None): + super().__init__() + + check_value_in_range(merge_type, MERGE_TYPES, 'merge_type') + assert isinstance(datasets, list) + + self.datasets = [ DatasetConfig(**x) for x in datasets ] + self.merge_type = merge_type + self.workers = workers + +def parse_deprecated_data_config_v1_celeba( + dataset_args, image_shape, workers, transform_train, transform_val +): + attr = dataset_args.get('attr', None) + + if attr is None: + domains = [ None, ] + else: + domains = [ 'a', 'b' ] + + return DataConfig( + datasets = [ + { + 'dataset' : { + 'name' : 'celeba', + 'attr' : attr, + 'domain' : domain, + 'path' : dataset_args.get('path', None), + }, + 'shape' : image_shape, + 'transform_train' : transform_train, + 'transform_test' : transform_val, + } for domain in domains + ], + merge_type = 'unpaired', + workers = workers, + ) + +def parse_deprecated_data_config_v1_cyclegan( + dataset_args, image_shape, workers, transform_train, transform_val +): + return DataConfig( + datasets = [ + { + 'dataset' : { + 'name' : 'cyclegan', + 'domain' : domain, + 'path' : dataset_args.get('path', None), + }, + 'shape' : image_shape, + 'transform_train' : transform_train, + 'transform_test' : transform_val, + } for domain in ['a', 'b'] + ], + merge_type = 'unpaired', + workers = workers, + ) + +def parse_deprecated_data_config_v1_imagedir( + dataset_args, image_shape, workers, transform_train, transform_val +): + return DataConfig( + datasets = [ + { + 'dataset' : { + 'name' : 'imagedir', + 'path' : dataset_args.get('path', None), + }, + 'shape' : image_shape, + 'transform_train' : transform_train, + 'transform_test' : transform_val, + }, + ], + merge_type = 'none', + workers = workers, + ) + +def parse_deprecated_data_config_v1_toyzero_precropped( + dataset_args, image_shape, workers, transform_train, transform_val +): + align_train = dataset_args.get('align_train', False) + + return DataConfig( + datasets = [ + { + 'dataset' : { + 'name' : 'toyzero-precropped-v1', + 'domain' : domain, + 'path' : dataset_args.get('path', None), + }, + 'shape' : image_shape, + 'transform_train' : transform_train, + 'transform_test' : transform_val, + } for domain in [ 'a', 'b' ] + ], + merge_type = 'paired' if align_train else 'unpaired', + workers = workers, + ) + +def parse_deprecated_data_config_v1_toyzero_presimple( + dataset, dataset_args, image_shape, workers, transform_train, transform_val +): + # pylint: disable=too-many-arguments + if dataset == 'toyzero-presimple': + merge_type = 'paired' + + elif dataset == 'toyzero-preunaligned': + if dataset_args.get('align_train', False): + merge_type = 'paired' + else: + merge_type = 'unpaired' + + else: + raise ValueError(f"Unsupported dataset '{dataset}'") + + return DataConfig( + datasets = [ + { + 'dataset' : { + 'name' : 'toyzero-presimple-v1', + 'domain' : domain, + 'fname' : dataset_args.get('fname', None), + 'path' : dataset_args.get('path', None), + 'seed' : dataset_args.get('seed', 0), + 'shuffle' : dataset_args.get('shuffle', True), + }, + 'shape' : image_shape, + 'transform_train' : transform_train, + 'transform_test' : transform_val, + } for domain in [ 'a', 'b' ] + ], + merge_type = merge_type, + workers = workers, + ) + +def parse_deprecated_data_config_v1( + dataset, dataset_args, image_shape, workers, + transform_train = None, transform_val = None +): + # pylint: disable=too-many-arguments + if dataset == 'celeba': + return parse_deprecated_data_config_v1_celeba( + dataset_args, image_shape, workers, transform_train, transform_val + ) + + if dataset == 'cyclegan': + return parse_deprecated_data_config_v1_cyclegan( + dataset_args, image_shape, workers, transform_train, transform_val + ) + + if dataset == 'imagedir': + return parse_deprecated_data_config_v1_imagedir( + dataset_args, image_shape, workers, transform_train, transform_val + ) + + if dataset == 'toyzero-precropped': + return parse_deprecated_data_config_v1_toyzero_precropped( + dataset_args, image_shape, workers, transform_train, transform_val + ) + + if dataset in [ 'toyzero-presimple', 'toyzero-preunaligned' ]: + return parse_deprecated_data_config_v1_toyzero_presimple( + dataset, dataset_args, image_shape, workers, + transform_train, transform_val + ) + + raise NotImplementedError( + f"Do not know how to parse deprecated '{dataset}'" + ) + +def parse_data_config(data, data_args, image_shape, workers): + if isinstance(data, str): + LOGGER.warning( + "Deprecation Warning: Old (v0) dataset configuration detected." + " Please modify your configuration and change `data` parameter" + " into a dictionary describing `DataConfig` structure." + ) + return parse_deprecated_data_config_v1( + data, data_args, image_shape, workers + ) + + assert data_args is None, \ + "Deprecated `data_args` argument detected with new data configuration" + + if ( + ('dataset' in data) + or ('dataset_args' in data) + or ('transform_train' in data) + or ('transform_val' in data) + ): + LOGGER.warning( + "Deprecation Warning: Old (v1) dataset configuration detected." + " Please modify your configuration and change `data` parameter" + " into a dictionary describing `DataConfig` structure." + ) + return parse_deprecated_data_config_v1( + **data, image_shape = image_shape, workers = workers + ) + + return DataConfig(**data) + diff --git a/uvcgan_s/config/funcs.py b/uvcgan_s/config/funcs.py new file mode 100644 index 0000000..ec599f4 --- /dev/null +++ b/uvcgan_s/config/funcs.py @@ -0,0 +1,7 @@ +import os + +def create_evaldir(path, eval_name): + result = os.path.join(path, eval_name) + os.makedirs(result, exist_ok = True) + + return result diff --git a/uvcgan_s/config/model_config.py b/uvcgan_s/config/model_config.py new file mode 100644 index 0000000..2b287c5 --- /dev/null +++ b/uvcgan_s/config/model_config.py @@ -0,0 +1,34 @@ + +class ModelConfig: + + __slots__ = [ + 'model', + 'model_args', + 'optimizer', + 'weight_init', + 'lr_equal', + 'spectr_norm', + ] + + def __init__( + self, + model, + optimizer = None, + model_args = None, + weight_init = None, + lr_equal = False, + spectr_norm = False, + ): + # pylint: disable=too-many-arguments + self.model = model + self.model_args = model_args or {} + self.optimizer = optimizer or { + 'name' : 'AdamW', 'betas' : (0.5, 0.999), 'weight_decay' : 1e-5, + } + self.weight_init = weight_init + self.lr_equal = lr_equal + self.spectr_norm = spectr_norm + + def to_dict(self): + return { x : getattr(self, x) for x in self.__slots__ } + diff --git a/uvcgan_s/config/transfer_config.py b/uvcgan_s/config/transfer_config.py new file mode 100644 index 0000000..cbb1a5a --- /dev/null +++ b/uvcgan_s/config/transfer_config.py @@ -0,0 +1,53 @@ +from .config_base import ConfigBase + +class TransferConfig(ConfigBase): + """Model transfer configuration. + + Parameters + ---------- + base_model : str + Path to the model to transfer parameters from. If path is relative + then this path is interpreted relative to `ROOT_OUTDIR`. + transfer_map : dict or None + Mapping between networks names of the current model and the model to + transfer parameters from. For example, mapping of the form + `{ 'gen_ab' : 'gen' }` will initialize generator `gen_ab` of the + current model from the `gen` generator of the base model. + Default: None. + strict : bool + Value of the pytorch's strict parameter when loading parameters. + Default: True. + allow_partial : bool + Whether to allow transfer from the last checkpoint of a partially + trained base model. + Default: False. + fuzzy : str, optional + Allow fuzzy transfer. E.g. transfer parameters from a larger model to + a smaller one. + Choices: [ 'none', 'from-larger-model' ] + Default: None. + """ + + __slots__ = [ + 'base_model', + 'transfer_map', + 'strict', + 'allow_partial', + 'fuzzy', + ] + + def __init__( + self, + base_model, + transfer_map = None, + strict = True, + allow_partial = False, + fuzzy = None, + ): + # pylint: disable=too-many-arguments + self.base_model = base_model + self.transfer_map = transfer_map or {} + self.strict = strict + self.allow_partial = allow_partial + self.fuzzy = fuzzy + diff --git a/uvcgan_s/consts.py b/uvcgan_s/consts.py new file mode 100644 index 0000000..8e79b95 --- /dev/null +++ b/uvcgan_s/consts.py @@ -0,0 +1,16 @@ +import os + +CONFIG_NAME = 'config.json' +ROOT_DATA = os.path.join(os.environ.get('UVCGAN_S_DATA', 'data')) +ROOT_OUTDIR = os.path.join(os.environ.get('UVCGAN_S_OUTDIR', 'outdir')) + +SPLIT_TRAIN = 'train' +SPLIT_VAL = 'val' +SPLIT_TEST = 'test' + +MERGE_PAIRED = 'paired' +MERGE_UNPAIRED = 'unpaired' +MERGE_NONE = 'none' + +MODEL_STATE_TRAIN = 'train' +MODEL_STATE_EVAL = 'eval' diff --git a/uvcgan_s/data/__init__.py b/uvcgan_s/data/__init__.py new file mode 100644 index 0000000..9b02bd0 --- /dev/null +++ b/uvcgan_s/data/__init__.py @@ -0,0 +1,2 @@ +from .data import construct_data_loaders, construct_datasets + diff --git a/uvcgan_s/data/data.py b/uvcgan_s/data/data.py new file mode 100644 index 0000000..0ddfe66 --- /dev/null +++ b/uvcgan_s/data/data.py @@ -0,0 +1,129 @@ +import os +import torch + +import torchvision + +from uvcgan_s.consts import ( + ROOT_DATA, SPLIT_TRAIN, MERGE_PAIRED, MERGE_UNPAIRED +) +from uvcgan_s.torch.select import extract_name_kwargs + +from .datasets.celeba import CelebaDataset +from .datasets.image_domain_folder import ImageDomainFolder +from .datasets.image_domain_hierarchy import ImageDomainHierarchy +from .datasets.zipper import DatasetZipper +from .datasets.ndarray_domain_hierarchy import NDArrayDomainHierarchy +from .datasets.h5array_domain_hierarchy import H5ArrayDomainHierarchy +from .datasets.toy_mix_blur_dataset import ToyMixBlurDataset + +from .loader_zipper import DataLoaderZipper +from .transforms import select_transform + +def select_dataset(name, path, split, transform, **kwargs): + # pylint: disable=too-many-return-statements + # pylint: disable=too-many-branches + + if name == 'celeba': + return CelebaDataset( + path, transform = transform, split = split, **kwargs + ) + + if name in [ 'cyclegan', 'image-domain-folder' ]: + return ImageDomainFolder( + path, transform = transform, split = split, **kwargs + ) + + if name in [ 'image-domain-hierarchy' ]: + return ImageDomainHierarchy( + path, transform = transform, split = split, **kwargs + ) + + if name == 'imagenet': + return torchvision.datasets.ImageNet( + path, transform = transform, split = split, **kwargs + ) + + if name in [ 'imagedir', 'image-folder' ]: + return torchvision.datasets.ImageFolder( + os.path.join(path, split), transform = transform, **kwargs + ) + + if name == 'ndarray-domain-hierarchy': + return NDArrayDomainHierarchy( + path, transform = transform, split = split, **kwargs + ) + + if name == 'h5array-domain-hierarchy': + return H5ArrayDomainHierarchy( + path, transform = transform, split = split, **kwargs + ) + + if name == 'toy-mix-blur': + return ToyMixBlurDataset( + path, transform = transform, split = split, **kwargs + ) + + raise ValueError(f"Unknown dataset: '{name}'") + +def construct_single_dataset(dataset_config, split): + name, kwargs = extract_name_kwargs(dataset_config.dataset) + path = os.path.join(ROOT_DATA, kwargs.pop('path', name)) + + if split == SPLIT_TRAIN: + transform = select_transform(dataset_config.transform_train) + else: + transform = select_transform(dataset_config.transform_test) + + return select_dataset(name, path, split, transform, **kwargs) + +def construct_datasets(data_config, split): + return [ + construct_single_dataset(config, split) + for config in data_config.datasets + ] + +def construct_single_loader( + dataset, batch_size, shuffle, + workers = None, + prefetch_factor = 2, + **kwargs +): + if workers is None: + workers = min(torch.get_num_threads(), 20) + + return torch.utils.data.DataLoader( + dataset, batch_size, + shuffle = shuffle, + num_workers = workers, + prefetch_factor = prefetch_factor, + pin_memory = True, + **kwargs + ) + +def construct_data_loaders(data_config, batch_size, split): + datasets = construct_datasets(data_config, split) + shuffle = (split == SPLIT_TRAIN) + + if data_config.merge_type == MERGE_PAIRED: + dataset = DatasetZipper(datasets) + + return construct_single_loader( + dataset, batch_size, shuffle, data_config.workers, + drop_last = False + ) + + loaders = [ + construct_single_loader( + dataset, batch_size, shuffle, data_config.workers, + drop_last = (data_config.merge_type == MERGE_UNPAIRED) + ) for dataset in datasets + ] + + if data_config.merge_type == MERGE_UNPAIRED: + return DataLoaderZipper(loaders) + + if len(loaders) == 1: + return loaders[0] + + return loaders + diff --git a/uvcgan_s/data/datasets/__init__.py b/uvcgan_s/data/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/uvcgan_s/data/datasets/celeba.py b/uvcgan_s/data/datasets/celeba.py new file mode 100644 index 0000000..d8f1765 --- /dev/null +++ b/uvcgan_s/data/datasets/celeba.py @@ -0,0 +1,112 @@ +import os +import pandas as pd + +from torch.utils.data import Dataset +from torchvision.datasets.folder import default_loader + +from uvcgan_s.consts import SPLIT_TRAIN, SPLIT_VAL, SPLIT_TEST +from uvcgan_s.utils.funcs import check_value_in_range + +FNAME_ATTRS = 'list_attr_celeba.txt' +FNAME_SPLIT = 'list_eval_partition.txt' +SUBDIR_IMG = 'img_align_celeba' + +SPLITS = { + SPLIT_TRAIN : 0, + SPLIT_VAL : 1, + SPLIT_TEST : 2, +} + +DOMAINS = [ 'a', 'b' ] + +class CelebaDataset(Dataset): + + def __init__( + self, path, + attr = 'Young', + domain = 'a', + split = SPLIT_TRAIN, + transform = None, + **kwargs + ): + # pylint: disable=too-many-arguments + check_value_in_range(split, SPLITS, 'CelebaDataset: split') + + if attr is None: + assert domain is None + else: + check_value_in_range(domain, DOMAINS, 'CelebaDataset: domain') + + super().__init__(**kwargs) + + self._path = path + self._root_imgs = os.path.join(path, SUBDIR_IMG) + self._split = split + self._attr = attr + self._domain = domain + self._imgs = [] + self._transform = transform + + self._collect_files() + + def _collect_files(self): + imgs_specs = CelebaDataset.load_image_specs(self._path) + + imgs = CelebaDataset.partition_images( + imgs_specs, self._split, self._attr, self._domain + ) + + self._imgs = [ os.path.join(self._root_imgs, x) for x in imgs ] + + @staticmethod + def load_image_partition(root): + path = os.path.join(root, FNAME_SPLIT) + + return pd.read_csv( + path, sep = r'\s+', header = None, names = [ 'partition', ], + index_col = 0 + ) + + @staticmethod + def load_image_attrs(root): + path = os.path.join(root, FNAME_ATTRS) + + return pd.read_csv( + path, sep = r'\s+', skiprows = 1, header = 0, index_col = 0 + ) + + @staticmethod + def load_image_specs(root): + df_partition = CelebaDataset.load_image_partition(root) + df_attrs = CelebaDataset.load_image_attrs(root) + + return df_partition.join(df_attrs) + + @staticmethod + def partition_images(image_specs, split, attr, domain): + part_mask = (image_specs.partition == SPLITS[split]) + + if attr is None: + imgs = image_specs[part_mask].index.to_list() + else: + if domain == 'a': + domain_mask = (image_specs[attr] > 0) + else: + domain_mask = (image_specs[attr] < 0) + + imgs = image_specs[part_mask & domain_mask].index.to_list() + + return imgs + + def __len__(self): + return len(self._imgs) + + def __getitem__(self, index): + path = self._imgs[index] + result = default_loader(path) + + if self._transform is not None: + result = self._transform(result) + + return result + diff --git a/uvcgan_s/data/datasets/funcs.py b/uvcgan_s/data/datasets/funcs.py new file mode 100644 index 0000000..d955f5a --- /dev/null +++ b/uvcgan_s/data/datasets/funcs.py @@ -0,0 +1,11 @@ + +def cantor_pairing(x, y, mod = (1 << 32)): + # https://stackoverflow.com/questions/919612/mapping-two-integers-to-one-in-a-unique-and-deterministic-way + # https://en.wikipedia.org/wiki/Pairing_function#Cantor_pairing_function + result = (x + y) * (x + y + 1) // 2 + y + + if mod is not None: + result = result % mod + + return result + diff --git a/uvcgan_s/data/datasets/h5array_domain_hierarchy.py b/uvcgan_s/data/datasets/h5array_domain_hierarchy.py new file mode 100644 index 0000000..910d84c --- /dev/null +++ b/uvcgan_s/data/datasets/h5array_domain_hierarchy.py @@ -0,0 +1,53 @@ +import os + +import numpy as np +from torch.utils.data import Dataset + +from uvcgan_s.consts import SPLIT_TRAIN + +DSET_INDEX = 'index' +DSET_DATA = 'data' + +H5_EXT = [ '', '.h5', '.hdf5' ] + + +class H5ArrayDomainHierarchy(Dataset): + + def __init__( + self, path, domain, + split = SPLIT_TRAIN, + transform = None, + **kwargs + ): + super().__init__(**kwargs) + + self._path = None + path_base = os.path.join(path, split, domain) + + for ext in H5_EXT: + path = path_base + ext + if os.path.exists(path): + self._path = path + break + else: + raise RuntimeError(f"Failed to find h5 dataset '{path_base}'") + + # pylint: disable=import-outside-toplevel + import h5py + + self._f = h5py.File(self._path, 'r') + self._dset = self._f.get(DSET_DATA) + + self._transform = transform + + def __len__(self): + return len(self._dset) + + def __getitem__(self, index): + result = np.float32(self._dset[index]) + + if self._transform is not None: + result = self._transform(result) + + return result + diff --git a/uvcgan_s/data/datasets/image_domain_folder.py b/uvcgan_s/data/datasets/image_domain_folder.py new file mode 100644 index 0000000..d7fee75 --- /dev/null +++ b/uvcgan_s/data/datasets/image_domain_folder.py @@ -0,0 +1,76 @@ +import os + +from torch.utils.data import Dataset +from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS + +from uvcgan_s.consts import SPLIT_TRAIN + +class ImageDomainFolder(Dataset): + """Dataset structure introduced in a CycleGAN paper. + + This dataset expects images to be arranged into subdirectories + under `path`: `trainA`, `trainB`, `testA`, `testB`. Here, `trainA` + subdirectory contains training images from domain "a", `trainB` + subdirectory contains training images from domain "b", and so on. + + Parameters + ---------- + path : str + Path where the dataset is located. + domain : str + Choices: 'a', 'b'. + split : str + Choices: 'train', 'test', 'val' + transform : Callable or None, + Optional transformation to apply to images. + E.g. torchvision.transforms.RandomCrop. + Default: None + """ + + def __init__( + self, path, + domain = 'a', + split = SPLIT_TRAIN, + transform = None, + **kwargs + ): + super().__init__(**kwargs) + + subdir = split + domain.upper() + + self._path = os.path.join(path, subdir) + self._imgs = ImageDomainFolder.find_images_in_dir(self._path) + self._transform = transform + + @staticmethod + def find_images_in_dir(path): + extensions = set(IMG_EXTENSIONS) + + result = [] + for fname in os.listdir(path): + fullpath = os.path.join(path, fname) + + if not os.path.isfile(fullpath): + continue + + ext = os.path.splitext(fname)[1] + if ext not in extensions: + continue + + result.append(fullpath) + + result.sort() + return result + + def __len__(self): + return len(self._imgs) + + def __getitem__(self, index): + path = self._imgs[index] + result = default_loader(path) + + if self._transform is not None: + result = self._transform(result) + + return result + diff --git a/uvcgan_s/data/datasets/image_domain_hierarchy.py b/uvcgan_s/data/datasets/image_domain_hierarchy.py new file mode 100644 index 0000000..bdbeba5 --- /dev/null +++ b/uvcgan_s/data/datasets/image_domain_hierarchy.py @@ -0,0 +1,34 @@ +import os + +from torch.utils.data import Dataset +from torchvision.datasets.folder import default_loader + +from uvcgan_s.consts import SPLIT_TRAIN +from .image_domain_folder import ImageDomainFolder + +class ImageDomainHierarchy(Dataset): + + def __init__( + self, path, domain, + split = SPLIT_TRAIN, + transform = None, + **kwargs + ): + super().__init__(**kwargs) + + self._path = os.path.join(path, split, domain) + self._imgs = ImageDomainFolder.find_images_in_dir(self._path) + self._transform = transform + + def __len__(self): + return len(self._imgs) + + def __getitem__(self, index): + path = self._imgs[index] + result = default_loader(path) + + if self._transform is not None: + result = self._transform(result) + + return result + diff --git a/uvcgan_s/data/datasets/ndarray_domain_hierarchy.py b/uvcgan_s/data/datasets/ndarray_domain_hierarchy.py new file mode 100644 index 0000000..5a05341 --- /dev/null +++ b/uvcgan_s/data/datasets/ndarray_domain_hierarchy.py @@ -0,0 +1,55 @@ +import os + +import numpy as np +from torch.utils.data import Dataset + +from uvcgan_s.consts import SPLIT_TRAIN + +def find_ndarrays_in_dir(path): + result = [] + + for fname in os.listdir(path): + fullpath = os.path.join(path, fname) + + if not os.path.isfile(fullpath): + continue + + ext = os.path.splitext(fname)[1] + if ext != '.npz': + continue + + result.append(fullpath) + + result.sort() + return result + +def load_ndarray(path): + with np.load(path) as f: + return f[f.files[0]] + +class NDArrayDomainHierarchy(Dataset): + + def __init__( + self, path, domain, + split = SPLIT_TRAIN, + transform = None, + **kwargs + ): + super().__init__(**kwargs) + + self._path = os.path.join(path, split, domain) + self._arrays = find_ndarrays_in_dir(self._path) + self._transform = transform + + def __len__(self): + return len(self._arrays) + + def __getitem__(self, index): + path = self._arrays[index] + result = np.float32(load_ndarray(path)) + + if self._transform is not None: + result = self._transform(result) + + return result + diff --git a/uvcgan_s/data/datasets/svhn.py b/uvcgan_s/data/datasets/svhn.py new file mode 100644 index 0000000..ecd38a2 --- /dev/null +++ b/uvcgan_s/data/datasets/svhn.py @@ -0,0 +1,26 @@ +from torchvision.datasets import SVHN + +class SVHNDataset(SVHN): + + def __init__( + self, path, split, transform, return_target = False, **kwargs + ): + # pylint: disable=too-many-arguments + super().__init__( + path, + split = split, + transform = transform, + download = True, + **kwargs + ) + + self._return_target = return_target + + def __getitem__(self, index): + item = super().__getitem__(index) + + if self._return_target: + return item + + return item[0] + diff --git a/uvcgan_s/data/datasets/toy_mix_blur_dataset.py b/uvcgan_s/data/datasets/toy_mix_blur_dataset.py new file mode 100644 index 0000000..b80e788 --- /dev/null +++ b/uvcgan_s/data/datasets/toy_mix_blur_dataset.py @@ -0,0 +1,84 @@ +import os +import random + +import torchvision.transforms.functional as TF + +from torch.utils.data import Dataset +from torchvision.datasets.folder import default_loader + +from uvcgan_s.consts import SPLIT_TRAIN +from .image_domain_folder import ImageDomainFolder + + +class ToyMixBlurDataset(Dataset): + # pylint: disable=too-many-instance-attributes + + def __init__( + self, path, + domain_a0 = 'cat', + domain_a1 = 'dog', + split = SPLIT_TRAIN, + alpha = 0.5, + alpha_range = 0.0, + blur_kernel_size = 5, + seed = 42, + transform = None, + **kwargs + ): + # pylint: disable=too-many-arguments + super().__init__(**kwargs) + + self._split = split + self._alpha = alpha + self._alpha_range = alpha_range + self._blur_ks = blur_kernel_size + self._transform = transform + + path_a0 = os.path.join(path, split, domain_a0) + path_a1 = os.path.join(path, split, domain_a1) + + self._imgs_a0 = ImageDomainFolder.find_images_in_dir(path_a0) + self._imgs_a1 = ImageDomainFolder.find_images_in_dir(path_a1) + + self._rng = random.Random(seed) + + self._pairing_indices = list(range(len(self._imgs_a1))) + self._rng.shuffle(self._pairing_indices) + + def __len__(self): + return len(self._imgs_a0) + + def _get_alpha(self): + if self._split == SPLIT_TRAIN: + offset = random.uniform( + -self._alpha_range / 2, self._alpha_range / 2 + ) + return self._alpha + offset + else: + return self._alpha + + def _apply_blur(self, img): + return TF.gaussian_blur(img, kernel_size=self._blur_ks) + + def __getitem__(self, index): + + if self._split == SPLIT_TRAIN: + index_a1 = random.randint(0, len(self._imgs_a1) - 1) + else: + index_a1 = self._pairing_indices[ + index % len(self._pairing_indices) + ] + + img_a0 = default_loader(self._imgs_a0[index]) + img_a1 = default_loader(self._imgs_a1[index_a1]) + + if self._transform is not None: + img_a0 = self._transform(img_a0) + img_a1 = self._transform(img_a1) + + alpha = self._get_alpha() + mixed = alpha * img_a0 + (1.0 - alpha) * img_a1 + mixed = self._apply_blur(mixed) + + return mixed + diff --git a/uvcgan_s/data/datasets/zipper.py b/uvcgan_s/data/datasets/zipper.py new file mode 100644 index 0000000..bb89129 --- /dev/null +++ b/uvcgan_s/data/datasets/zipper.py @@ -0,0 +1,24 @@ +from torch.utils.data import Dataset + +class DatasetZipper(Dataset): + + def __init__(self, datasets, **kwargs): + super().__init__(**kwargs) + + assert len(datasets) > 0, \ + "DatasetZipper does not know how to zip empty list of datasets" + + self._datasets = datasets + self._len = len(datasets[0]) + + lengths = [ len(dset) for dset in datasets ] + + assert all(x == self._len for x in lengths), \ + f"DatasetZipper cannot zip datasets of unequal lengths: {lengths}" + + def __len__(self): + return self._len + + def __getitem__(self, index): + return tuple(d[index] for d in self._datasets) + diff --git a/uvcgan_s/data/loader_zipper.py b/uvcgan_s/data/loader_zipper.py new file mode 100644 index 0000000..db3d3a0 --- /dev/null +++ b/uvcgan_s/data/loader_zipper.py @@ -0,0 +1,12 @@ + +class DataLoaderZipper: + + def __init__(self, loaders): + self._loaders = loaders + + def __len__(self): + return min(len(d) for d in self._loaders) + + def __iter__(self): + return zip(*self._loaders) + diff --git a/uvcgan_s/data/transforms.py b/uvcgan_s/data/transforms.py new file mode 100644 index 0000000..67b7281 --- /dev/null +++ b/uvcgan_s/data/transforms.py @@ -0,0 +1,105 @@ +import torch +import torchvision +from torchvision import transforms + +from uvcgan_s.torch.select import extract_name_kwargs + +FromNumpy = lambda : torch.from_numpy + +TRANSFORM_DICT = { + 'center-crop' : transforms.CenterCrop, + 'color-jitter' : transforms.ColorJitter, + 'random-crop' : transforms.RandomCrop, + 'random-flip-vertical' : transforms.RandomVerticalFlip, + 'random-flip-horizontal' : transforms.RandomHorizontalFlip, + 'random-rotation' : transforms.RandomRotation, + 'random-resize-crop' : transforms.RandomResizedCrop, + 'random-solarize' : transforms.RandomSolarize, + 'random-invert' : transforms.RandomInvert, + 'gaussian-blur' : transforms.GaussianBlur, + 'resize' : transforms.Resize, + 'normalize' : transforms.Normalize, + 'pad' : transforms.Pad, + 'grayscale' : transforms.Grayscale, + 'to-tensor' : transforms.ToTensor, + 'from-numpy' : FromNumpy, + 'CenterCrop' : transforms.CenterCrop, + 'ColorJitter' : transforms.ColorJitter, + 'RandomCrop' : transforms.RandomCrop, + 'RandomVerticalFlip' : transforms.RandomVerticalFlip, + 'RandomHorizontalFlip' : transforms.RandomHorizontalFlip, + 'RandomRotation' : transforms.RandomRotation, + 'Resize' : transforms.Resize, +} + +INTERPOLATION_DICT = { + 'nearest' : transforms.InterpolationMode.NEAREST, + 'bilinear' : transforms.InterpolationMode.BILINEAR, + 'bicubic' : transforms.InterpolationMode.BICUBIC, + 'lanczos' : transforms.InterpolationMode.LANCZOS, +} + +def parse_interpolation(kwargs): + if 'interpolation' in kwargs: + kwargs['interpolation'] = INTERPOLATION_DICT[kwargs['interpolation']] + +def select_single_transform(transform): + name, kwargs = extract_name_kwargs(transform) + + if name == 'random-apply': + transform = select_transform_basic(kwargs.pop('transforms')) + return transforms.RandomApply(transform, **kwargs) + + if name not in TRANSFORM_DICT: + raise ValueError(f"Unknown transform: '{name}'") + + parse_interpolation(kwargs) + + return TRANSFORM_DICT[name](**kwargs) + +def select_transform_basic(transform, compose = False): + result = [] + + if transform is not None: + if not isinstance(transform, (list, tuple)): + transform = [ transform, ] + + result = [ + select_single_transform(x) for x in transform if x != 'none' + ] + + if compose: + if len(result) == 1: + return result[0] + else: + return torchvision.transforms.Compose(result) + else: + return result + +def select_transform(transform, add_to_tensor = True): + if transform == 'none': + return None + + result = select_transform_basic(transform) + + # NOTE: this uglinness is for backward compat + if add_to_tensor: + if not isinstance(transform, (list, tuple)): + transform = [ transform, ] + + need_transform = True + + if any(t == 'to-tensor' for t in transform): + need_transform = False + + if any(t == 'from-numpy' for t in transform): + need_transform = False + + if any(t == 'none' for t in transform): + need_transform = False + + if need_transform: + result.append(torchvision.transforms.ToTensor()) + + return torchvision.transforms.Compose(result) + diff --git a/uvcgan_s/eval/__init__.py b/uvcgan_s/eval/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/uvcgan_s/eval/funcs.py b/uvcgan_s/eval/funcs.py new file mode 100644 index 0000000..999506c --- /dev/null +++ b/uvcgan_s/eval/funcs.py @@ -0,0 +1,112 @@ +import os +import math +from itertools import islice + +from uvcgan_s.config import Args +from uvcgan_s.consts import ( + MODEL_STATE_TRAIN, MODEL_STATE_EVAL, MERGE_NONE +) +from uvcgan_s.data import construct_data_loaders +from uvcgan_s.torch.funcs import get_torch_device_smart, seed_everything +from uvcgan_s.cgan import construct_model + +def slice_data_loader(loader, batch_size, n_samples = None): + if n_samples is None: + return (loader, len(loader)) + + steps = min(math.ceil(n_samples / batch_size), len(loader)) + sliced_loader = islice(loader, steps) + + return (sliced_loader, steps) + +def tensor_to_image(tensor): + result = tensor.cpu().detach().numpy() + + if tensor.ndim == 4: + result = result.squeeze(0) + + result = result.transpose((1, 2, 0)) + return result + +def override_config(config, config_overrides): + if config_overrides is None: + return + + for (k,v) in config_overrides.items(): + config[k] = v + +def get_evaldir(root, epoch, mkdir = False): + if epoch is None: + result = os.path.join(root, 'evals', 'final') + else: + result = os.path.join(root, 'evals', 'epoch_%04d' % epoch) + + if mkdir: + os.makedirs(result, exist_ok = True) + + return result + +def set_model_state(model, state): + if state == MODEL_STATE_TRAIN: + model.train() + elif state == MODEL_STATE_EVAL: + model.eval() + else: + raise ValueError(f"Unknown model state '{state}'") + +def start_model_eval(path, epoch, model_state, merge_type, **config_overrides): + args = Args.load(path) + device = get_torch_device_smart() + + override_config(args.config, config_overrides) + + if merge_type is not None: + args.config.data.merge_type = merge_type + + model = construct_model( + args.savedir, args.config, + is_train = (model_state == MODEL_STATE_TRAIN), + device = device + ) + + if epoch == -1: + epoch = max(model.find_last_checkpoint_epoch(), 0) + + print("Load checkpoint at epoch %s" % epoch) + + seed_everything(args.config.seed) + model.load(epoch) + + set_model_state(model, model_state) + evaldir = get_evaldir(path, epoch, mkdir = True) + + return (args, model, evaldir) + +def load_eval_model_dset_from_cmdargs( + cmdargs, merge_type = MERGE_NONE, **config_overrides +): + args, model, evaldir = start_model_eval( + cmdargs.model, cmdargs.epoch, cmdargs.model_state, + merge_type = merge_type, + batch_size = cmdargs.batch_size, **config_overrides + ) + + data_it = construct_data_loaders( + args.config.data, args.config.batch_size, split = cmdargs.split + ) + + return (args, model, data_it, evaldir) + +def get_eval_savedir(evaldir, prefix, model_state, split, mkdir = False): + result = os.path.join(evaldir, f'{prefix}_{model_state}-{split}') + + if mkdir: + os.makedirs(result, exist_ok = True) + + return result + +def make_image_subdirs(model, savedir): + for name in model.images: + path = os.path.join(savedir, name) + os.makedirs(path, exist_ok = True) + diff --git a/uvcgan_s/models/__init__.py b/uvcgan_s/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/uvcgan_s/models/discriminator/__init__.py b/uvcgan_s/models/discriminator/__init__.py new file mode 100644 index 0000000..e701fb9 --- /dev/null +++ b/uvcgan_s/models/discriminator/__init__.py @@ -0,0 +1,26 @@ +from uvcgan_s.base.networks import select_base_discriminator +from uvcgan_s.models.funcs import default_model_init + +from .resnet import ResNetDisc, ResNetFFTDisc +from .dcgan import DCGANDiscriminator + +DISC_DICT = { + 'resnet' : ResNetDisc, + 'resnet-fft' : ResNetFFTDisc, + 'dcgan' : DCGANDiscriminator, +} + +def select_discriminator(name, **kwargs): + if name in DISC_DICT: + return DISC_DICT[name](**kwargs) + + return select_base_discriminator(name, **kwargs) + +def construct_discriminator(model_config, image_shape, device): + model = select_discriminator( + model_config.model, image_shape = image_shape, + **model_config.model_args + ) + + return default_model_init(model, model_config, device) + diff --git a/uvcgan_s/models/discriminator/dcgan.py b/uvcgan_s/models/discriminator/dcgan.py new file mode 100644 index 0000000..0ae08e4 --- /dev/null +++ b/uvcgan_s/models/discriminator/dcgan.py @@ -0,0 +1,121 @@ +import math +import logging + +from torch import nn +from torchvision.transforms import CenterCrop + +from uvcgan_s.torch.select import get_activ_layer, get_norm_layer + +LOGGER = logging.getLogger('models.discriminator.dcgan') + +DEF_NORM = 'batch' +DEF_ACTIV = { + 'name' : 'leakyrelu', + 'negative_slope' : 0.2, +} + +def math_prod(shape): + result = 1 + + for x in shape: + result *= x + + return result + +def get_padding_layer(input_shape, downscale_factor): + need_pad = False + + if ( + (input_shape[1] < downscale_factor) + or (input_shape[2] < downscale_factor) + ): + need_pad = True + + LOGGER.warning( + "DCGAN input shape '%s' is smaller than the downscale factor '%d'." + " Adding padding.", tuple(input_shape), downscale_factor + ) + + if ( + (input_shape[1] % downscale_factor != 0) + or (input_shape[2] % downscale_factor != 0) + ): + need_pad = True + + LOGGER.warning( + "DCGAN input shape '%s' is not divisible by the downscale " + " factor '%d'. Adding padding.", + tuple(input_shape), downscale_factor + ) + + h = math.ceil(input_shape[1] / downscale_factor) * downscale_factor + w = math.ceil(input_shape[2] / downscale_factor) * downscale_factor + + if not need_pad: + return None, h, w + + return CenterCrop((h, w)), h ,w + +class DCGANDiscriminator(nn.Module): + + def __init__( + self, image_shape, features_list, activ = DEF_ACTIV, norm = DEF_NORM, + ): + # pylint: disable=dangerous-default-value + # pylint: disable=too-many-arguments + super().__init__() + + self._input_shape = image_shape + + curr_features = image_shape[0] + downscale_factor = 1 + layers = [] + + layers.append(nn.Sequential( + nn.Conv2d( + curr_features, features_list[0], kernel_size = 5, + stride = 2, padding = 2 + ), + get_activ_layer(activ) + )) + + downscale_factor *= 2 + curr_features = features_list[0] + + for features in features_list[1:]: + layers.append(nn.Sequential( + nn.Conv2d( + curr_features, features, kernel_size = 5, + stride = 2, padding = 2 + ), + get_norm_layer(norm, features), + get_activ_layer(activ) + )) + + downscale_factor *= 2 + curr_features = features + + padding, h, w = get_padding_layer(image_shape, downscale_factor) + curr_shape = ( + curr_features, h // downscale_factor, w // downscale_factor + ) + + if padding is not None: + layers = [ padding, ] + layers + + self._net_main = nn.Sequential(*layers) + dense_features = math_prod(curr_shape) + + self._net_output = nn.Sequential( + nn.Flatten(), + nn.Linear(dense_features, 1), + ) + + def forward(self, x): + # x : (N, C, H, W) + + # y : (N, Ci, Hi, Wi) + y = self._net_main(x) + + return self._net_output(y) + diff --git a/uvcgan_s/models/discriminator/resnet.py b/uvcgan_s/models/discriminator/resnet.py new file mode 100644 index 0000000..ded2642 --- /dev/null +++ b/uvcgan_s/models/discriminator/resnet.py @@ -0,0 +1,99 @@ +import torch +from torch import nn + +from uvcgan_s.torch.select import get_activ_layer +from uvcgan_s.torch.layers.resnet import ResNetEncoder + +class ResNetDisc(nn.Module): + + def __init__( + self, image_shape, block_specs, activ, norm, + rezero = True, activ_output = None, reduce_output_channels = True + ): + # pylint: disable=too-many-arguments + super().__init__() + + self._net = ResNetEncoder( + image_shape, block_specs, activ, norm, rezero + ) + + self._output_shape = self._net.output_shape + self._input_shape = image_shape + + output_layers = [] + + if reduce_output_channels: + output_layers.append( + nn.Conv2d(self._output_shape[0], 1, kernel_size = 1) + ) + self._output_shape = (1, *self._output_shape[1:]) + + if activ_output is not None: + output_layers.append(get_activ_layer(activ_output)) + + self._out = nn.Sequential(*output_layers) + + @property + def input_shape(self): + return self._input_shape + + @property + def output_shape(self): + return self._output_shape + + def forward(self, x): + y = self._net(x) + return self._out(y) + +class ResNetFFTDisc(nn.Module): + + def __init__( + self, image_shape, block_specs, activ, norm, + rezero = True, activ_output = None, reduce_output_channels = True + ): + # pylint: disable=too-many-arguments + super().__init__() + + # Doubling channels for FFT amplitudes + input_shape = (2 * image_shape[0], *image_shape[1:]) + + self._net = ResNetEncoder( + input_shape, block_specs, activ, norm, rezero + ) + + self._output_shape = self._net.output_shape + self._input_shape = image_shape + + output_layers = [] + + if reduce_output_channels: + output_layers.append( + nn.Conv2d(self._output_shape[0], 1, kernel_size = 1) + ) + self._output_shape = (1, *self._output_shape[1:]) + + if activ_output is not None: + output_layers.append(get_activ_layer(activ_output)) + + self._out = nn.Sequential(*output_layers) + + @property + def input_shape(self): + return self._input_shape + + @property + def output_shape(self): + return self._output_shape + + def forward(self, x): + # x : (N, C, H, W) + # x_fft : (N, C, H, W) + x_fft = torch.fft.fft2(x).abs() + + # z : (N, 2C, H, W) + z = torch.cat([ x, x_fft ], dim = 1) + + y = self._net(z) + + return self._out(y) + diff --git a/uvcgan_s/models/funcs.py b/uvcgan_s/models/funcs.py new file mode 100644 index 0000000..86e1efb --- /dev/null +++ b/uvcgan_s/models/funcs.py @@ -0,0 +1,17 @@ +from uvcgan_s.base.weight_init import init_weights +from uvcgan_s.torch.funcs import prepare_model +from uvcgan_s.torch.lr_equal import apply_lr_equal +from uvcgan_s.torch.spectr_norm import apply_sn + +def default_model_init(model, model_config, device): + model = prepare_model(model, device) + init_weights(model, model_config.weight_init) + + if model_config.lr_equal: + apply_lr_equal(model) + + if model_config.spectr_norm: + apply_sn(model) + + return model + diff --git a/uvcgan_s/models/generator/__init__.py b/uvcgan_s/models/generator/__init__.py new file mode 100644 index 0000000..e5afc02 --- /dev/null +++ b/uvcgan_s/models/generator/__init__.py @@ -0,0 +1,42 @@ +from uvcgan_s.base.networks import select_base_generator +from uvcgan_s.models.funcs import default_model_init + +from .vitunet import ViTUNetGenerator +from .vitmodnet import ViTModNetGenerator, CViTModNetGenerator +from .resnet import ResNetGen +from .dcgan import DCGANGenerator + +def select_generator(name, **kwargs): + # pylint: disable=too-many-return-statements + + if name == 'vit-unet': + return ViTUNetGenerator(**kwargs) + + if name == 'vit-modnet': + return ViTModNetGenerator(**kwargs) + + if name == 'cvit-modnet': + return CViTModNetGenerator(**kwargs) + + if name == 'resnet': + return ResNetGen(**kwargs) + + if name == 'dcgan': + return DCGANGenerator(**kwargs) + + input_shape = kwargs.pop('input_shape') + output_shape = kwargs.pop('output_shape') + + assert input_shape == output_shape + return select_base_generator(name, image_shape = input_shape, **kwargs) + +def construct_generator(model_config, input_shape, output_shape, device): + model = select_generator( + model_config.model, + input_shape = input_shape, + output_shape = output_shape, + **model_config.model_args + ) + + return default_model_init(model, model_config, device) + diff --git a/uvcgan_s/models/generator/dcgan.py b/uvcgan_s/models/generator/dcgan.py new file mode 100644 index 0000000..ef88c84 --- /dev/null +++ b/uvcgan_s/models/generator/dcgan.py @@ -0,0 +1,98 @@ +# import math +import logging + +from torch import nn +from torchvision.transforms import CenterCrop + +from uvcgan_s.torch.select import get_activ_layer, get_norm_layer + +LOGGER = logging.getLogger('models.generator.dcgan') + +DEF_FEATURES = 1024 +DEF_NORM = 'batch' +DEF_ACTIV = 'relu' + +def math_prod(shape): + result = 1 + + for x in shape: + result *= x + + return result + +class DCGANGenerator(nn.Module): + + def __init__( + self, input_shape, output_shape, features_list, + activ = DEF_ACTIV, + norm = DEF_NORM, + activ_output = None + ): + # pylint: disable=dangerous-default-value + # pylint: disable=too-many-arguments + super().__init__() + + self._input_shae = input_shape + self._output_shape = output_shape + + # to reshape into (2, 2) + dense_features = 4 * features_list[0] + + self._net_in = nn.Sequential( + nn.Flatten(), + nn.Linear(math_prod(input_shape), dense_features), + get_activ_layer(activ), + ) + + layers = [] + + curr_shape = (features_list[0], 2, 2) + + for features in features_list[1:]: + layers.append(nn.Sequential( + nn.ConvTranspose2d( + curr_shape[0], features, + kernel_size = 5, stride = 2, padding = 2, + output_padding = 1, + ), + get_norm_layer(norm, features), + get_activ_layer(activ), + )) + + curr_shape = (features, 2 * curr_shape[1], 2 * curr_shape[2]) + + layers.append(nn.Sequential( + nn.ConvTranspose2d( + curr_shape[0], output_shape[0], + kernel_size = 5, stride = 2, padding = 2, + output_padding = 1, + ), + get_activ_layer(activ_output), + )) + + curr_shape = (output_shape[0], 2 * curr_shape[1], 2 * curr_shape[2]) + + if curr_shape != tuple(output_shape): + LOGGER.warning( + "DCGAN output shape '%s' is not equal to the expected output" + " shape '%s'. Adding center cropping.", + curr_shape, tuple(output_shape) + ) + + layers.append(CenterCrop(tuple(output_shape[1:]))) + + self._net_main = nn.Sequential(*layers) + + def forward(self, z): + # z : (N, F) + + # x : (N, 4 * C) + x = self._net_in(z) + + # x : (N, C, 2, 2) + x = x.reshape((x.shape[0], -1, 2, 2)) + + result = self._net_main(x) + + return result + diff --git a/uvcgan_s/models/generator/resnet.py b/uvcgan_s/models/generator/resnet.py new file mode 100644 index 0000000..466058b --- /dev/null +++ b/uvcgan_s/models/generator/resnet.py @@ -0,0 +1,40 @@ +from torch import nn + +from uvcgan_s.torch.select import get_activ_layer +from uvcgan_s.torch.layers.resnet import ResNetEncoder + +class ResNetGen(nn.Module): + + def __init__( + self, input_shape, output_shape, block_specs, activ, norm, + rezero = True, activ_output = None + ): + # pylint: disable=too-many-arguments + super().__init__() + + self._net = ResNetEncoder( + input_shape, block_specs, activ, norm, rezero + ) + + self._output_shape = self._net.output_shape + self._input_shape = input_shape + + assert tuple(output_shape) == self._output_shape, ( + f"Output shape {self._output_shape}" + f" != desired shape {output_shape}" + ) + + self._out = get_activ_layer(activ_output) + + @property + def input_shape(self): + return self._input_shape + + @property + def output_shape(self): + return self._output_shape + + def forward(self, x): + y = self._net(x) + return self._out(y) + diff --git a/uvcgan_s/models/generator/vit.py b/uvcgan_s/models/generator/vit.py new file mode 100644 index 0000000..53ad83e --- /dev/null +++ b/uvcgan_s/models/generator/vit.py @@ -0,0 +1,78 @@ +# pylint: disable=too-many-arguments +# pylint: disable=too-many-instance-attributes + +import numpy as np +from torch import nn + +from uvcgan_s.torch.layers.transformer import ( + calc_tokenized_size, ViTInput, TransformerEncoder, img_to_tokens, + img_from_tokens +) + +class ViTGenerator(nn.Module): + + def __init__( + self, features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, input_shape, output_shape, token_size, + rescale = False, rezero = True, **kwargs + ): + super().__init__(**kwargs) + + assert input_shape == output_shape + image_shape = input_shape + + self.image_shape = image_shape + self.token_size = token_size + self.token_shape = (image_shape[0], *token_size) + self.token_features = np.prod([image_shape[0], *token_size]) + self.N_h, self.N_w = calc_tokenized_size(image_shape, token_size) + self.rescale = rescale + + self.gan_input = ViTInput( + self.token_features, embed_features, features, self.N_h, self.N_w + ) + + self.trans = TransformerEncoder( + features, ffn_features, n_heads, n_blocks, activ, norm, rezero + ) + + self.gan_output = nn.Linear(features, self.token_features) + + # pylint: disable=no-self-use + def calc_scale(self, x): + # x : (N, C, H, W) + return x.abs().mean(dim = (1, 2, 3), keepdim = True) + 1e-8 + + def forward(self, x): + # x : (N, C, H, W) + if self.rescale: + scale = self.calc_scale(x) + x = x / scale + + # itokens : (N, N_h, N_w, C, H_c, W_c) + itokens = img_to_tokens(x, self.token_shape[1:]) + + # itokens : (N, N_h, N_w, C, H_c, W_c) + # -> (N, N_h * N_w, C * H_c * W_c) + # = (N, L, in_features) + itokens = itokens.reshape((itokens.shape[0], self.N_h * self.N_w, -1)) + + # y : (N, L, features) + y = self.gan_input(itokens) + y = self.trans(y) + + # otokens : (N, L, in_features) + otokens = self.gan_output(y) + + # otokens : (N, L, in_features) + # -> (N, N_h, N_w, C, H_c, W_c) + otokens = otokens.reshape(( + otokens.shape[0], self.N_h, self.N_w, *self.token_shape + )) + + result = img_from_tokens(otokens) + if self.rescale: + result = result * scale + + return result + diff --git a/uvcgan_s/models/generator/vitgan.py b/uvcgan_s/models/generator/vitgan.py new file mode 100644 index 0000000..0796870 --- /dev/null +++ b/uvcgan_s/models/generator/vitgan.py @@ -0,0 +1,216 @@ +# pylint: disable=too-many-arguments +# pylint: disable=too-many-instance-attributes + +import torch + +import numpy as np +from torch import nn + +from uvcgan_s.torch.select import select_activation +from uvcgan_s.torch.layers.transformer import ( + calc_tokenized_size, img_to_tokens, img_from_tokens, + ViTInput, TransformerEncoder, FourierEmbedding +) + +class ModulatedLinear(nn.Module): + # arXiv: 2011.13775 + + def __init__( + self, in_features, out_features, w_features, activ, eps = 1e-8, + **kwargs + ): + super().__init__(**kwargs) + + self.weight = nn.Parameter(torch.empty((out_features, in_features))) + self.bias = nn.Parameter(torch.empty((out_features,))) + self._eps = eps + + self.A = nn.Linear(w_features, in_features) + self.activ = select_activation(activ) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_normal_(self.weight) + nn.init.zeros_(self.bias) + + def get_modulated_weight(self, s): + # s : (N, in_features) + + # w : (1, out_features, in_features) + w = self.weight.unsqueeze(0) + + # s : (N, in_features) + # -> (N, 1, in_features) + s = s.unsqueeze(1) + + # weight : (N, out_features, in_features) + weight = w * s + + # norm : (N, out_features, 1) + norm = torch.rsqrt( + self._eps + torch.sum(weight**2, dim = 2, keepdim = True) + ) + + return weight * norm + + def forward(self, x, w): + # x : (N, L, in_features) + # w : (N, w_features) + + # s : (N, in_features) + s = self.A(w) + + # weight : (N, out_features, in_features) + weight = self.get_modulated_weight(s) + + # weight : (N, out_features, in_features) + # x : (N, L, in_features) + # result : (N, L, out_features) + result = torch.matmul(weight.unsqueeze(1), x.unsqueeze(-1)).squeeze(-1) + + result = result + self.bias + + return self.activ(result) + +class ViTGANOutput(nn.Module): + + def __init__(self, features, token_shape, activ, **kwargs): + super().__init__(**kwargs) + + self.embed = ViTGANInput(features, token_shape[1], token_shape[2]) + + self.fc1 = ModulatedLinear(features, features, features, activ) + self.fc2 = ModulatedLinear(features, token_shape[0], features, None) + + self._token_shape = token_shape + + def _map_fn(self, x, w): + # x : (N, H_t * W_t, features) + # w : (N, features) + + # result : (N, H_t * W_t, features) + result = self.fc1(x, w) + + # result : (N, H_t * W_t, C) + result = self.fc2(result, w) + + # result : (N, H_t * W_t, C) + # -> (N, C, H_t * W_t) + result = result.permute(0, 2, 1) + + # (N, 1, C, H_t, W_t) + return result.reshape( + (result.shape[0], 1, result.shape[1], *self._token_shape[1:]) + ) + + def forward(self, y): + # y : (N, L, features) + # e : (N, H_t * W_t, features) + e = self.embed(len(y)) + + # result : (N, L, C, H_t, W_t) + result = torch.stack( + [ self._map_fn(e, w) for w in torch.unbind(y, dim = 1) ], + dim = 1 + ) + + return result + +class ViTGANInput(nn.Module): + + def __init__(self, features, height, width, **kwargs): + super().__init__(**kwargs) + self._height = height + self._width = width + + x = torch.arange(width).to(torch.float32) + y = torch.arange(height).to(torch.float32) + + x, y = torch.meshgrid(x, y) + self.x = x.reshape((1, -1)) + self.y = y.reshape((1, -1)) + + self.register_buffer('x_const', self.x) + self.register_buffer('y_const', self.y) + + self.embed = FourierEmbedding(features, height, width) + + def forward(self, batch_size = None): + # result : (1, height * width, features) + result = self.embed(self.y_const, self.x_const) + + if batch_size is not None: + # result : (1, height * width, features) + # -> (batch_size, height * width, features) + result = result.expand((batch_size, *result.shape[1:])) + + return result + +class ViTGANGenerator(nn.Module): + + def __init__( + self, features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, input_shape, output_shape, token_size, rescale = False, + rezero = True, **kwargs + ): + super().__init__(**kwargs) + + assert input_shape == output_shape + image_shape = input_shape + + self.image_shape = image_shape + self.token_size = token_size + self.token_shape = (image_shape[0], *token_size) + self.token_features = np.prod([image_shape[0], *token_size]) + self.N_h, self.N_w = calc_tokenized_size(image_shape, token_size) + self.rescale = rescale + + self.gan_input = ViTInput( + self.token_features, embed_features, features, self.N_h, self.N_w + ) + + self.encoder = TransformerEncoder( + features, ffn_features, n_heads, n_blocks, activ, norm, rezero + ) + + self.gan_output = ViTGANOutput(features, self.token_shape, 'relu') + + # pylint: disable=no-self-use + def calc_scale(self, x): + # x : (N, C, H, W) + return x.abs().mean(dim = (1, 2, 3), keepdim = True) + 1e-8 + + def forward(self, x): + # x : (N, C, H, W) + if self.rescale: + scale = self.calc_scale(x) + x = x / scale + + # itokens : (N, N_h, N_w, C, H_c, W_c) + itokens = img_to_tokens(x, self.token_shape[1:]) + + # itokens : (N, N_h, N_w, C, H_c, W_c) + # -> (N, N_h * N_w, C * H_c * W_c) + # = (N, L, in_features) + itokens = itokens.reshape((itokens.shape[0], self.N_h * self.N_w, -1)) + + # y : (N, L, features) + y = self.gan_input(itokens) + y = self.trans(y) + + # otokens : (N, L, C, H_t, W_t) + otokens = self.gan_output(y) + + # otokens : (N, L, C, H_t, W_t) + # -> (N, N_h, N_w, C, H_c, W_c) + otokens = otokens.reshape( + (otokens.shape[0], self.N_h, self.N_w, *otokens.shape[3:]) + ) + + result = img_from_tokens(otokens) + if self.rescale: + result = result * scale + + return result + diff --git a/uvcgan_s/models/generator/vithybrid.py b/uvcgan_s/models/generator/vithybrid.py new file mode 100644 index 0000000..5c33a56 --- /dev/null +++ b/uvcgan_s/models/generator/vithybrid.py @@ -0,0 +1,137 @@ +# pylint: disable=too-many-arguments +# pylint: disable=too-many-instance-attributes + +from torch import nn + +from uvcgan_s.torch.select import get_norm_layer, get_activ_layer +from uvcgan_s.torch.layers.transformer import PixelwiseViT +from uvcgan_s.torch.layers.cnn import ( + get_downsample_x2_layer, get_upsample_x2_layer +) + +def construct_downsample_stem( + features_list, activ, norm, downsample, image_shape +): + result = nn.Sequential() + + result.add_module( + "downsample_base", + nn.Sequential( + nn.Conv2d( + image_shape[0], features_list[0], kernel_size = 3, padding = 1 + ), + get_activ_layer(activ), + ) + ) + + prev_features = features_list[0] + curr_size = image_shape[1:] + + for idx,features in enumerate(features_list): + layer_down, next_features = \ + get_downsample_x2_layer(downsample, features) + + result.add_module( + "downsample_block_%d" % idx, + nn.Sequential( + get_norm_layer(norm, prev_features), + nn.Conv2d( + prev_features, features, kernel_size = 3, padding = 1 + ), + get_activ_layer(activ), + + layer_down, + ) + ) + + prev_features = next_features + curr_size = (curr_size[0] // 2, curr_size[1] // 2) + + return result, prev_features, curr_size + +def construct_upsample_stem( + features_list, input_features, activ, norm, upsample, image_shape +): + result = nn.Sequential() + prev_features = input_features + + for idx,features in reversed(list(enumerate(features_list))): + layer_up, next_features \ + = get_upsample_x2_layer(upsample, prev_features) + + result.add_module( + "upsample_block_%d" % idx, + nn.Sequential( + layer_up, + + get_norm_layer(norm, next_features), + nn.Conv2d( + next_features, features, kernel_size = 3, padding = 1 + ), + get_activ_layer(activ) + ) + ) + + prev_features = features + + + result.add_module( + "upsample_base", + nn.Sequential( + nn.Conv2d(prev_features, image_shape[0], kernel_size = 1), + ) + ) + + return result + +class ViTHybridGenerator(nn.Module): + + def __init__( + self, features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, input_shape, output_shape, + stem_features_list, stem_activ, stem_norm, + stem_downsample = 'conv', + stem_upsample = 'upsample-conv', + rezero = True, + **kwargs + ): + # pylint: disable=too-many-locals + super().__init__(**kwargs) + + assert input_shape == output_shape + image_shape = input_shape + + self.image_shape = image_shape + + self.stem_down, self.token_features, output_size = \ + construct_downsample_stem( + stem_features_list, stem_activ, stem_norm, stem_downsample, + image_shape + ) + + self.N_h, self.N_w = output_size + + self.bottleneck = PixelwiseViT( + features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, + image_shape = (self.token_features, self.N_h, self.N_w), + rezero = rezero + ) + + self.stem_up = construct_upsample_stem( + stem_features_list, self.token_features, stem_activ, stem_norm, + stem_upsample, image_shape + ) + + def forward(self, x): + # x : (N, C, H, W) + + # z : (N, token_features, N_h, N_w) + z = self.stem_down(x) + z = self.bottleneck(z) + + # result : (N, C, H, W) + result = self.stem_up(z) + + return result + diff --git a/uvcgan_s/models/generator/vitmodnet.py b/uvcgan_s/models/generator/vitmodnet.py new file mode 100644 index 0000000..54fbf71 --- /dev/null +++ b/uvcgan_s/models/generator/vitmodnet.py @@ -0,0 +1,119 @@ +# pylint: disable=too-many-arguments +# pylint: disable=too-many-instance-attributes + +from torch import nn + +from uvcgan_s.torch.layers.transformer import ( + ExtendedPixelwiseViT, CExtPixelwiseViT +) +from uvcgan_s.torch.layers.modnet import ModNet +from uvcgan_s.torch.select import get_activ_layer + +class ViTModNetGenerator(nn.Module): + + def __init__( + self, features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, input_shape, output_shape, modnet_features_list, + modnet_activ, + modnet_norm = None, + modnet_downsample = 'conv', + modnet_upsample = 'upsample-conv', + modnet_rezero = False, + modnet_demod = True, + rezero = True, + activ_output = None, + style_rezero = True, + style_bias = True, + n_ext = 1, + **kwargs + ): + # pylint: disable = too-many-locals + super().__init__(**kwargs) + + mod_features = features * n_ext + + self.net = ModNet( + modnet_features_list, modnet_activ, modnet_norm, + input_shape, output_shape, + modnet_downsample, modnet_upsample, mod_features, modnet_rezero, + modnet_demod, style_rezero, style_bias, return_mod = False + ) + + bottleneck = ExtendedPixelwiseViT( + features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, + image_shape = self.net.get_inner_shape(), + rezero = rezero, + n_ext = n_ext, + ) + + self.net.set_bottleneck(bottleneck) + + self.output = get_activ_layer(activ_output) + + def forward(self, x): + # x : (N, C, H, W) + result = self.net(x) + return self.output(result) + +class CViTModNetGenerator(nn.Module): + + def __init__( + self, features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, input_shape, output_shape, modnet_features_list, + modnet_activ, + modnet_norm = None, + modnet_downsample = 'conv', + modnet_upsample = 'upsample-conv', + modnet_rezero = False, + modnet_demod = True, + rezero = True, + activ_output = None, + style_rezero = True, + style_bias = True, + n_control_in = 100, + n_control_out = 100, + return_feedback = True, + **kwargs + ): + # pylint: disable = too-many-locals + super().__init__(**kwargs) + + self.control_in = nn.Linear(n_control_in, features) + self.control_out = nn.Linear(features, n_control_out) + self.return_feedback = return_feedback + mod_features = features + + self.net = ModNet( + modnet_features_list, modnet_activ, modnet_norm, + input_shape, output_shape, + modnet_downsample, modnet_upsample, mod_features, modnet_rezero, + modnet_demod, style_rezero, style_bias, return_mod = True + ) + + bottleneck = CExtPixelwiseViT( + features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, + image_shape = self.net.get_inner_shape(), + rezero = rezero, + ) + + self.net.set_bottleneck(bottleneck) + + self.output = get_activ_layer(activ_output) + + def forward(self, x, control): + # x : (N, C, H, W) + + mod = self.control_in(control) + + result, mod = self.net(x, mod) + + result = self.output(result) + + if not self.return_feedback: + return result + + feedback = self.control_out(mod) + return result, feedback + diff --git a/uvcgan_s/models/generator/vitunet.py b/uvcgan_s/models/generator/vitunet.py new file mode 100644 index 0000000..442beac --- /dev/null +++ b/uvcgan_s/models/generator/vitunet.py @@ -0,0 +1,49 @@ +# pylint: disable=too-many-arguments +# pylint: disable=too-many-instance-attributes + +from torch import nn + +from uvcgan_s.torch.layers.transformer import PixelwiseViT +from uvcgan_s.torch.layers.unet import UNet +from uvcgan_s.torch.select import get_activ_layer + +class ViTUNetGenerator(nn.Module): + + def __init__( + self, features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, input_shape, output_shape, + unet_features_list, unet_activ, unet_norm, + unet_downsample = 'conv', + unet_upsample = 'upsample-conv', + unet_rezero = False, + rezero = True, + activ_output = None, + **kwargs + ): + # pylint: disable = too-many-locals + super().__init__(**kwargs) + + self.image_shape = input_shape + + self.net = UNet( + unet_features_list, unet_activ, unet_norm, input_shape, + unet_downsample, unet_upsample, unet_rezero, + output_shape = output_shape + ) + + bottleneck = PixelwiseViT( + features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, + image_shape = self.net.get_inner_shape(), + rezero = rezero + ) + + self.net.set_bottleneck(bottleneck) + + self.output = get_activ_layer(activ_output) + + def forward(self, x): + # x : (N, C, H, W) + result = self.net(x) + return self.output(result) + diff --git a/uvcgan_s/torch/__init__.py b/uvcgan_s/torch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/uvcgan_s/torch/background_penalty.py b/uvcgan_s/torch/background_penalty.py new file mode 100644 index 0000000..6dd319f --- /dev/null +++ b/uvcgan_s/torch/background_penalty.py @@ -0,0 +1,35 @@ +from torch import nn + +class BackgroundPenaltyReduction(nn.Module): + + def __init__(self, epochs_warmup, epochs_anneal, **kwargs): + super().__init__(**kwargs) + + self._epochs_warmup = epochs_warmup + self._epochs_anneal = epochs_anneal + + self._alpha = 0 + + def end_epoch(self, epoch): + if epoch is None: + self._alpha = 1 + return + + if epoch < self._epochs_warmup: + self._alpha = 0 + return + + progression = (epoch - self._epochs_warmup) + + if progression < self._epochs_anneal: + self._alpha = progression / self._epochs_anneal + else: + self._alpha = 1 + + def forward(self, fake, real): + if self._alpha == 1: + return fake + + result = fake - (1 - self._alpha) * fake * (real == 0) + return result + diff --git a/uvcgan_s/torch/data_norm.py b/uvcgan_s/torch/data_norm.py new file mode 100644 index 0000000..4409a41 --- /dev/null +++ b/uvcgan_s/torch/data_norm.py @@ -0,0 +1,220 @@ +import torch +from torchvision.transforms import Normalize +from .select import extract_name_kwargs + +class DataNorm: + + def normalize(self, x): + raise NotImplementedError + + def denormalize(self, y): + raise NotImplementedError + + def normalize_nograd(self, x): + with torch.no_grad(): + return self.normalize(x) + + def denormalize_nograd(self, y): + with torch.no_grad(): + return self.denormalize(y) + +class Standardizer(DataNorm): + + def __init__(self, mean, stdev): + super().__init__() + # fwd: (x - m) / s + # bkw: s * y + m + + bkw_mean = [ -m/s for (m, s) in zip(mean, stdev) ] + bkw_stdev = [ 1/s for s in stdev ] + + self._norm_fwd = Normalize(mean, stdev) + self._norm_bkw = Normalize(bkw_mean, bkw_stdev) + + def normalize(self, x): + return self._norm_fwd(x) + + def denormalize(self, y): + return self._norm_bkw(y) + +class ScaleNorm(DataNorm): + + def __init__(self, scale = 1.0): + super().__init__() + self._scale = scale + + def normalize(self, x): + return self._scale * x + + def denormalize(self, y): + return y / self._scale + +class LogNorm(DataNorm): + + def __init__(self, clip_min = None, bias = None): + super().__init__() + + self._clip_min = clip_min + self._bias = bias + + def normalize(self, x): + if self._clip_min is not None: + x = x.clip(min = self._clip_min, max = None) + + if self._bias is not None: + x = x + self._bias + + return torch.log(x) + + def denormalize(self, y): + x = torch.exp(y) + + if self._bias is not None: + x = x - self._bias + + return x + +class SymLogNorm(DataNorm): + """ + To ensure continuity of the function and its derivative: + y = scale * x if (x <= T) + y = scale * T * (log(x/T) + 1) if (x > T) + + Inverse: + x = y / scale if (y <= scale * T) + x = T * exp(y / (scale * T) - 1) otherwise + """ + + def __init__(self, threshold = 1.0, scale = 1.0): + super().__init__() + + self._scale = scale + self._T = threshold + self._inv_T = scale * threshold + + def normalize(self, x): + x_abs = x.abs() + + y_lin = self._scale * x + y_log = torch.sign(x) * self._inv_T * (torch.log(x_abs / self._T) + 1) + + return torch.where(x_abs > self._T, y_log, y_lin) + + def denormalize(self, y): + y_abs = y.abs() + + x_lin = y / self._scale + x_log = torch.sign(y) * self._T * torch.exp(y_abs / self._inv_T - 1) + + return torch.where(y_abs > self._inv_T, x_log, x_lin) + +class MinMaxScaler(DataNorm): + + def __init__(self, feature_min, feature_max, dim): + super().__init__() + + self._min = feature_min + self._max = feature_max + self._dim = dim + + @staticmethod + def align_shapes(source, target, dim): + if source.ndim == 0: + source = source.unsqueeze(0) + + n_before = dim + n_after = target.ndim - (dim + source.ndim) + + return source.reshape( + ( 1, ) * n_before + source.shape + ( 1, ) * n_after + ) + + def normalize(self, x): + fmin = torch.tensor(self._min, dtype = x.dtype, device = x.device) + fmax = torch.tensor(self._max, dtype = x.dtype, device = x.device) + + fmin = MinMaxScaler.align_shapes(fmin, x, self._dim) + fmax = MinMaxScaler.align_shapes(fmax, x, self._dim) + + return (x - fmin) / (fmax - fmin) + + def denormalize(self, y): + fmin = torch.tensor(self._min, dtype = y.dtype, device = y.device) + fmax = torch.tensor(self._max, dtype = y.dtype, device = y.device) + + fmin = MinMaxScaler.align_shapes(fmin, y, self._dim) + fmax = MinMaxScaler.align_shapes(fmax, y, self._dim) + + return (y * (fmax - fmin) + fmin) + +class DoublePrecision(DataNorm): + + def __init__(self, norm): + self._norm = norm + + def normalize(self, x): + old_dtype = x.dtype + + result = x.double() + result = self._norm.normalize(result) + + return result.to(dtype = old_dtype) + + def denormalize(self, y): + old_dtype = y.dtype + + result = y.double() + result = self._norm.denormalize(result) + + return result.to(dtype = old_dtype) + +class Compose(DataNorm): + + def __init__(self, norms): + self._norms = norms + + def normalize(self, x): + for norm in self._norms: + x = norm.normalize(x) + + return x + + def denormalize(self, y): + for norm in reversed(self._norms): + y = norm.denormalize(y) + + return y + +def select_single_data_normalization(norm): + name, kwargs = extract_name_kwargs(norm) + + if name == 'double-precision': + return DoublePrecision(select_data_normalization(**kwargs)) + + if name == 'scale': + return ScaleNorm(**kwargs) + + if name == 'log': + return LogNorm(**kwargs) + + if name == 'symlog': + return SymLogNorm(**kwargs) + + if name == 'standardize': + return Standardizer(**kwargs) + + if name == 'min-max-scaler': + return MinMaxScaler(**kwargs) + + raise ValueError(f"Unknown data normalization '{name}'") + +def select_data_normalization(norm): + if norm is None: + return None + + if isinstance(norm, (tuple, list)): + norm = [ select_single_data_normalization(n) for n in norm ] + return Compose(norm) + + return select_single_data_normalization(norm) + diff --git a/uvcgan_s/torch/funcs.py b/uvcgan_s/torch/funcs.py new file mode 100644 index 0000000..b7b2b94 --- /dev/null +++ b/uvcgan_s/torch/funcs.py @@ -0,0 +1,68 @@ +import logging +import random +import torch +import numpy as np + +from torch import nn + +LOGGER = logging.getLogger('uvcgan_s.torch') + +def seed_everything(seed): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + +def get_torch_device_smart(): + if torch.cuda.is_available(): + return 'cuda' + + return 'cpu' + +def prepare_model(model, device): + model = model.to(device) + + if torch.cuda.device_count() > 1: + LOGGER.warning( + "Multiple (%d) GPUs found. Using Data Parallelism", + torch.cuda.device_count() + ) + model = nn.DataParallel(model) + + return model + +@torch.no_grad() +def update_average_model(average_model, model, momentum): + # TODO: Maybe it is better to copy buffers, instead of + # averaging them. + # Think about this later. + online_params = dict(model.named_parameters()) + online_bufs = dict(model.named_buffers()) + + for (k, v) in average_model.named_parameters(): + if v.ndim == 0: + v.copy_(momentum * v + (1 - momentum) * online_params[k]) + else: + v.lerp_(online_params[k], (1 - momentum)) + + for (k, v) in average_model.named_buffers(): + if v.ndim == 0: + v.copy_(momentum * v + (1 - momentum) * online_bufs[k]) + else: + v.lerp_(online_bufs[k], (1 - momentum)) + +def clip_gradients(optimizer, norm = None, value = None): + if (norm is None) and (value is None): + return + + params = [ + param + for param_group in optimizer.param_groups + for param in param_group['params'] + ] + + if norm is not None: + torch.nn.utils.clip_grad_norm_(params, max_norm = norm) + + if value is not None: + torch.nn.utils.clip_grad_value_(params, clip_value = value) + diff --git a/uvcgan_s/torch/gan_losses.py b/uvcgan_s/torch/gan_losses.py new file mode 100644 index 0000000..059bfcb --- /dev/null +++ b/uvcgan_s/torch/gan_losses.py @@ -0,0 +1,135 @@ +import torch +from torch import nn + +from .select import extract_name_kwargs + +def reduce_loss(loss, reduction): + if (reduction is None) or (reduction == 'none'): + return loss + + if reduction == 'mean': + return loss.mean() + + if reduction == 'sum': + return loss.sum() + + raise ValueError(f"Unknown reduction method: '{reduction}'") + +class GANLoss(nn.Module): + + def __init__( + self, label_real = 1, label_fake = 0, reduction = 'mean', + **kwargs + ): + super().__init__(**kwargs) + + self.reduction = reduction + self.register_buffer('label_real', torch.tensor(label_real)) + self.register_buffer('label_fake', torch.tensor(label_fake)) + + def _expand_label_as(self, x, is_real): + result = self.label_real if is_real else self.label_fake + return result.to(dtype = x.dtype).expand_as(x) + + def eval_loss(self, x, is_real, is_generator): + raise NotImplementedError + + def forward(self, x, is_real, is_generator = False): + if isinstance(x, (list, tuple)): + result = sum(self.forward(y, is_real, is_generator) for y in x) + return result / len(x) + + return self.eval_loss(x, is_real, is_generator) + +class LSGANLoss(GANLoss): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.loss = nn.MSELoss(reduction = self.reduction) + + def eval_loss(self, x, is_real, is_generator = False): + label = self._expand_label_as(x, is_real) + return self.loss(x, label) + +class BCEGANLoss(GANLoss): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.loss = nn.BCEWithLogitsLoss(reduction = self.reduction) + + def eval_loss(self, x, is_real, is_generator = False): + label = self._expand_label_as(x, is_real) + return self.loss(x, label) + +class SoftplusGANLoss(GANLoss): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.loss = nn.Softplus() + + def eval_loss(self, x, is_real, is_generator = False): + if is_real: + result = self.loss(x) + else: + result = self.loss(-x) + + return reduce_loss(result, self.reduction) + +class WGANLoss(GANLoss): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + if self.reduction != 'mean': + raise NotImplementedError + + def eval_loss(self, x, is_real, is_generator = False): + if is_real: + result = -x.mean() + else: + result = x.mean() + + return reduce_loss(result, self.reduction) + +class HingeGANLoss(GANLoss): + + def __init__(self, margin = 1, **kwargs): + super().__init__(**kwargs) + self._margin = margin + self._relu = nn.ReLU() + + if self.reduction != 'mean': + raise NotImplementedError + + def eval_loss(self, x, is_real, is_generator = False): + if is_generator: + if is_real: + result = -x.mean() + else: + result = x.mean() + else: + if is_real: + result = self._relu(self._margin - x).mean() + else: + result = self._relu(self._margin + x).mean() + + return reduce_loss(result, self.reduction) + +GAN_LOSSES = { + 'lsgan' : LSGANLoss, + 'wgan' : WGANLoss, + 'softplus' : SoftplusGANLoss, + 'hinge' : HingeGANLoss, + 'bce' : BCEGANLoss, + 'vanilla' : BCEGANLoss, +} + +def select_gan_loss(gan_loss): + name, kwargs = extract_name_kwargs(gan_loss) + + if name in GAN_LOSSES: + return GAN_LOSSES[name](**kwargs) + + raise ValueError(f"Unknown gan loss: '{name}'") + + diff --git a/uvcgan_s/torch/gradient_cacher.py b/uvcgan_s/torch/gradient_cacher.py new file mode 100644 index 0000000..1c812f9 --- /dev/null +++ b/uvcgan_s/torch/gradient_cacher.py @@ -0,0 +1,41 @@ + +class GradientCacher: + + def __init__(self, model, loss, cache_period = 10): + self._cache_period = cache_period + self._cache_grad = {} + self._cache_loss = None + self._iter = 0 + self._model = model + self._loss = loss + + def __call__(self, *args, **kwargs): + if ( + (self._iter < self._cache_period) + and (self._cache_loss is not None) + ): + for name, p in self._model.named_parameters(): + cached_grad = self._cache_grad[name] + + if cached_grad is None: + p.grad = None + else: + p.grad = cached_grad.clone() + + result = self._cache_loss + self._iter += 1 + + else: + result = self._loss(*args, **kwargs) + result.backward() + + self._cache_loss = result + self._cache_grad = { + name : p.grad.detach().clone() if p.grad is not None else None + for (name, p) in self._model.named_parameters() + } + + self._iter = 0 + + return result + diff --git a/uvcgan_s/torch/gradient_penalty.py b/uvcgan_s/torch/gradient_penalty.py new file mode 100644 index 0000000..2f82ff2 --- /dev/null +++ b/uvcgan_s/torch/gradient_penalty.py @@ -0,0 +1,134 @@ +import torch + +def mix_tensors(alpha, a, b): + return alpha * a + (1 - alpha) * b + +def recursively_mix_args(alpha, a, b): + if a is None: + assert b is None + return None + + if isinstance(a, (tuple, list)): + return type(a)( + recursively_mix_args(alpha, x, y) for (x, y) in zip(a, b) + ) + + if isinstance(a, dict): + return { + k : recursively_mix_args(alpha, a[k], b[k]) for k in a + } + + if isinstance(a, torch.Tensor): + return mix_tensors(alpha, a, b) + + assert a == b + return a + +def reduce_tensor(y, reduction, reduce_batch = False): + if isinstance(y, (list, tuple)): + return [ reduce_tensor(x, reduction) for x in y ] + + if (reduction is None) or (reduction == 'none'): + return y + + if reduction == 'mean': + if reduce_batch: + return y.mean() + else: + return y.mean(dim = tuple(i for i in range(1, y.dim()))) + + if reduction == 'sum': + if reduce_batch: + return y.sum() + else: + return y.sum(dim = tuple(i for i in range(1, y.dim()))) + + raise ValueError(f"Unknown reduction: '{reduction}'") + +class GradientPenalty: + + def __init__( + self, mix_type, center, lambda_gp, seed = 0, + reduction = 'mean', gp_reduction = 'mean' + ): + # pylint: disable=too-many-arguments + self._mix_type = mix_type + self._center = center + self._reduce = reduction + self._gp_reduce = gp_reduction + self._lambda = lambda_gp + + self._rng = torch.Generator() + self._rng.manual_seed(seed) + + def eval_at(self, model, x, **model_kwargs): + x.requires_grad_(True) + + y = model(x, **model_kwargs) + y = reduce_tensor(y, self._reduce, reduce_batch = False) + + if not isinstance(y, list): + y = [ y, ] + + grad = torch.autograd.grad( + outputs = y, + inputs = x, + grad_outputs = [ torch.ones(z.size()).to(z.device) for z in y ], + create_graph = True, + retain_graph = True, + ) + + grad_x = grad[0].reshape((x.shape[0], -1)) + result = torch.square( + torch.norm(grad_x, p = 2, dim = 1) - self._center + ) + + return self._lambda * result + + def mix_gp_args(self, a, b, model_kwargs_a, model_kwargs_b): + alpha = torch.rand(1, generator = self._rng).to(a.device) + result = mix_tensors(alpha, a, b) + + mixed_kwargs = recursively_mix_args( + alpha, model_kwargs_a, model_kwargs_b + ) + return (result, mixed_kwargs) + + def get_eval_point( + self, fake, real, model_kwargs_fake = None, model_kwargs_real = None + ): + if self._mix_type == 'real': + return (real.clone(), model_kwargs_real) + + if self._mix_type == 'fake': + return (fake.clone(), model_kwargs_fake) + + if self._mix_type == 'real-fake': + return self.mix_gp_args( + real, fake, model_kwargs_real, model_kwargs_fake + ) + + if self._mix_type == 'real-or-fake': + alpha = torch.rand(1, generator = self._rng).cpu().item() + if alpha > 0.5: + return (real.clone(), model_kwargs_real) + else: + return (fake.clone(), model_kwargs_fake) + + raise ValueError(f"Unknown mix type: {self._mix_type}") + + def __call__( + self, model, fake, real, + model_kwargs_fake = None, + model_kwargs_real = None, + ): + # pylint: disable=too-many-arguments + x, model_kwargs = self.get_eval_point( + fake, real, model_kwargs_fake, model_kwargs_real + ) + + model_kwargs = model_kwargs or {} + result = self.eval_at(model, x, **model_kwargs) + + return reduce_tensor(result, self._gp_reduce, reduce_batch = True) + diff --git a/uvcgan_s/torch/image_masking.py b/uvcgan_s/torch/image_masking.py new file mode 100644 index 0000000..cffc0c1 --- /dev/null +++ b/uvcgan_s/torch/image_masking.py @@ -0,0 +1,129 @@ +import torch +from torch import nn + +from .select import extract_name_kwargs +from .layers.transformer import calc_tokenized_size + +class SequenceRandomMasking(nn.Module): + + def __init__(self, fraction = 0.4, seed = 0, **kwargs): + super().__init__(**kwargs) + self._fraction = fraction + + self._rng = torch.Generator() + self._rng.manual_seed(seed) + + def forward(self, sequence): + # sequence : (N, L, features) + mask = ( + torch.rand((*sequence.shape[:2], 1), generator = self._rng) + > self._fraction + ) + return mask.to(sequence.device) * sequence + +class ImagePatchRandomMasking(nn.Module): + + def __init__(self, patch_size, fraction = 0.4, seed = 0, **kwargs): + super().__init__(**kwargs) + + self._patch_size = patch_size + self._fraction = fraction + + self._rng = torch.Generator() + self._rng.manual_seed(seed) + + def forward(self, image): + # image : (N, C, H, W) + N_h, N_w = calc_tokenized_size(image.shape[1:], self._patch_size) + + # mask : (N, 1, N_h, N_w) + mask = ( + torch.rand((image.shape[0], 1, N_h, N_w), generator = self._rng) + > self._fraction + ) + + # mask : (N, 1, N_h, N_w) + # -> (N, 1, H, W) + mask = mask.repeat_interleave(self._patch_size[0], dim = 2) + mask = mask.repeat_interleave(self._patch_size[1], dim = 3) + + return mask.to(image.device) * image + +# pylint: disable=trailing-whitespace +# class BlockwiseMasking(ImageMaskingBase): +# # Algorithm 1 of arXiv:2106.08254 +# +# def __init__( +# self, +# mask_ratio = 0.4, +# min_block_size = 16, +# aspect_ratio = 0.3, +# seed = 0, +# ): +# self._mask_ratio = mask_ratio +# self._min_block_size = min_block_size +# self._aspect_ratio = aspect_ratio +# +# self._prg = np.random.default_rng(seed) +# +# def get_mask_region(self, image, h, w, masking_threshold, masked_patches): +# min_block_size = self._min_block_size +# max_block_size = \ +# max(min_block_size, masking_threshold - len(masked_patches)) +# +# block_size = self._prg.integers(min_block_size, max_block_size) +# aspect_ratio = self._prg.uniform( +# self._aspect_ratio, 1/self._aspect_ratio +# ) +# +# y_range = int(np.round(np.sqrt(block_size * aspect_ratio))) +# x_range = int(np.round(np.sqrt(block_size / aspect_ratio))) +# +# y_range = min(y_range, h) +# x_range = min(x_range, w) +# +# y0 = self._prg.integers(0, h - y_range) +# x0 = self._prg.integers(0, w - x_range) +# +# return (y0, x0, y_range, x_range) +# +# def mask(self, image): +# # image : (..., H, W) +# h = image.shape[-2] +# w = image.shape[-1] +# +# n_patches = h * w +# masked_patches = set() +# +# masking_threshold = self._mask_ratio * n_patches +# +# while len(masked_patches) < masking_threshold: +# (y0, x0, y_range, x_range) = self.get_mask_region( +# image, h, w, masking_threshold, masked_patches +# ) +# +# for y in range(y0, y0 + y_range): +# for x in range(x0, x0 + x_range): +# coord = (x, y) +# if coord in masked_patches: +# continue +# +# image[..., y, x] = 0 +# masked_patches.add(coord) +# +# return image + +def select_masking(masking): + if masking is None: + return None + + name, kwargs = extract_name_kwargs(masking) + + if name in [ 'transformer-random', 'sequence-random' ]: + return SequenceRandomMasking(**kwargs) + + if name == 'image-patch-random': + return ImagePatchRandomMasking(**kwargs) + + raise ValueError("Unknown masking: '%s'" % name) + diff --git a/uvcgan_s/torch/layers/__init__.py b/uvcgan_s/torch/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/uvcgan_s/torch/layers/activation.py b/uvcgan_s/torch/layers/activation.py new file mode 100644 index 0000000..bb14ab8 --- /dev/null +++ b/uvcgan_s/torch/layers/activation.py @@ -0,0 +1,12 @@ +import torch +from torch import nn + +class Exponential(nn.Module): + + def __init__(self, beta = 1): + super().__init__() + self._beta = beta + + def forward(self, x): + return torch.exp(self._beta * x) + diff --git a/uvcgan_s/torch/layers/alt_trans.py b/uvcgan_s/torch/layers/alt_trans.py new file mode 100644 index 0000000..9e9c315 --- /dev/null +++ b/uvcgan_s/torch/layers/alt_trans.py @@ -0,0 +1,125 @@ +# pylint: disable=too-many-arguments +# pylint: disable=too-many-instance-attributes + +import torch +from torch import nn + +from uvcgan_s.torch.select import get_norm_layer +from .attention import select_attention +from .transformer import ( + img_to_pixelwise_tokens, img_from_pixelwise_tokens, + PositionWiseFFN, ViTInput +) + +class AltTransformerBlock(nn.Module): + + def __init__( + self, features, ffn_features, n_heads, activ = 'gelu', norm = None, + attention = 'dot', rezero = True, **kwargs + ): + super().__init__(**kwargs) + + self.norm1 = get_norm_layer(norm, features) + self.atten = select_attention( + attention, + embed_dim = features, + num_heads = n_heads, + batch_first = False, + ) + + self.norm2 = get_norm_layer(norm, features) + self.ffn = PositionWiseFFN(features, ffn_features, activ) + + self.rezero = rezero + + if rezero: + self.re_alpha = nn.Parameter(torch.zeros((1, ))) + else: + self.re_alpha = 1 + + def forward(self, x): + # x: (L, N, features) + + # Step 1: Multi-Head Self Attention + y1 = self.norm1(x) + y1, _atten_weights = self.atten(y1, y1, y1) + + y = x + self.re_alpha * y1 + + # Step 2: PositionWise Feed Forward Network + y2 = self.norm2(y) + y2 = self.ffn(y2) + + y = y + self.re_alpha * y2 + + return y + + def extra_repr(self): + return 're_alpha = %e' % (self.re_alpha, ) + +class AltTransformerEncoder(nn.Module): + + def __init__( + self, features, ffn_features, n_heads, n_blocks, activ, norm, + attention = 'dot', rezero = True, **kwargs + ): + super().__init__(**kwargs) + + self.encoder = nn.Sequential(*[ + AltTransformerBlock( + features, ffn_features, n_heads, activ, norm, attention, rezero + ) for _ in range(n_blocks) + ]) + + def forward(self, x): + # x : (N, L, features) + + # y : (L, N, features) + y = x.permute((1, 0, 2)) + y = self.encoder(y) + + # result : (N, L, features) + result = y.permute((1, 0, 2)) + + return result + +class AltPixelwiseViT(nn.Module): + + def __init__( + self, features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, image_shape, attention = 'dot', rezero = True, **kwargs + ): + super().__init__(**kwargs) + + self.image_shape = image_shape + + self.trans_input = ViTInput( + image_shape[0], embed_features, features, + image_shape[1], image_shape[2], + ) + + self.encoder = AltTransformerEncoder( + features, ffn_features, n_heads, n_blocks, activ, norm, + attention, rezero + ) + + self.trans_output = nn.Linear(features, image_shape[0]) + + def forward(self, x): + # x : (N, C, H, W) + + # itokens : (N, L, C) + itokens = img_to_pixelwise_tokens(x) + + # y : (N, L, features) + y = self.trans_input(itokens) + y = self.encoder(y) + + # otokens : (N, L, C) + otokens = self.trans_output(y) + + # result : (N, C, H, W) + result = img_from_pixelwise_tokens(otokens, self.image_shape) + + return result + diff --git a/uvcgan_s/torch/layers/attention.py b/uvcgan_s/torch/layers/attention.py new file mode 100644 index 0000000..c67adc5 --- /dev/null +++ b/uvcgan_s/torch/layers/attention.py @@ -0,0 +1,160 @@ +import math +from typing import Optional, Tuple + +import torch +from torch import nn +from torch import Tensor + +from einops import rearrange +from uvcgan_s.torch.select import extract_name_kwargs + +def expand_heads(values, n_heads): + return rearrange( + values, 'N L (D_h n_heads) -> (N n_heads) L D_h', n_heads = n_heads + ) + +def contract_heads(values, n_heads): + return rearrange( + values, '(N n_heads) L D_h -> N L (D_h n_heads)', n_heads = n_heads + ) + +def return_result_and_atten_weights( + result, A, need_weights, average_attn_weights, batch_first, n_heads +): + # pylint: disable=too-many-arguments + + # result : (N, L, embed_dim) + # A : ((N n_heads), L, S) + + if not batch_first: + # result : (N, L, embed_dim) + # -> (L, N, embed_dim) + result = result.swapaxes(0, 1) + + if not need_weights: + return (result, None) + + # A : (n_heads, N, L, S) + A = rearrange( + A, '(N n_heads) L S -> N n_heads L S', n_heads = n_heads + ) + + if average_attn_weights: + # A : (N, n_heads, L, S) + # -> (N, L, S) + A = A.mean(dim = 1) + + return (result, A) + +class QuadraticAttention(nn.Module): + + # This is Lipshits continuous when equal_kq + + def __init__( + self, embed_dim, num_heads, + bias = False, + add_bias_kv = False, + kdim = None, + vdim = None, + batch_first = False, + equal_kq = False, + ): + # pylint: disable=too-many-arguments + super().__init__() + + if kdim is None: + kdim = embed_dim + elif equal_kq: + assert kdim == embed_dim + + if vdim is None: + vdim = embed_dim + + self._batch_first = batch_first + self._n_heads = num_heads + self._dh = embed_dim // num_heads + + self.w_q = nn.Linear(embed_dim, embed_dim, bias = bias) + + if equal_kq: + self.w_k = self.w_q + else: + self.w_k = nn.Linear(kdim, embed_dim, bias = add_bias_kv) + + self.w_v = nn.Linear(vdim, embed_dim, bias = add_bias_kv) + self.w_o = nn.Linear(embed_dim, embed_dim, bias = bias) + + def compute_attention_matrix(self, w_query, w_key): + # w_query : (N, L, d) + # w_key : (N, S, d) + + # w_query : (N, L, d) + # -> (N, L, 1, d) + w_query = w_query.unsqueeze(dim = 2) + + # w_key : (N, S, d) + # -> (N, 1, S, d) + w_key = w_key.unsqueeze(dim = 1) + + # L : (N, L, S) + L = -torch.norm(w_query - w_key, p = 2, dim = 3) + L = L / math.sqrt(self._dh) + + # result : (N, L, S) + return torch.softmax(L, dim = 2) + + def forward( + self, query: Tensor, key: Tensor, value: Tensor, + key_padding_mask : Optional[Tensor] = None, + need_weights : bool = True, + attn_mask : Optional[Tensor] = None, + average_attn_weights : bool = True + ) -> Tuple[Tensor, Optional[Tensor]]: + # pylint: disable=too-many-arguments + + assert key_padding_mask is None, "key_padding_mask is not supported" + assert attn_mask is None, "attn_mask is not supported" + + if not self._batch_first: + # (k, q, v) : (L, N, ...) + query = query.swapaxes(0, 1) + key = key.swapaxes(0, 1) + value = value.swapaxes(0, 1) + + # (query, key, value) : (N, L, ...) + + # w_query : ((N n_heads), L, D_h) + # w_key : ((N n_heads), S, D_h) + # w_value : ((N n_heads), S, D_h) + w_query = expand_heads(self.w_q(query), self._n_heads) + w_key = expand_heads(self.w_k(key), self._n_heads) + w_value = expand_heads(self.w_v(key), self._n_heads) + + # A : ((N n_heads), L, S) + A = self.compute_attention_matrix(w_query, w_key) + + # w_output : ((N n_heads), L, D_h) + w_output = torch.bmm(A, w_value) + + # output : (N, L, embed_dim) + output = contract_heads(w_output, self._n_heads) + + # result : (N, L, embed_dim) + result = self.w_o(output) + + return return_result_and_atten_weights( + result, A, need_weights, average_attn_weights, + self._batch_first, self._n_heads + ) + +def select_attention(attention, **extra_kwargs): + name, kwargs = extract_name_kwargs(attention) + + if name in [ 'default', 'standard', 'scalar', 'dot' ]: + return nn.MultiheadAttention(**kwargs, **extra_kwargs) + + if name in [ 'quadratic', 'l2' ]: + return QuadraticAttention(**kwargs, **extra_kwargs) + + raise ValueError(f"Unknown attention {name}") + diff --git a/uvcgan_s/torch/layers/batch_head.py b/uvcgan_s/torch/layers/batch_head.py new file mode 100644 index 0000000..6d5464c --- /dev/null +++ b/uvcgan_s/torch/layers/batch_head.py @@ -0,0 +1,331 @@ +import torch +from torch import nn + +from uvcgan_s.torch.select import get_activ_layer, extract_name_kwargs +from .alt_trans import AltTransformerEncoder + +# References: +# arXiv: 1912.0495 +# https://github.com/moono/stylegan2-tf-2.x/blob/master/stylegan2/discriminator.py +# https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py + +class BatchStdev(nn.Module): + + # pylint: disable=useless-super-delegation + def __init__(self, **kwargs): + """ arXiv: 1710.10196 """ + super().__init__(**kwargs) + + @staticmethod + def safe_stdev(x, dim = 0, eps = 1e-6): + var = torch.var(x, dim = dim, unbiased = False, keepdim = True) + stdev = torch.sqrt(var + eps) + + return stdev + + # pylint: disable=no-self-use + def forward(self, x): + """ + NOTE: Reference impl has fixed minibatch size. + + arXiv: 1710.10196 + + 1. We first compute the standard deviation for each feature in each + spatial location over the minibatch. + + 2. We then average these estimates over all features and spatial + locations to arrive at a single value. + + 3. We replicate the value and concatenate it to all spatial locations + and over the minibatch, yielding one additional (con-stant) feature + map. + """ + + # x : (N, C, H, W) + # x_stdev : (1, C, H, W) + x_stdev = BatchStdev.safe_stdev(x, dim = 0) + + # x_norm : (1, 1, 1, 1) + x_norm = torch.mean(x_stdev, dim = (1, 2, 3), keepdim = True) + + # x_norm : (N, 1, H, W) + x_norm = x_norm.expand((x.shape[0], 1, *x.shape[2:])) + + # y : (N, C + 1, H, W) + y = torch.cat((x, x_norm), dim = 1) + + return y + +class BatchAttention(nn.Module): + + def __init__( + self, input_features, features, ffn_features, n_heads, n_blocks, activ, + norm, attention = 'dot', rezero = True, + input_flatten_method = 'average', **kwargs + ): + # pylint: disable=too-many-arguments + super().__init__(**kwargs) + + self.embed = nn.Linear(input_features, features) + self.attention = AltTransformerEncoder( + features, ffn_features, n_heads, n_blocks, activ, norm, attention, + rezero + ) + + self.flat_method = input_flatten_method + + assert input_flatten_method in [ 'average', 'flatten' ], \ + f"Unknown input flattening method: {input_flatten_method}" + + def flatten_input(self, x): + # x : (N, C_in, ...) + + if len(x.shape) == 2: + return x + + if self.flat_method == 'average': + return x.mean(dim = tuple(range(2, len(x.shape))), keepdim = False) + + if self.flat_method == 'stdev': + return BatchStdev.safe_stdev( + x, dim = tuple(range(2, len(x.shape))) + ) + + if self.flat_method == 'flatten': + return x.reshape((x.shape[0], -1)) + + raise ValueError(f"Unknown flatten method: {self.flat_method}") + + def forward(self, x): + # x : (N, C_in, ...) + + # x_flat : (N, C) + x_flat = self.flatten_input(x) + + # e : (1, N, features) + e = self.embed(x_flat).unsqueeze(0) + + # y : (1, N, features) + y = self.attention(e) + + # result : (N, features) + return y.squeeze(0) + +class BatchHead1d(nn.Module): + + def __init__( + self, input_features, mid_features = None, output_features = None, + activ = 'relu', activ_output = None, **kwargs + ): + # pylint: disable=too-many-arguments + super().__init__(**kwargs) + + if mid_features is None: + mid_features = input_features + + if output_features is None: + output_features = mid_features + + self.net = nn.Sequential( + nn.Linear(input_features, mid_features), + nn.BatchNorm1d(mid_features), + get_activ_layer(activ), + + nn.Linear(mid_features, output_features), + get_activ_layer(activ_output), + ) + + def forward(self, x): + # x : (N, C) + return self.net(x) + +class BatchHead2d(nn.Module): + + def __init__( + self, input_features, mid_features = None, output_features = None, + activ = 'relu', activ_output = None, n_signal = None, **kwargs + ): + # pylint: disable=too-many-arguments + super().__init__(**kwargs) + + if mid_features is None: + mid_features = input_features + + if output_features is None: + output_features = mid_features + + self._n_signal = n_signal + + self.norm = nn.BatchNorm2d(input_features) + self.net = nn.Sequential( + nn.Conv2d( + input_features, mid_features, kernel_size = 3, padding = 1 + ), + get_activ_layer(activ), + + nn.Conv2d( + mid_features, output_features, kernel_size = 3, padding = 1 + ), + get_activ_layer(activ_output), + ) + + def forward(self, x): + # x : (N, C, H, W) + y = self.norm(x) + + if self._n_signal is not None: + # Drop queue tokens + y = y[:self._n_signal, ...] + + return self.net(y) + +class BatchStdevHead(nn.Module): + + def __init__( + self, input_features, mid_features = None, output_features = None, + activ = 'relu', activ_output = None, **kwargs + ): + # pylint: disable=too-many-arguments + super().__init__(**kwargs) + + if mid_features is None: + mid_features = input_features + + if output_features is None: + output_features = mid_features + + self.net = nn.Sequential( + BatchStdev(), + nn.Conv2d( + input_features + 1, mid_features, kernel_size = 3, padding = 1 + ), + get_activ_layer(activ), + + nn.Conv2d( + mid_features, output_features, kernel_size = 3, padding = 1 + ), + get_activ_layer(activ_output), + ) + + def forward(self, x): + # x : (N, C, H, W) + return self.net(x) + +class BatchAverageHead(nn.Module): + + def __init__( + self, input_features, reduce_channels = True, average_spacial = False, + activ_output = None, **kwargs + ): + # pylint: disable=too-many-arguments + super().__init__(**kwargs) + + layers = [] + + if reduce_channels: + layers.append( + nn.Conv2d(input_features, 1, kernel_size = 3, padding = 1) + ) + + if average_spacial: + layers.append(nn.AdaptiveAvgPool2d(1)) + + if activ_output is not None: + layers.append(get_activ_layer(activ_output)) + + self.net = nn.Sequential(*layers) + + def forward(self, x): + # x : (N, C, H, W) + return self.net(x) + +class BatchAttentionHead(nn.Module): + + def __init__( + self, input_features, features, ffn_features, n_heads, n_blocks, activ, + norm, + attention = 'dot', + rezero = True, + input_flatten_method = 'average', + output_features = None, + activ_output = None, + **kwargs + ): + # pylint: disable=too-many-arguments + super().__init__(**kwargs) + + if output_features is None: + output_features = input_features + + self.net = nn.Sequential( + BatchAttention( + input_features, features, ffn_features, n_heads, n_blocks, + activ, norm, attention, rezero, input_flatten_method + ), + + nn.Linear(features, output_features), + get_activ_layer(activ_output), + ) + + def forward(self, x): + # x : (N, C, ...) + return self.net(x) + +class BatchHeadWrapper(nn.Module): + + def __init__(self, body, head, **kwargs): + super().__init__(**kwargs) + self._body = body + self._head = head + + def forward_head(self, x_body): + return self._head(x_body) + + def forward_body(self, x): + return self._body(x) + + def forward(self, x, extra_bodies = None, return_body = False): + y_body = self._body(x) + + if isinstance(y_body, (list, tuple)): + y_body_main = list(y_body[:-1]) + y_body_last = y_body[-1] + else: + y_body_main = tuple() + y_body_last = y_body + + if extra_bodies is not None: + all_bodies = torch.cat((y_body_last, extra_bodies), dim = 0) + y_head = self._head(all_bodies) + else: + y_head = self._head(y_body_last) + + y_head = y_head[:y_body_last.shape[0]] + + if len(y_body_main) == 0: + result = y_head + else: + result = y_body_main + [ y_head, ] + + if return_body: + return (result, y_body_last) + + return result + +BATCH_HEADS = { + 'batch-norm-1d' : BatchHead1d, + 'batch-norm-2d' : BatchHead2d, + 'batch-stdev' : BatchStdevHead, + 'batch-atten' : BatchAttentionHead, + 'simple-average' : BatchAverageHead, + 'idt' : nn.Identity, +} + +def get_batch_head(batch_head): + name, kwargs = extract_name_kwargs(batch_head) + + if name not in BATCH_HEADS: + raise ValueError("Unknown Batch Head: '%s'" % name) + + return BATCH_HEADS[name](**kwargs) + diff --git a/uvcgan_s/torch/layers/cnn.py b/uvcgan_s/torch/layers/cnn.py new file mode 100644 index 0000000..e76e6a8 --- /dev/null +++ b/uvcgan_s/torch/layers/cnn.py @@ -0,0 +1,143 @@ +from itertools import repeat + +from torch import nn +from uvcgan_s.torch.select import extract_name_kwargs + +def calc_conv1d_output_size(input_size, kernel_size, padding, stride): + return (input_size + 2 * padding - kernel_size) // stride + 1 + +def calc_conv_transpose1d_output_size( + input_size, kernel_size, padding, stride +): + return (input_size - 1) * stride - 2 * padding + kernel_size + +def calc_conv_output_size(input_size, kernel_size, padding, stride): + if isinstance(kernel_size, int): + kernel_size = repeat(kernel_size, len(input_size)) + + if isinstance(stride, int): + stride = repeat(stride, len(input_size)) + + if isinstance(padding, int): + padding = repeat(padding, len(input_size)) + + return tuple( + calc_conv1d_output_size(sz, ks, p, s) + for (sz, ks, p, s) in zip(input_size, kernel_size, padding, stride) + ) + +def calc_conv_transpose_output_size(input_size, kernel_size, padding, stride): + if isinstance(kernel_size, int): + kernel_size = repeat(kernel_size, len(input_size)) + + if isinstance(stride, int): + stride = repeat(stride, len(input_size)) + + if isinstance(padding, int): + padding = repeat(padding, len(input_size)) + + return tuple( + calc_conv_transpose1d_output_size(sz, ks, p, s) + for (sz, ks, p, s) in zip(input_size, kernel_size, padding, stride) + ) + +def get_downsample_x2_conv2_layer(features, **kwargs): + return ( + nn.Conv2d(features, features, kernel_size = 2, stride = 2, **kwargs), + features + ) + +def get_downsample_x2_conv3_layer(features, **kwargs): + return ( + nn.Conv2d( + features, features, kernel_size = 3, stride = 2, padding = 1, + **kwargs + ), + features + ) + +def get_downsample_x2_pixelshuffle_layer(features, **kwargs): + out_features = 4 * features + return (nn.PixelUnshuffle(downscale_factor = 2, **kwargs), out_features) + +def get_downsample_x2_pixelshuffle_conv_layer(features, **kwargs): + out_features = features * 4 + + layer = nn.Sequential( + nn.PixelUnshuffle(downscale_factor = 2, **kwargs), + nn.Conv2d( + out_features, out_features, kernel_size = 3, padding = 1 + ), + ) + + return (layer, out_features) + +def get_upsample_x2_deconv2_layer(features, **kwargs): + return ( + nn.ConvTranspose2d( + features, features, kernel_size = 2, stride = 2, **kwargs + ), + features + ) + +def get_upsample_x2_upconv_layer(features, **kwargs): + layer = nn.Sequential( + nn.Upsample(scale_factor = 2, **kwargs), + nn.Conv2d(features, features, kernel_size = 3, padding = 1), + ) + + return (layer, features) + +def get_upsample_x2_pixelshuffle_conv_layer(features, **kwargs): + out_features = features // 4 + + layer = nn.Sequential( + nn.PixelShuffle(upscale_factor = 2, **kwargs), + nn.Conv2d(out_features, out_features, kernel_size = 3, padding = 1), + ) + + return (layer, out_features) + +def get_downsample_x2_layer(layer, features): + name, kwargs = extract_name_kwargs(layer) + + if name == 'conv': + return get_downsample_x2_conv2_layer(features, **kwargs) + + if name == 'conv3': + return get_downsample_x2_conv3_layer(features, **kwargs) + + if name == 'avgpool': + return (nn.AvgPool2d(kernel_size = 2, stride = 2, **kwargs), features) + + if name == 'maxpool': + return (nn.MaxPool2d(kernel_size = 2, stride = 2, **kwargs), features) + + if name == 'pixel-unshuffle': + return get_downsample_x2_pixelshuffle_layer(features, **kwargs) + + if name == 'pixel-unshuffle-conv': + return get_downsample_x2_pixelshuffle_conv_layer(features, **kwargs) + + raise ValueError("Unknown Downsample Layer: '%s'" % name) + +def get_upsample_x2_layer(layer, features): + name, kwargs = extract_name_kwargs(layer) + + if name == 'deconv': + return get_upsample_x2_deconv2_layer(features, **kwargs) + + if name == 'upsample': + return (nn.Upsample(scale_factor = 2, **kwargs), features) + + if name == 'upsample-conv': + return get_upsample_x2_upconv_layer(features, **kwargs) + + if name == 'pixel-shuffle': + return (nn.PixelShuffle(upscale_factor = 2, **kwargs), features // 4) + + if name == 'pixel-shuffle-conv': + return get_upsample_x2_pixelshuffle_conv_layer(features, **kwargs) + + raise ValueError("Unknown Upsample Layer: '%s'" % name) + diff --git a/uvcgan_s/torch/layers/modnet.py b/uvcgan_s/torch/layers/modnet.py new file mode 100644 index 0000000..b779d95 --- /dev/null +++ b/uvcgan_s/torch/layers/modnet.py @@ -0,0 +1,422 @@ +# pylint: disable=too-many-arguments +# pylint: disable=too-many-instance-attributes + +import torch +from torch import nn + +from uvcgan_s.torch.select import get_activ_layer +from .cnn import get_upsample_x2_layer +from .unet import UNetEncBlock + +# Ref: https://arxiv.org/pdf/1912.04958.pdf + +def get_demod_scale(mod_scale, weights, eps = 1e-6): + # Ref: https://arxiv.org/pdf/1912.04958.pdf + # + # demod_scale[alpha] = 1 / sqrt(sigma[alpha]^2 + eps) + # + # sigma[alpha]^2 + # = sum_{beta i} (mod_scale[alpha] * weights[alpha, beta, i])^2 + # = sum_{beta} (mod_scale[alpha])^2 * sum_i (weights[alpha, beta, i])^2 + # + + # mod_scale : (N, C_in) + # weights : (C_out, C_in, h, w) + + # w_sq : (C_out, C_in) + w_sq = torch.sum(weights.square(), dim = (2, 3)) + + # w_sq : (C_out, C_in) -> (1, C_in, C_out) + w_sq = torch.swapaxes(w_sq, 0, 1).unsqueeze(0) + + # mod_scale_sq : (N, C_in, 1) + mod_scale_sq = mod_scale.square().unsqueeze(2) + + # sigma : (N, C_out) + sigma_sq = torch.sum(mod_scale_sq * w_sq, dim = 1) + + # result : (N, C_out) + return 1 / torch.sqrt(sigma_sq + eps) + +class ModulatedConv2d(nn.Module): + + def __init__( + self, in_channels, out_channels, kernel_size, + stride = 1, padding = 0, dilation = 1, groups = 1, eps = 1e-6, + demod = True, device = None, dtype = None, **kwargs + ): + super().__init__(**kwargs) + + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, device = device, dtype = dtype + ) + + self._eps = eps + self._demod = demod + + def forward(self, x, s): + # x : (N, C_in, H_in, W_in) + # s : (N, C_in) + + # x_mod : (N, C_in, H_in, W_in) + x_mod = x * s.unsqueeze(2).unsqueeze(3) + + # y : (N, C_out, H_out, W_out) + y = self.conv(x_mod) + + if self._demod: + # s_demod : (N, C_out) + s_demod = get_demod_scale(s, self.conv.weight) + y_demod = y * s_demod.unsqueeze(2).unsqueeze(3) + + return y_demod + + return y + +class StyleBlock(nn.Module): + + def __init__( + self, mod_features, style_features, rezero = True, bias = True, + **kwargs + ): + super().__init__(**kwargs) + + self.affine_mod = nn.Linear(mod_features, style_features, bias = bias) + self.rezero = rezero + + if rezero: + self.re_alpha = nn.Parameter(torch.zeros((1, ))) + else: + self.re_alpha = 1 + + def forward(self, mod): + # mod : (N, mod_features) + # s : (N, style_features) + s = 1 + self.re_alpha * self.affine_mod(mod) + + return s + + def extra_repr(self): + return 're_alpha = %e' % (self.re_alpha, ) + +class ModNetBasicBlock(nn.Module): + + def __init__( + self, in_features, out_features, activ, mod_features, + mid_features = None, + demod = True, + style_rezero = True, + style_bias = True, + **kwargs + ): + super().__init__(**kwargs) + + if mid_features is None: + mid_features = out_features + + self.style_block1 = StyleBlock( + mod_features, in_features, style_rezero, style_bias + ) + self.conv_block1 = ModulatedConv2d( + in_features, mid_features, kernel_size = 3, padding = 1, + demod = demod + ) + self.activ1 = get_activ_layer(activ) + + self.style_block2 = StyleBlock( + mod_features, mid_features, style_rezero, style_bias + ) + self.conv_block2 = ModulatedConv2d( + mid_features, out_features, kernel_size = 3, padding = 1, + demod = demod + ) + self.activ2 = get_activ_layer(activ) + + def forward(self, x, mod): + # x : (N, C_in, H_in, W_in) + # mod : (N, mod_features) + + # mod_scale1 : (N, C_out) + mod_scale1 = self.style_block1(mod) + + # y1 : (N, C_mid, H_mid, W_mid) + y1 = self.conv_block1(x, mod_scale1) + y1 = self.activ1(y1) + + # mod_scale2 : (N, C_out) + mod_scale2 = self.style_block2(mod) + + # result : (N, C_out, H_out, W_out) + result = self.conv_block2(y1, mod_scale2) + result = self.activ2(result) + + return result + +class ModNetDecBlock(nn.Module): + + def __init__( + self, input_shape, output_features, skip_features, mod_features, + activ, upsample, + rezero = True, + demod = True, + style_rezero = True, + style_bias = True, + **kwargs + ): + super().__init__(**kwargs) + + (input_features, H, W) = input_shape + self.upsample, input_features = get_upsample_x2_layer( + upsample, input_features + ) + + self.block = ModNetBasicBlock( + skip_features + input_features, output_features, activ, + mod_features = mod_features, + mid_features = max(input_features, input_shape[0]), + demod = demod, + style_rezero = style_rezero, + style_bias = style_bias, + ) + + self._output_shape = (output_features, 2 * H, 2 * W) + + if rezero: + self.re_alpha = nn.Parameter(torch.zeros((1, ))) + else: + self.re_alpha = 1 + + @property + def output_shape(self): + return self._output_shape + + def forward(self, x, r, mod): + # x : (N, C, H_in, W_in) + # r : (N, C, H_out, W_out) + # mod : (N, mod_features) + + # x : (N, C_up, H_out, W_out) + x = self.re_alpha * self.upsample(x) + + # y : (N, C + C_up, H_out, W_out) + y = torch.cat([x, r], dim = 1) + + # result : (N, C_out, H_out, W_out) + return self.block(y, mod) + + def extra_repr(self): + return 're_alpha = %e' % (self.re_alpha, ) + +class ModNetBlock(nn.Module): + + def __init__( + self, features, activ, norm, image_shape, downsample, upsample, + mod_features, + rezero = True, + demod = True, + style_rezero = True, + style_bias = True, + **kwargs + ): + super().__init__(**kwargs) + + self.conv = UNetEncBlock( + features, activ, norm, downsample, image_shape + ) + + self.inner_shape = self.conv.output_shape + self.inner_module = None + + self.deconv = ModNetDecBlock( + input_shape = self.inner_shape, + output_features = image_shape[0], + skip_features = self.inner_shape[0], + mod_features = mod_features, + activ = activ, + upsample = upsample, + rezero = rezero, + demod = demod, + style_rezero = style_rezero, + style_bias = style_bias, + ) + + def get_inner_shape(self): + return self.inner_shape + + def set_inner_module(self, module): + self.inner_module = module + + def get_inner_module(self): + return self.inner_module + + def forward(self, x, mod = None): + # x : (N, C, H, W) + + # y : (N, C_inner, H_inner, W_inner) + # r : (N, C_inner, H, W) + (y, r) = self.conv(x) + + # y : (N, C_inner, H_inner, W_inner) + # mod : (N, mod_features) + if mod is None: + y, mod = self.inner_module(y) + else: + y, mod = self.inner_module(y, mod) + + # y : (N, C, H, W) + y = self.deconv(y, r, mod) + + return (y, mod) + +class ModNetLinearDecoder(nn.Module): + + def __init__( + self, features_list, input_shape, output_shape, skip_shapes, + mod_features, activ, upsample, + rezero = True, + demod = True, + style_rezero = True, + style_bias = True, + **kwargs + ): + # pylint: disable = too-many-locals + super().__init__(**kwargs) + + self.net = nn.ModuleList() + self._input_shape = input_shape + self._output_shape = output_shape + curr_shape = input_shape + + for features, skip_shape in zip( + features_list[::-1], skip_shapes[::-1] + ): + layer = ModNetDecBlock( + input_shape = curr_shape, + output_features = features, + skip_features = skip_shape[0], + mod_features = mod_features, + activ = activ, + upsample = upsample, + rezero = rezero, + demod = demod, + style_rezero = style_rezero, + style_bias = style_bias, + ) + curr_shape = layer.output_shape + + self.net.append(layer) + + self.output = nn.Conv2d( + curr_shape[0], output_shape[0], kernel_size = 1 + ) + curr_shape = (output_shape[0], *curr_shape[1:]) + + assert tuple(output_shape) == tuple(curr_shape) + + @property + def input_shape(self): + return self._input_shape + + @property + def output_shape(self): + return self._output_shape + + def forward(self, x, skip_list, mod): + # x : (N, C, H, W) + # mod : (N, mod_features) + # skip_list : List[(N, C_i, H_i, W_i)] + + y = x + + for layer, skip in zip(self.net, skip_list[::-1]): + y = layer(y, skip, mod) + + return self.output(y) + +class ModNet(nn.Module): + + def __init__( + self, features_list, activ, norm, input_shape, output_shape, + downsample, upsample, + mod_features, + rezero = True, + demod = True, + style_rezero = True, + style_bias = True, + return_mod = False, + **kwargs + ): + # pylint: disable = too-many-locals + super().__init__(**kwargs) + assert tuple(input_shape[1:]) == tuple(output_shape[1:]) + + self.features_list = features_list + self.input_shape = input_shape + self.output_shape = output_shape + self.return_mod = return_mod + + self._construct_input_layer(activ) + self._construct_output_layer() + + unet_layers = [] + curr_image_shape = (features_list[0], *input_shape[1:]) + + for features in features_list: + layer = ModNetBlock( + features, activ, norm, curr_image_shape, downsample, upsample, + mod_features, rezero, demod, style_rezero, style_bias + ) + curr_image_shape = layer.get_inner_shape() + unet_layers.append(layer) + + for idx in range(len(unet_layers)-1): + unet_layers[idx].set_inner_module(unet_layers[idx+1]) + + self.modnet = unet_layers[0] + + def _construct_input_layer(self, activ): + self.layer_input = nn.Sequential( + nn.Conv2d( + self.input_shape[0], self.features_list[0], + kernel_size = 3, padding = 1 + ), + get_activ_layer(activ), + ) + + def _construct_output_layer(self): + self.layer_output = nn.Conv2d( + self.features_list[0], self.output_shape[0], kernel_size = 1 + ) + + def get_innermost_block(self): + result = self.modnet + + for _ in range(len(self.features_list)-1): + result = result.get_inner_module() + + return result + + def set_bottleneck(self, module): + self.get_innermost_block().set_inner_module(module) + + def get_bottleneck(self): + return self.get_innermost_block().get_inner_module() + + def get_inner_shape(self): + return self.get_innermost_block().get_inner_shape() + + def forward(self, x, mod = None): + # x : (N, C, H, W) + + y = self.layer_input(x) + + y, mod = self.modnet(y, mod) + + y = self.layer_output(y) + + if self.return_mod: + return (y, mod) + + return y + diff --git a/uvcgan_s/torch/layers/resnet.py b/uvcgan_s/torch/layers/resnet.py new file mode 100644 index 0000000..607ecb9 --- /dev/null +++ b/uvcgan_s/torch/layers/resnet.py @@ -0,0 +1,480 @@ +import torch +from torch import nn + +from uvcgan_s.torch.select import get_norm_layer, get_activ_layer +from .cnn import calc_conv_output_size + +class ResNetBlock(nn.Module): + + def __init__( + self, features, activ, norm, rezero = False, + kernel_size = 3, bottlneck_features = None, **kwargs + ): + # pylint: disable=too-many-arguments + super().__init__(**kwargs) + + if bottlneck_features is None: + bottlneck_features = features + + self.block = nn.Sequential( + nn.Conv2d( + features, bottlneck_features, + kernel_size = kernel_size, + padding = 'same', + stride = 1, + ), + get_norm_layer(norm, bottlneck_features), + get_activ_layer(activ), + + nn.Conv2d( + bottlneck_features, features, + kernel_size = kernel_size, + padding = 'same', + stride = 1 + ), + get_norm_layer(norm, features), + ) + + self.block_out = get_activ_layer(activ) + + if rezero: + self.re_alpha = nn.Parameter(torch.zeros((1, ))) + else: + self.re_alpha = 1 + + def forward(self, x): + # x : (N, C, H, W) + y = self.block(x) + z = x + self.re_alpha * y + + return self.block_out(z) + + def extra_repr(self): + return 're_alpha = %e' % (self.re_alpha, ) + +class ResNetBlockv2(nn.Module): + + def __init__( + self, features, activ, norm, rezero = False, + kernel_size = 3, bottlneck_features = None, **kwargs + ): + # pylint: disable=too-many-arguments + super().__init__(**kwargs) + + if bottlneck_features is None: + bottlneck_features = features + + self.block = nn.Sequential( + get_norm_layer(norm, bottlneck_features), + get_activ_layer(activ), + + nn.Conv2d( + features, bottlneck_features, + kernel_size = kernel_size, + padding = 'same', + stride = 1, + ), + + get_norm_layer(norm, features), + get_activ_layer(activ), + + nn.Conv2d( + bottlneck_features, features, + kernel_size = kernel_size, + padding = 'same', + stride = 1 + ), + ) + + if rezero: + self.re_alpha = nn.Parameter(torch.zeros((1, ))) + else: + self.re_alpha = 1 + + def forward(self, x): + # x : (N, C, H, W) + y = self.block(x) + z = x + self.re_alpha * y + + return self.block_out(z) + + def extra_repr(self): + return 're_alpha = %e' % (self.re_alpha, ) + +class BigGanResDownBlock(nn.Module): + + def __init__( + self, input_shape, features, activ, norm, rezero = False, + kernel_size = 3, bottlneck_features = None, n_blocks = 1, **kwargs + ): + # pylint: disable=too-many-arguments + super().__init__(**kwargs) + + if bottlneck_features is None: + bottlneck_features = features + + layers = [] + + curr_features = input_shape[0] + + for _ in range(n_blocks): + block = nn.Sequential( + get_norm_layer(norm, curr_features), + get_activ_layer(activ), + nn.Conv2d( + curr_features, bottlneck_features, + kernel_size = kernel_size, padding = 'same', stride = 1, + ), + get_norm_layer(norm, bottlneck_features), + get_activ_layer(activ), + nn.Conv2d( + bottlneck_features, features, + kernel_size = kernel_size, padding = 'same', stride = 1 + ) + ) + + layers.append(block) + curr_features = features + + layers.append(nn.AvgPool2d(kernel_size = 2, stride = 2)) + + self.net_main = nn.Sequential(*layers) + self.net_res = nn.Sequential( + nn.Conv2d(input_shape[0], features, kernel_size = 1), + nn.AvgPool2d(kernel_size = 2, stride = 2) + ) + + self._input_shape = input_shape + self._output_shape = ( + features, input_shape[1] // 2, input_shape[2] // 2 + ) + + if rezero: + self.re_alpha = nn.Parameter(torch.zeros((1, ))) + else: + self.re_alpha = 1 + + def forward(self, x): + # x : (N, C, H, W) + main = self.net_main(x) + res = self.net_res(x) + + return res + self.re_alpha * main + + @property + def input_shape(self): + return self._input_shape + + @property + def output_shape(self): + return self._output_shape + + def extra_repr(self): + return 're_alpha = %e' % (self.re_alpha, ) + +class BigGanDeepResDownBlock(nn.Module): + + def __init__( + self, input_shape, features, activ, norm, rezero = False, + kernel_size = 3, bottlneck_features = None, n_blocks = 1, **kwargs + ): + # pylint: disable=too-many-arguments + super().__init__(**kwargs) + + if bottlneck_features is None: + bottlneck_features = features + + layers = [] + layers.append(nn.Sequential( + get_norm_layer(norm, input_shape[0]), + get_activ_layer(activ), + nn.Conv2d(input_shape[0], features, kernel_size = 1) + )) + + for _ in range(n_blocks): + block = nn.Sequential( + get_norm_layer(norm, features), + get_activ_layer(activ), + nn.Conv2d( + features, bottlneck_features, + kernel_size = kernel_size, padding = 'same', stride = 1, + ), + get_norm_layer(norm, bottlneck_features), + get_activ_layer(activ), + nn.Conv2d( + bottlneck_features, features, + kernel_size = kernel_size, padding = 'same', stride = 1 + ) + ) + + layers.append(block) + + layers.append(nn.Sequential( + nn.AvgPool2d(kernel_size = 2, stride = 2), + nn.Conv2d(features, features, kernel_size = 1) + )) + + self.net_main = nn.Sequential(*layers) + self.net_res_stem = nn.AvgPool2d(kernel_size = 2, stride = 2) + + if features > input_shape[0]: + self.net_res_side = nn.Conv2d( + input_shape[0], features - input_shape[0], kernel_size = 1 + ) + else: + self.net_res_side = None + + self._input_shape = input_shape + self._output_shape = ( + features, input_shape[1] // 2, input_shape[2] // 2 + ) + + if rezero: + self.re_alpha = nn.Parameter(torch.zeros((1, ))) + else: + self.re_alpha = 1 + + def forward(self, x): + # x : (N, C, H, W) + main = self.net_main(x) + res_stem = self.net_res_stem(x) + + if self.net_res_side is None: + res = res_stem + else: + res_side = self.net_res_side(res_stem) + res = torch.cat((res_stem, res_side), dim = 1) + + return res + self.re_alpha * main + + @property + def input_shape(self): + return self._input_shape + + @property + def output_shape(self): + return self._output_shape + + def extra_repr(self): + return 're_alpha = %e' % (self.re_alpha, ) + + +class ResNetStem(nn.Module): + + def __init__( + self, input_shape, features, norm = None, + kernel_size = 4, padding = 0, stride = 4 + ): + # pylint: disable=too-many-arguments + super().__init__() + + self.net = nn.Sequential( + nn.Conv2d( + input_shape[0], features, + kernel_size = kernel_size, padding = padding, stride = stride + ), + get_norm_layer(norm, features), + ) + + self._input_shape = input_shape + self._output_shape = ( + features, + *calc_conv_output_size( + input_shape[1:], kernel_size, padding, stride + ) + ) + + @property + def input_shape(self): + return self._input_shape + + @property + def output_shape(self): + return self._output_shape + + def forward(self, x): + return self.net(x) + +class ResNetEncoder(nn.Module): + # pylint: disable=too-many-instance-attributes + + def _make_stem_block(self, block_spec, curr_shape): + # pylint: disable=no-self-use + block = ResNetStem(curr_shape, **block_spec) + curr_shape = block.output_shape + + return (block, curr_shape) + + def _make_resample_block(self, block_spec, curr_shape): + # pylint: disable=no-self-use + if isinstance(block_spec, (list, tuple)): + size, kwargs = block_spec + else: + size = block_spec + kwargs = {} + + if isinstance(size, int): + size = (size, size) + + features = curr_shape[0] + block = nn.Upsample(size, **kwargs) + curr_shape = (features, *size) + + return (block, curr_shape) + + def _make_resnet_block(self, block_spec, curr_shape): + if isinstance(block_spec, (list, tuple)): + n_blocks, kwargs = block_spec + else: + n_blocks = block_spec + kwargs = {} + + features = curr_shape[0] + block = nn.Sequential( + *[ ResNetBlock( + features, + activ = self._activ, + norm = self._norm, + rezero = self._rezero, + **kwargs + ) + for _ in range(n_blocks) ] + ) + + return (block, curr_shape) + + def _make_resnet_block_v2(self, block_spec, curr_shape): + if isinstance(block_spec, (list, tuple)): + n_blocks, kwargs = block_spec + else: + n_blocks = block_spec + kwargs = {} + + features = curr_shape[0] + block = nn.Sequential( + ResNetBlockv2( + features, + activ = self._activ, + norm = self._norm, + rezero = self._rezero, + **kwargs + ) for _ in range(n_blocks) + ) + + return (block, curr_shape) + + def _make_biggan_resdown_block(self, block_spec, curr_shape): + # pylint: disable=no-self-use + block = BigGanResDownBlock( + curr_shape, + activ = self._activ, + norm = self._norm, + rezero = self._rezero, + **block_spec + ) + return (block, block.output_shape) + + def _make_biggan_deep_resdown_block(self, block_spec, curr_shape): + # pylint: disable=no-self-use + block = BigGanDeepResDownBlock( + curr_shape, + activ = self._activ, + norm = self._norm, + rezero = self._rezero, + **block_spec + ) + return (block, block.output_shape) + + def __init__( + self, input_shape, block_specs, activ, norm, rezero = True + ): + # pylint: disable=too-many-arguments + # pylint: disable=too-many-locals + super().__init__() + + curr_shape = input_shape + self.blocks = nn.ModuleList() + + self._activ = activ + self._norm = norm + self._rezero = rezero + + block_idx = 0 + skip_indices = set() + skip_shapes = [] + + for (block_type, block_spec) in block_specs: + if block_type == 'stem': + block, curr_shape \ + = self._make_stem_block(block_spec, curr_shape) + + elif block_type == 'resample': + block, curr_shape \ + = self._make_resample_block(block_spec, curr_shape) + + elif block_type == 'resnet': + block, curr_shape \ + = self._make_resnet_block(block_spec, curr_shape) + + elif block_type == 'resnet-v2': + block, curr_shape \ + = self._make_resnet_block_v2(block_spec, curr_shape) + + elif block_type == 'biggan-resdown': + block, curr_shape \ + = self._make_biggan_resdown_block(block_spec, curr_shape) + + elif block_type == 'biggan-deep-resdown': + block, curr_shape = self._make_biggan_deep_resdown_block( + block_spec, curr_shape + ) + + elif block_type == 'skip': + skip_indices.add(block_idx) + skip_shapes.append(curr_shape) + continue + + else: + raise ValueError(f"Unknown block type: {block_type}") + + self.blocks.append(block) + block_idx += 1 + + self._input_shape = input_shape + self._output_shape = curr_shape + self._skip_shapes = skip_shapes + self._skip_indices = skip_indices + + @property + def output_shape(self): + return self._output_shape + + @property + def input_shape(self): + return self._input_shape + + @property + def skip_indices(self): + return self._skip_indices + + @property + def skip_shapes(self): + return self._skip_shapes + + def forward(self, x, return_skips = False): + if return_skips: + skips = [] + + y = x + + for idx, block in enumerate(self.blocks): + if return_skips and (idx in self._skip_indices): + skips.append(y) + + y = block(y) + + if return_skips: + return (y, skips) + + return y + diff --git a/uvcgan_s/torch/layers/transformer.py b/uvcgan_s/torch/layers/transformer.py new file mode 100644 index 0000000..52ac722 --- /dev/null +++ b/uvcgan_s/torch/layers/transformer.py @@ -0,0 +1,415 @@ +# pylint: disable=too-many-arguments +# pylint: disable=too-many-instance-attributes + +import torch +from torch import nn + +from uvcgan_s.torch.select import get_norm_layer, get_activ_layer + +def calc_tokenized_size(image_shape, token_size): + # image_shape : (C, H, W) + # token_size : (H_t, W_t) + if image_shape[1] % token_size[0] != 0: + raise ValueError( + "Token width %d does not divide image width %d" % ( + token_size[0], image_shape[1] + ) + ) + + if image_shape[2] % token_size[1] != 0: + raise ValueError( + "Token height %d does not divide image height %d" % ( + token_size[1], image_shape[2] + ) + ) + + # result : (N_h, N_w) + return (image_shape[1] // token_size[0], image_shape[2] // token_size[1]) + +def img_to_tokens(image_batch, token_size): + # image_batch : (N, C, H, W) + # token_size : (H_t, W_t) + + # result : (N, C, N_h, H_t, W) + result = image_batch.view( + (*image_batch.shape[:2], -1, token_size[0], image_batch.shape[3]) + ) + + # result : (N, C, N_h, H_t, W ) + # -> (N, C, N_h, H_t, N_w, W_t) + result = result.view((*result.shape[:4], -1, token_size[1])) + + # result : (N, C, N_h, H_t, N_w, W_t) + # -> (N, N_h, N_w, C, H_t, W_t) + result = result.permute((0, 2, 4, 1, 3, 5)) + + return result + +def img_from_tokens(tokens): + # tokens : (N, N_h, N_w, C, H_t, W_t) + # result : (N, C, N_h, H_t, N_w, W_t) + result = tokens.permute((0, 3, 1, 4, 2, 5)) + + # result : (N, C, N_h, H_t, N_w, W_t) + # -> (N, C, N_h, H_t, N_w * W_t) + # = (N, C, N_h, H_t, W) + result = result.reshape((*result.shape[:4], -1)) + + # result : (N, C, N_h, H_t, W) + # -> (N, C, N_h * H_t, W) + # = (N, C, H, W) + result = result.reshape((*result.shape[:2], -1, result.shape[4])) + + return result + +def img_to_pixelwise_tokens(image): + # image : (N, C, H, W) + + # result : (N, C, H * W) + result = image.view(*image.shape[:2], -1) + + # result : (N, C, H * W) + # -> (N, H * W, C ) + # = (N, L, C) + result = result.permute((0, 2, 1)) + + # (N, L, C) + return result + +def img_from_pixelwise_tokens(tokens, image_shape): + # tokens : (N, L, C) + # image_shape : (3, ) + + # tokens : (N, L, C) + # -> (N, C, L) + # = (N, C, H * W) + tokens = tokens.permute((0, 2, 1)) + + # (N, C, H, W) + return tokens.view(*tokens.shape[:2], *image_shape[1:]) + +class PositionWiseFFN(nn.Module): + + def __init__(self, features, ffn_features, activ = 'gelu', **kwargs): + super().__init__(**kwargs) + + self.net = nn.Sequential( + nn.Linear(features, ffn_features), + get_activ_layer(activ), + nn.Linear(ffn_features, features), + ) + + def forward(self, x): + return self.net(x) + +class TransformerBlock(nn.Module): + + def __init__( + self, features, ffn_features, n_heads, activ = 'gelu', norm = None, + rezero = True, **kwargs + ): + super().__init__(**kwargs) + + self.norm1 = get_norm_layer(norm, features) + self.atten = nn.MultiheadAttention(features, n_heads) + + self.norm2 = get_norm_layer(norm, features) + self.ffn = PositionWiseFFN(features, ffn_features, activ) + + self.rezero = rezero + + if rezero: + self.re_alpha = nn.Parameter(torch.zeros((1, ))) + else: + self.re_alpha = 1 + + def forward(self, x): + # x: (L, N, features) + + # Step 1: Multi-Head Self Attention + y1 = self.norm1(x) + y1, _atten_weights = self.atten(y1, y1, y1) + + y = x + self.re_alpha * y1 + + # Step 2: PositionWise Feed Forward Network + y2 = self.norm2(y) + y2 = self.ffn(y2) + + y = y + self.re_alpha * y2 + + return y + + def extra_repr(self): + return 're_alpha = %e' % (self.re_alpha, ) + +class TransformerEncoder(nn.Module): + + def __init__( + self, features, ffn_features, n_heads, n_blocks, activ, norm, + rezero = True, **kwargs + ): + super().__init__(**kwargs) + + self.encoder = nn.Sequential(*[ + TransformerBlock( + features, ffn_features, n_heads, activ, norm, rezero + ) for _ in range(n_blocks) + ]) + + def forward(self, x): + # x : (N, L, features) + + # y : (L, N, features) + y = x.permute((1, 0, 2)) + y = self.encoder(y) + + # result : (N, L, features) + result = y.permute((1, 0, 2)) + + return result + +class FourierEmbedding(nn.Module): + # arXiv: 2011.13775 + + def __init__(self, features, height, width, **kwargs): + super().__init__(**kwargs) + self.projector = nn.Linear(2, features) + self._height = height + self._width = width + + def forward(self, y, x): + # x : (N, L) + # y : (N, L) + x_norm = 2 * x / (self._width - 1) - 1 + y_norm = 2 * y / (self._height - 1) - 1 + + # z : (N, L, 2) + z = torch.cat((x_norm.unsqueeze(2), y_norm.unsqueeze(2)), dim = 2) + + return torch.sin(self.projector(z)) + +class ViTInput(nn.Module): + + def __init__( + self, input_features, embed_features, features, height, width, + **kwargs + ): + super().__init__(**kwargs) + self._height = height + self._width = width + + x = torch.arange(width).to(torch.float32) + y = torch.arange(height).to(torch.float32) + + x, y = torch.meshgrid(x, y) + self.x = x.reshape((1, -1)) + self.y = y.reshape((1, -1)) + + self.register_buffer('x_const', self.x) + self.register_buffer('y_const', self.y) + + self.embed = FourierEmbedding(embed_features, height, width) + self.output = nn.Linear(embed_features + input_features, features) + + def forward(self, x): + # x : (N, L, input_features) + # embed : (1, height * width, embed_features) + # = (1, L, embed_features) + embed = self.embed(self.y_const, self.x_const) + + # embed : (1, L, embed_features) + # -> (N, L, embed_features) + embed = embed.expand((x.shape[0], *embed.shape[1:])) + + # result : (N, L, embed_features + input_features) + result = torch.cat([embed, x], dim = 2) + + # (N, L, features) + return self.output(result) + +class PixelwiseViT(nn.Module): + + def __init__( + self, features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, image_shape, rezero = True, **kwargs + ): + super().__init__(**kwargs) + + self.image_shape = image_shape + + self.trans_input = ViTInput( + image_shape[0], embed_features, features, + image_shape[1], image_shape[2], + ) + + self.encoder = TransformerEncoder( + features, ffn_features, n_heads, n_blocks, activ, norm, rezero + ) + + self.trans_output = nn.Linear(features, image_shape[0]) + + def forward(self, x): + # x : (N, C, H, W) + + # itokens : (N, C, H * W) + itokens = x.view(*x.shape[:2], -1) + + # itokens : (N, C, H * W) + # -> (N, H * W, C ) + # = (N, L, C) + itokens = itokens.permute((0, 2, 1)) + + # y : (N, L, features) + y = self.trans_input(itokens) + y = self.encoder(y) + + # otokens : (N, L, C) + otokens = self.trans_output(y) + + # otokens : (N, L, C) + # -> (N, C, L) + # = (N, C, H * W) + otokens = otokens.permute((0, 2, 1)) + + # result : (N, C, H, W) + result = otokens.view(*otokens.shape[:2], *self.image_shape[1:]) + + return result + +class ExtendedTransformerEncoder(nn.Module): + + def __init__( + self, features, n_heads, n_blocks, ffn_features, activ, norm, + rezero = True, n_ext = 1, **kwargs + ): + super().__init__(**kwargs) + + self.encoder = TransformerEncoder( + features, ffn_features, n_heads, n_blocks, activ, norm, rezero + ) + + self.extra_tokens = nn.Parameter(torch.empty((1, n_ext, features))) + torch.nn.init.normal_(self.extra_tokens) + + def forward(self, x): + # x : (N, L, C) + N, L, _C = x.shape + + # i_extra_tokens : (N, n_extra, C) + i_extra_tokens = self.extra_tokens.tile(N, 1, 1) + + # y : (N, L + n_extra, C) + y = torch.cat([ x, i_extra_tokens ], dim = 1) + y = self.encoder(y) + + # o_extra_tokens : (N, n_extra, features) + o_extra_tokens = y[:, L:, :] + + # result : (N, L, C) + result = y[:, :L, :] + + return (result, o_extra_tokens.reshape(N, -1)) + +class ExtendedPixelwiseViT(nn.Module): + + def __init__( + self, features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, image_shape, rezero = True, n_ext = 1, **kwargs + ): + super().__init__(**kwargs) + + self.image_shape = image_shape + + self.trans_input = ViTInput( + image_shape[0], embed_features, features, + image_shape[1], image_shape[2], + ) + + self.encoder = TransformerEncoder( + features, ffn_features, n_heads, n_blocks, activ, norm, rezero + ) + + self.extra_tokens = nn.Parameter(torch.empty((1, n_ext, features))) + torch.nn.init.normal_(self.extra_tokens) + + self.trans_output = nn.Linear(features, image_shape[0]) + + def forward(self, x): + # x : (N, C, H, W) + + # itokens : (N, L, C) + itokens = img_to_pixelwise_tokens(x) + (N, L, _C) = itokens.shape + + # i_extra_tokens : (N, n_extra, C) + i_extra_tokens = self.extra_tokens.tile(itokens.shape[0], 1, 1) + + # y : (N, L, features) + y = self.trans_input(itokens) + + # y : (N, L + n_extra, C) + y = torch.cat([ y, i_extra_tokens ], dim = 1) + y = self.encoder(y) + + # o_extra_tokens : (N, n_extra, features) + o_extra_tokens = y[:, L:, :] + + # otokens : (N, L, C) + otokens = self.trans_output(y[:, :L, :]) + + # result : (N, C, H, W) + result = img_from_pixelwise_tokens(otokens, self.image_shape) + + return (result, o_extra_tokens.reshape(N, -1)) + +class CExtPixelwiseViT(nn.Module): + + def __init__( + self, features, n_heads, n_blocks, ffn_features, embed_features, + activ, norm, image_shape, rezero = True, **kwargs + ): + super().__init__(**kwargs) + + self.image_shape = image_shape + + self.trans_input = ViTInput( + image_shape[0], embed_features, features, + image_shape[1], image_shape[2], + ) + + self.encoder = TransformerEncoder( + features, ffn_features, n_heads, n_blocks, activ, norm, rezero + ) + + self.trans_output = nn.Linear(features, image_shape[0]) + + def forward(self, x, control): + # x : (N, C, H, W) + + # itokens : (N, L, C) + itokens = img_to_pixelwise_tokens(x) + (_N, L, _C) = itokens.shape + + # control : (N, C) + # i_control : (N, 1, C) + i_control = control.unsqueeze(1) + + # y : (N, L, features) + y = self.trans_input(itokens) + + # y : (N, L + 1, C) + y = torch.cat([ y, i_control ], dim = 1) + y = self.encoder(y) + + # o_control : (N, 1, features) + o_control = y[:, L:, :] + + # otokens : (N, L, C) + otokens = self.trans_output(y[:, :L, :]) + + # result : (N, C, H, W) + result = img_from_pixelwise_tokens(otokens, self.image_shape) + + return (result, o_control.squeeze(1)) + diff --git a/uvcgan_s/torch/layers/unet.py b/uvcgan_s/torch/layers/unet.py new file mode 100644 index 0000000..c8ede23 --- /dev/null +++ b/uvcgan_s/torch/layers/unet.py @@ -0,0 +1,331 @@ +# pylint: disable=too-many-arguments +# pylint: disable=too-many-instance-attributes + +import torch +from torch import nn + +from uvcgan_s.torch.select import get_norm_layer, get_activ_layer + +from .cnn import get_downsample_x2_layer, get_upsample_x2_layer + +class UnetBasicBlock(nn.Module): + + def __init__( + self, in_features, out_features, activ, norm, mid_features = None, + **kwargs + ): + super().__init__(**kwargs) + + if mid_features is None: + mid_features = out_features + + self.block = nn.Sequential( + get_norm_layer(norm, in_features), + nn.Conv2d(in_features, mid_features, kernel_size = 3, padding = 1), + get_activ_layer(activ), + + get_norm_layer(norm, mid_features), + nn.Conv2d( + mid_features, out_features, kernel_size = 3, padding = 1 + ), + get_activ_layer(activ), + ) + + def forward(self, x): + return self.block(x) + +class UNetEncBlock(nn.Module): + + def __init__( + self, features, activ, norm, downsample, input_shape, **kwargs + ): + super().__init__(**kwargs) + + self.downsample, output_features = \ + get_downsample_x2_layer(downsample, features) + + (C, H, W) = input_shape + self.block = UnetBasicBlock(C, features, activ, norm) + + self._output_shape = (output_features, H//2, W//2) + + @property + def output_shape(self): + return self._output_shape + + def forward(self, x): + r = self.block(x) + y = self.downsample(r) + return (y, r) + +class UNetDecBlock(nn.Module): + + def __init__( + self, input_shape, output_features, skip_features, activ, norm, + upsample, rezero = False, **kwargs + ): + super().__init__(**kwargs) + + (input_features, H, W) = input_shape + self.upsample, input_features = get_upsample_x2_layer( + upsample, input_features + ) + + self.block = UnetBasicBlock( + input_features + skip_features, output_features, activ, norm, + mid_features = max(skip_features, input_features, output_features) + ) + + self._output_shape = (output_features, 2 * H, 2 * W) + + if rezero: + self.re_alpha = nn.Parameter(torch.zeros((1, ))) + else: + self.re_alpha = 1 + + @property + def output_shape(self): + return self._output_shape + + def forward(self, x, r): + # x : (N, C, H_in, W_in) + # r : (N, C_skip, H_out, W_out) + + # x : (N, C_up, H_out, W_out) + x = self.re_alpha * self.upsample(x) + + # y : (N, C_skip + C_up, H_out, W_out) + y = torch.cat([x, r], dim = 1) + + # result : (N, C_out, H_out, W_out) + return self.block(y) + + def extra_repr(self): + return 're_alpha = %e' % (self.re_alpha, ) + +class UNetBlock(nn.Module): + + def __init__( + self, features, activ, norm, image_shape, downsample, upsample, + rezero = True, output_features = None, **kwargs + ): + super().__init__(**kwargs) + + self.conv = UNetEncBlock( + features, activ, norm, downsample, image_shape + ) + + self.inner_shape = self.conv.output_shape + self.inner_module = None + + if output_features is None: + output_features = image_shape[0] + + self.deconv = UNetDecBlock( + self.inner_shape, output_features, self.inner_shape[0], + activ, norm, upsample, rezero + ) + + def get_inner_shape(self): + return self.inner_shape + + def set_inner_module(self, module): + self.inner_module = module + + def get_inner_module(self): + return self.inner_module + + def forward(self, x): + # x : (N, C, H, W) + + # y : (N, C_inner, H_inner, W_inner) + # r : (N, C_inner, H, W) + (y, r) = self.conv(x) + + # y : (N, C_inner, H_inner, W_inner) + y = self.inner_module(y) + + # y : (N, C, H, W) + y = self.deconv(y, r) + + return y + +class UNetLinearEncoder(nn.Module): + + def __init__( + self, features_list, image_shape, activ, norm, downsample, **kwargs + ): + # pylint: disable = too-many-locals + super().__init__(**kwargs) + + self.features_list = features_list + self.image_shape = image_shape + self._skip_shapes = [] + + self.net = nn.ModuleList() + curr_shape = image_shape + + for features in features_list: + layer = UNetEncBlock(features, activ, norm, downsample, curr_shape) + self.net.append(layer) + + curr_shape = layer.output_shape + self._skip_shapes.append(curr_shape) + + self._output_shape = curr_shape + + @property + def output_shape(self): + return self._output_shape + + @property + def skip_shapes(self): + return self._skip_shapes + + def forward(self, x, return_skips = True): + # x : (N, C, H, W) + + skips = [ ] + y = x + + for layer in self.net: + y, r = layer(y) + if return_skips: + skips.append(r) + + if return_skips: + return y, skips + else: + return y + +class UNetLinearDecoder(nn.Module): + + def __init__( + self, features_list, input_shape, output_shape, skip_shapes, + activ, norm, upsample, **kwargs + ): + # pylint: disable = too-many-locals + super().__init__(**kwargs) + + self.net = nn.ModuleList() + self._input_shape = input_shape + self._output_shape = output_shape + curr_shape = input_shape + + for features, skip_shape in zip( + features_list[::-1], skip_shapes[::-1] + ): + layer = UNetDecBlock( + curr_shape, features, skip_shape[0], activ, norm, upsample + ) + curr_shape = layer.output_shape + + self.net.append(layer) + + if output_shape[0] == curr_shape[0]: + self.output = nn.Identity() + else: + self.output = nn.Conv2d( + curr_shape[0], output_shape[0], kernel_size = 1 + ) + + curr_shape = (output_shape[0], *curr_shape[1:]) + assert tuple(output_shape) == tuple(curr_shape) + + @property + def input_shape(self): + return self._input_shape + + @property + def output_shape(self): + return self._output_shape + + def forward(self, x, skip_list): + # x : (N, C, H, W) + # skip_list : List[(N, C_i, H_i, W_i)] + + y = x + + for layer, skip in zip(self.net, skip_list[::-1]): + y = layer(y, skip) + + return self.output(y) + +class UNet(nn.Module): + + def __init__( + self, features_list, activ, norm, image_shape, downsample, upsample, + rezero = True, output_shape = None, **kwargs + ): + # pylint: disable = too-many-locals + super().__init__(**kwargs) + + self.features_list = features_list + self.image_shape = image_shape + + if output_shape is None: + output_shape = image_shape + + assert tuple(output_shape[1:]) == tuple(image_shape[1:]) + + self.output_shape = output_shape + + self._construct_input_layer(activ) + self._construct_output_layer() + + unet_layers = [] + curr_image_shape = (features_list[0], *image_shape[1:]) + + for features in features_list: + layer = UNetBlock( + features, activ, norm, curr_image_shape, downsample, upsample, + rezero + ) + curr_image_shape = layer.get_inner_shape() + unet_layers.append(layer) + + for idx in range(len(unet_layers)-1): + unet_layers[idx].set_inner_module(unet_layers[idx+1]) + + self.unet = unet_layers[0] + + def _construct_input_layer(self, activ): + self.layer_input = nn.Sequential( + nn.Conv2d( + self.image_shape[0], self.features_list[0], + kernel_size = 3, padding = 1 + ), + get_activ_layer(activ), + ) + + def _construct_output_layer(self): + self.layer_output = nn.Conv2d( + self.features_list[0], self.output_shape[0], kernel_size = 1 + ) + + def get_innermost_block(self): + result = self.unet + + for _ in range(len(self.features_list)-1): + result = result.get_inner_module() + + return result + + def set_bottleneck(self, module): + self.get_innermost_block().set_inner_module(module) + + def get_bottleneck(self): + return self.get_innermost_block().get_inner_module() + + def get_inner_shape(self): + return self.get_innermost_block().get_inner_shape() + + def forward(self, x): + # x : (N, C, H, W) + + y = self.layer_input(x) + y = self.unet(y) + y = self.layer_output(y) + + return y + diff --git a/uvcgan_s/torch/lr_equal.py b/uvcgan_s/torch/lr_equal.py new file mode 100644 index 0000000..4ff4819 --- /dev/null +++ b/uvcgan_s/torch/lr_equal.py @@ -0,0 +1,59 @@ +import math + +import torch +from torch import nn + +from torch.nn.init import _calculate_correct_fan, calculate_gain + +class LearningRateEqualizer(nn.Module): + + def __init__(self, mode = 'fan_in', nonlinearity = 'leaky_relu', a = 0): + super().__init__() + + self._mode = mode + self._activ = nonlinearity + self._param = a + + def _calc_scale(self, w): + fan = _calculate_correct_fan(w, self._mode) + gain = calculate_gain(self._activ, self._param) + std = gain / math.sqrt(fan) + + return std + + def forward(self, w): + scale = self._calc_scale(w) + return w * scale + +# NOTE: +# WARNING: +# Behavior of parametrization changes between pytorch 1.9 and 1.10. +# In pytorch 1.9 original_weight = unparametrized weight +# In pytorch >1.9 original_weight = right_inverse(unparametrized weight) +# To make the behavior consistent -- the right_inverse is masked for now. +# +# def right_inverse(self, w): +# scale = self._calc_scale(w) +# return w / scale + +def apply_lr_equal_to_module(module, name, param): + if isinstance(module, torch.nn.utils.parametrize.ParametrizationList): + return + + if not hasattr(module, name): + return + + w = getattr(module, name) + + if (w is None) or len(w.shape) < 2: + return + + torch.nn.utils.parametrize.register_parametrization(module, name, param) + +def apply_lr_equal(module, tensor_name = "weight", **kwargs): + parametrization = LearningRateEqualizer(**kwargs) + submodule_list = [ x[1] for x in module.named_modules() ] + + for m in submodule_list: + apply_lr_equal_to_module(m, tensor_name, parametrization) + diff --git a/uvcgan_s/torch/queue.py b/uvcgan_s/torch/queue.py new file mode 100644 index 0000000..0434dba --- /dev/null +++ b/uvcgan_s/torch/queue.py @@ -0,0 +1,81 @@ +from collections import deque +import torch + +class Queue: + + def __init__(self, size): + self._queue = deque(maxlen = size) + + def __len__(self): + return len(self._queue) + + def push(self, x): + self._queue.append(x.detach()) + + def query(self): + return tuple(self._queue) + +# NOTE: FastQueue differs from Queue +# Queue size means the number of objects to store +# FastQueue size means the full length of the queue tensor + +class FastQueue: + + def __init__(self, size, device): + self._queue = None + self._size = size + self._device = device + self._curr_idx = 0 + self._full = False + + def __len__(self): + if self._full: + return self._size + + return self._curr_idx + + def lazy_init_queue(self, x): + if self._queue is not None: + return + + self._queue = torch.empty( + (self._size, *x.shape[1:]), dtype = x.dtype, device = self._device + ) + + def push(self, x): + self.lazy_init_queue(x) + + n = x.shape[0] + n_avail_to_end = self._size - self._curr_idx + + if n > self._size: + x = x[-self._size:, ...] + n = self._size + + if n_avail_to_end <= n: + self._queue[self._curr_idx:self._size, ...] \ + = x[:n_avail_to_end, ...].detach().to( + self._device, non_blocking = True + ) + + self._curr_idx = 0 + self._full = True + + if n_avail_to_end < n: + self.push(x[n_avail_to_end:, ...]) + + else: + self._queue[self._curr_idx:self._curr_idx + n, ...] \ + = x.detach().to(self._device) + + self._curr_idx += n + + def query(self): + if self._queue is None: + return None + + if self._full: + return self._queue + + return self._queue[:self._curr_idx, ...] + diff --git a/uvcgan_s/torch/select.py b/uvcgan_s/torch/select.py new file mode 100644 index 0000000..8381da5 --- /dev/null +++ b/uvcgan_s/torch/select.py @@ -0,0 +1,107 @@ +import copy +import torch +from torch import nn + +from .layers.activation import Exponential + +def extract_name_kwargs(obj): + if isinstance(obj, dict): + obj = copy.copy(obj) + name = obj.pop('name') + kwargs = obj + else: + name = obj + kwargs = {} + + return (name, kwargs) + +def get_norm_layer(norm, features): + name, kwargs = extract_name_kwargs(norm) + + if name is None: + return nn.Identity(**kwargs) + + if name == 'layer': + return nn.LayerNorm((features,), **kwargs) + + if name == 'batch': + return nn.BatchNorm2d(features, **kwargs) + + if name == 'group': + return nn.GroupNorm(num_channels = features, **kwargs) + + if name == 'instance': + return nn.InstanceNorm2d(features, **kwargs) + + raise ValueError("Unknown Layer: '%s'" % name) + +def get_norm_layer_fn(norm): + return lambda features : get_norm_layer(norm, features) + +def get_activ_layer(activ): + # pylint: disable=too-many-return-statements + name, kwargs = extract_name_kwargs(activ) + + if (name is None) or (name == 'linear'): + return nn.Identity() + + if name == 'gelu': + return nn.GELU(**kwargs) + + if name == 'selu': + return nn.SELU(**kwargs) + + if name == 'relu': + return nn.ReLU(inplace = True, **kwargs) + + if name == 'leakyrelu': + return nn.LeakyReLU(inplace = True, **kwargs) + + if name == 'tanh': + return nn.Tanh() + + if name == 'sigmoid': + return nn.Sigmoid() + + if name == 'exp': + return Exponential(**kwargs) + + raise ValueError("Unknown activation: '%s'" % name) + +def select_activation(activation): + return get_activ_layer(activation) + +def select_optimizer(parameters, optimizer): + name, kwargs = extract_name_kwargs(optimizer) + + if name == 'AdamW': + return torch.optim.AdamW(parameters, **kwargs) + + if name == 'SGD': + return torch.optim.SGD(parameters, **kwargs) + + if name == 'Adam': + return torch.optim.Adam(parameters, **kwargs) + + raise ValueError("Unknown optimizer: '%s'" % name) + +def select_loss(loss, reduction = None): + name, kwargs = extract_name_kwargs(loss) + + if reduction is not None: + kwargs['reduction'] = reduction + + if name.lower() in [ 'l1', 'mae' ]: + return nn.L1Loss(**kwargs) + + if name.lower() in [ 'l2', 'mse' ]: + return nn.MSELoss(**kwargs) + + if name.lower() in [ 'bce', 'binary-cross-entropy' ]: + return nn.BCELoss(**kwargs) + + if name.lower() in [ 'bce-logits', ]: + return nn.BCEWithLogitsLoss(**kwargs) + + raise ValueError("Unknown loss: '%s'" % name) + diff --git a/uvcgan_s/torch/spectr_norm.py b/uvcgan_s/torch/spectr_norm.py new file mode 100644 index 0000000..f714204 --- /dev/null +++ b/uvcgan_s/torch/spectr_norm.py @@ -0,0 +1,23 @@ +import torch +from torch.nn.utils.parametrizations import spectral_norm + +def apply_sn_to_module(module, name, n_power_iterations = 1): + if isinstance(module, torch.nn.utils.parametrize.ParametrizationList): + return + + if not hasattr(module, name): + return + + w = getattr(module, name) + + if (w is None) or len(w.shape) < 2: + return + + spectral_norm(module, name, n_power_iterations) + +def apply_sn(module, tensor_name = 'weight', n_power_iterations = 1): + submodule_list = list(module.modules()) + + for m in submodule_list: + apply_sn_to_module(m, tensor_name, n_power_iterations) + diff --git a/uvcgan_s/train/__init__.py b/uvcgan_s/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/uvcgan_s/train/callbacks/__init__.py b/uvcgan_s/train/callbacks/__init__.py new file mode 100644 index 0000000..92e5104 --- /dev/null +++ b/uvcgan_s/train/callbacks/__init__.py @@ -0,0 +1 @@ +from .history import TrainingHistory diff --git a/uvcgan_s/train/callbacks/history.py b/uvcgan_s/train/callbacks/history.py new file mode 100644 index 0000000..4da6b92 --- /dev/null +++ b/uvcgan_s/train/callbacks/history.py @@ -0,0 +1,39 @@ +import os +import pandas as pd + +HISTORY_NAME = 'history.csv' + +class TrainingHistory: + + def __init__(self, savedir): + self._history = None + self._savedir = savedir + + def end_epoch(self, epoch, metrics): + values = metrics.values + values['epoch'] = epoch + values['time'] = pd.Timestamp.utcnow() + + if self._history is None: + self._history = pd.DataFrame([ values, ]) + else: + self._history = self._history.append([ values, ]) + + self.save() + + def save(self): + self._history.to_csv( + os.path.join(self._savedir, HISTORY_NAME), index = False + ) + + def load(self): + path = os.path.join(self._savedir, HISTORY_NAME) + + if os.path.exists(path): + self._history = pd.read_csv(path, parse_dates = [ 'time', ]) + + @property + def history(self): + return self._history + + diff --git a/uvcgan_s/train/metrics/__init__.py b/uvcgan_s/train/metrics/__init__.py new file mode 100644 index 0000000..f6a8572 --- /dev/null +++ b/uvcgan_s/train/metrics/__init__.py @@ -0,0 +1,2 @@ +from .loss_metrics import LossMetrics + diff --git a/uvcgan_s/train/metrics/loss_metrics.py b/uvcgan_s/train/metrics/loss_metrics.py new file mode 100644 index 0000000..2381e84 --- /dev/null +++ b/uvcgan_s/train/metrics/loss_metrics.py @@ -0,0 +1,37 @@ +import copy + +class LossMetrics: + + def __init__(self, values = None, n = 0): + self._values = values + self._n = n + + @property + def values(self): + if self._values is None: + return None + + return { k : v / self._n for (k,v) in self._values.items() } + + def update(self, values): + if self._values is None: + self._values = copy.deepcopy(values) + else: + for k,v in values.items(): + self._values[k] += v + + self._n += 1 + + def join(self, other, other_prefix = None): + self_dict = self.values + other_dict = other.values + + if other_prefix is not None: + other_dict = { + other_prefix + k : v for (k, v) in other_dict.items() + } + + values_dict = { **self_dict, **other_dict } + + return LossMetrics(values_dict, n = 1) + diff --git a/uvcgan_s/train/train.py b/uvcgan_s/train/train.py new file mode 100644 index 0000000..4428891 --- /dev/null +++ b/uvcgan_s/train/train.py @@ -0,0 +1,84 @@ +from itertools import islice +import tqdm + +from uvcgan_s.config import Args +from uvcgan_s.data import construct_data_loaders +from uvcgan_s.torch.funcs import get_torch_device_smart, seed_everything +from uvcgan_s.cgan import construct_model +from uvcgan_s.utils.log import setup_logging + +from .metrics import LossMetrics +from .callbacks import TrainingHistory +from .transfer import transfer + +def training_epoch(it_train, model, title, steps_per_epoch): + model.train() + + steps = len(it_train) + if steps_per_epoch is not None: + steps = min(steps, steps_per_epoch) + + progbar = tqdm.tqdm(desc = title, total = steps, dynamic_ncols = True) + metrics = LossMetrics() + + for batch in islice(it_train, steps): + model.set_input(batch) + model.optimization_step() + + metrics.update(model.get_current_losses()) + + progbar.set_postfix(metrics.values, refresh = False) + progbar.update() + + progbar.close() + return metrics + +def try_continue_training(args, model): + history = TrainingHistory(args.savedir) + + start_epoch = model.find_last_checkpoint_epoch() + model.load(start_epoch) + + if start_epoch > 0: + history.load() + + start_epoch = max(start_epoch, 0) + + return (start_epoch, history) + +def train(args_dict): + args = Args.from_args_dict(**args_dict) + + setup_logging(args.log_level) + seed_everything(args.config.seed) + + device = get_torch_device_smart() + it_train = construct_data_loaders( + args.config.data, args.config.batch_size, split = 'train' + ) + + print("Starting training...") + print(args.config.to_json(indent = 4)) + + model = construct_model( + args.savedir, args.config, is_train = True, device = device + ) + start_epoch, history = try_continue_training(args, model) + + if (start_epoch == 0) and (args.transfer is not None): + transfer(model, args.transfer) + + for epoch in range(start_epoch + 1, args.epochs + 1): + title = 'Epoch %d / %d' % (epoch, args.epochs) + metrics = training_epoch( + it_train, model, title, args.config.steps_per_epoch + ) + + history.end_epoch(epoch, metrics) + model.end_epoch(epoch) + + if epoch % args.checkpoint == 0: + model.save(epoch) + + model.save(epoch = None) + diff --git a/uvcgan_s/train/transfer.py b/uvcgan_s/train/transfer.py new file mode 100644 index 0000000..165a4df --- /dev/null +++ b/uvcgan_s/train/transfer.py @@ -0,0 +1,189 @@ +import os +import logging + +import torch + +from uvcgan_s.consts import ROOT_OUTDIR +from uvcgan_s.config import Args +from uvcgan_s.cgan import construct_model + +LOGGER = logging.getLogger('uvcgan_s.train') + +def load_base_model(model, transfer_config): + try: + model.load(epoch = None) + return + + except IOError as e: + if not transfer_config.allow_partial: + raise IOError( + "Failed to find fully trained model in '%s' for transfer: %s"\ + % (model.savedir, e) + ) from e + + LOGGER.warning( + ( + "Failed to find fully trained model in '%s' for transfer." + " Trying to load from a checkpoint..." + ), model.savedir + ) + + epoch = model.find_last_checkpoint_epoch() + + if epoch > 0: + LOGGER.warning("Load transfer model from a checkpoint '%d'", epoch) + else: + raise RuntimeError("Failed to find transfer model checkpoints.") + + model.load(epoch) + +def get_base_model(transfer_config, device): + base_path = os.path.join(ROOT_OUTDIR, transfer_config.base_model) + base_args = Args.load(base_path) + + model = construct_model( + base_args.savedir, base_args.config, is_train = True, device = device + ) + + load_base_model(model, transfer_config) + + return model + +def transfer_from_larger_model(module, state_dict, strict): + source_keys = set(state_dict.keys()) + target_keys = set(k for (k, _p) in module.named_parameters()) + + matching_dict = { k : state_dict[k] for k in target_keys } + module.load_state_dict(matching_dict, strict = strict) + + LOGGER.warning( + "Transfer from a large model. Transferred %d / %d parameters", + len(target_keys), len(source_keys) + ) + +def collect_keys_for_transfer_to_wider_model(module, state_dict, strict): + source_keys = set(state_dict.keys()) + target_keys = set() + matching_keys = set() + wider_keys = set() + narrower_keys = set() + + for (k, p_target) in module.state_dict().items(): + target_keys.add(k) + + if strict: + assert k in source_keys + + p_source = state_dict[k] + + shape_source = p_source.shape + shape_target = p_target.shape + + if (shape_target is None) or (shape_target == shape_source): + matching_keys.add(k) + + elif ( + (len(shape_target) == len(shape_source)) + and all(t >= s for (t, s) in zip(shape_target, shape_source)) + ): + LOGGER.warning( + "Weight Transfer. Found wider parameter: '%s'. " + "%s vs %s", + k, shape_target, shape_source + ) + + wider_keys.add(k) + + elif ( + (len(shape_target) == len(shape_source)) + and all(t <= s for (t, s) in zip(shape_target, shape_source)) + ): + LOGGER.warning( + "Weight Transfer. Found narrower parameter: '%s'. " + "%s vs %s", + k, shape_target, shape_source + ) + + narrower_keys.add(k) + + else: + raise ValueError( + "Weight Transfer. " + f"Cannot transfer parameter '{k}' due to mismatching" + f" shapes {shape_target} vs {shape_source}" + ) + + if strict and (source_keys != target_keys): + keys_diff = source_keys.symmetric_difference(target_keys) + + raise RuntimeError( + "Transfer to wide model. Strict transfer failed due to" + f" mismatching keys {keys_diff}" + ) + + return matching_keys, wider_keys, narrower_keys + +def transfer_to_wider_model(module, state_dict, strict): + matching_keys, wider_keys, _ = \ + collect_keys_for_transfer_to_wider_model(module, state_dict, strict) + + matching_dict = { k : state_dict[k] for k in matching_keys } + module.load_state_dict(matching_dict, strict = False) + + for k, p in module.named_parameters(): + if k not in wider_keys: + continue + + source_tensor = state_dict[k] + target_slice = tuple(slice(0, s) for s in source_tensor.shape) + + with torch.no_grad(): + p[target_slice] = source_tensor + +def fully_fuzzy_transfer(module, state_dict, strict): + matching_keys, _wider_keys, _narrower_keys = \ + collect_keys_for_transfer_to_wider_model(module, state_dict, strict) + + matching_dict = { k : state_dict[k] for k in matching_keys } + module.load_state_dict(matching_dict, strict = False) + +def transfer_state_dict(module, state_dict, fuzzy, strict): + if (fuzzy is None) or (fuzzy == 'none'): + module.load_state_dict(state_dict, strict = strict) + + elif fuzzy == 'from-larger-model': + transfer_from_larger_model(module, state_dict, strict) + + elif fuzzy == 'to-wider-model': + transfer_to_wider_model(module, state_dict, strict) + + elif fuzzy == 'full': + fully_fuzzy_transfer(module, state_dict, strict) + + else: + raise ValueError(f"Unknown fuzzy transfer type: {fuzzy}") + +def transfer_parameters(model, base_model, transfer_config): + for (dst, src) in transfer_config.transfer_map.items(): + transfer_state_dict( + model.models[dst], base_model.models[src].state_dict(), + transfer_config.fuzzy, transfer_config.strict + ) + +def transfer(model, transfer_config): + if transfer_config is None: + return + + if isinstance(transfer_config, list): + for conf in transfer_config: + transfer(model, conf) + return + + LOGGER.info( + "Initiating parameter transfer : '%s'", transfer_config.to_dict() + ) + + base_model = get_base_model(transfer_config, model.device) + transfer_parameters(model, base_model, transfer_config) + + diff --git a/uvcgan_s/utils/__init__.py b/uvcgan_s/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/uvcgan_s/utils/funcs.py b/uvcgan_s/utils/funcs.py new file mode 100644 index 0000000..f69d522 --- /dev/null +++ b/uvcgan_s/utils/funcs.py @@ -0,0 +1,58 @@ +import copy + +import torch +from torchvision.datasets.folder import default_loader + +def recursive_update_dict(base_dict, new_dict): + if new_dict is None: + return + + for k,v in new_dict.items(): + if ( + isinstance(v, dict) + and k in base_dict + and isinstance(base_dict[k], dict) + ): + recursive_update_dict(base_dict[k], v) + else: + base_dict[k] = copy.deepcopy(v) + +def join_dicts(*dicts_list): + base_dict = {} + + for d in dicts_list: + recursive_update_dict(base_dict, d) + + return base_dict + +def check_value_in_range(value, value_range, hint = None): + if value in value_range: + return + + msg = '' + + if hint is not None: + msg = hint + ' ' + + msg += f"value '{value}' is not range {value_range}" + + raise ValueError(msg) + +@torch.no_grad() +def load_image_fuzzy(path, transforms, device): + result = None + + try: + result = default_loader(path) + except FileNotFoundError: + base_path = path.split('.', maxsplit = 1)[0] + + for ext in [ '.png', '.jpg', '.jpg.png' ]: + try: + result = default_loader(base_path + ext) + except FileNotFoundError: + continue + + result = transforms(result) + return result.to(device).unsqueeze(0) + diff --git a/uvcgan_s/utils/log.py b/uvcgan_s/utils/log.py new file mode 100644 index 0000000..c03f80c --- /dev/null +++ b/uvcgan_s/utils/log.py @@ -0,0 +1,23 @@ +import logging + +def reduce_pil_verbosity(level): + """A hack to stop PIL dumping large amounts of useless DEBUG info""" + logger = logging.getLogger('PIL') + logger.setLevel(max(logging.WARNING, level)) + +def setup_logging(level = logging.DEBUG): + """Setup logging.""" + logger = logging.getLogger() + + formatter = logging.Formatter( + '[%(asctime)s] [%(name)s]: %(levelname)s %(message)s' + ) + + logger.setLevel(level) + + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + logger.addHandler(handler) + + reduce_pil_verbosity(logger.level) diff --git a/uvcgan_s/utils/parsers.py b/uvcgan_s/utils/parsers.py new file mode 100644 index 0000000..149ec82 --- /dev/null +++ b/uvcgan_s/utils/parsers.py @@ -0,0 +1,110 @@ +from uvcgan_s.consts import ( + MODEL_STATE_TRAIN, MODEL_STATE_EVAL, SPLIT_TRAIN, SPLIT_TEST, SPLIT_VAL +) + +def add_model_state_parser(parser): + parser.add_argument( + '--model-state', + choices = [ MODEL_STATE_TRAIN, MODEL_STATE_EVAL ], + default = 'eval', + dest = 'model_state', + help = "evaluate model in 'train' or 'eval' states", + type = str, + ) + +def add_plot_extension_parser(parser, default = ( 'png', )): + parser.add_argument( + '-e', '--ext', + default = None if default is None else list(default), + dest = 'ext', + help = 'plot extensions', + type = str, + nargs = '+', + ) + +def add_batch_size_parser(parser, default = 1): + parser.add_argument( + '--batch-size', + default = default, + dest = 'batch_size', + help = 'batch size to use for evaluation', + type = int, + ) + +def add_n_eval_samples_parser(parser, default = None): + parser.add_argument( + '-n', + default = default, + dest = 'n_eval', + help = 'number of samples to use for evaluation', + type = int, + ) + +def add_eval_type_parser(parser, default = 'transfer'): + parser.add_argument( + '--type', + choices = [ 'transfer', 'reco', 'masked', 'simple-reco' ], + default = default, + dest = 'eval_type', + help = 'type of evaluation', + type = str, + ) + +def add_split_parser(parser, default = SPLIT_TEST): + parser.add_argument( + '--split', + choices = [ SPLIT_TRAIN, SPLIT_TEST, SPLIT_VAL ], + default = default, + dest = 'split', + help = 'data split', + type = str, + ) + +def add_eval_epoch_parser(parser, default = None): + parser.add_argument( + '--epoch', + default = default, + dest = 'epoch', + help = ( + 'checkpoint epoch to evaluate.' + ' If not specified, then the evaluation will be performed for' + ' the final model. If epoch is -1, then the evaluation will' + ' be performed for the last checkpoint.' + ), + type = int, + ) + +def add_model_directory_parser(parser): + parser.add_argument( + 'model', + help = 'directory containing model to evaluate', + metavar = 'MODEL', + type = str, + ) + +def add_preset_name_parser( + parser, name, presets, default = None, help_msg = None, +): + parser.add_argument( + f'--{name}', + default = default, + dest = name, + choices = list(presets), + help = help_msg or name, + type = str, + ) + +def add_standard_eval_parsers( + parser, + default_batch_size = 1, + default_epoch = None, + default_n_eval = None, +): + add_model_directory_parser(parser) + add_model_state_parser(parser) + add_split_parser(parser) + + add_batch_size_parser(parser, default_batch_size) + add_eval_epoch_parser(parser, default_epoch) + add_n_eval_samples_parser(parser, default_n_eval) +