This is the Code submission accompanying the SIAM SIMODS 2024 Submission "The Optimal Discriminator in Higher-order Gradient-regularized Generative Adversarial Networks." All codes are in Tensorflow2.0 Keras, and can be implemented on either GPU, CPU and on Colab or locally. A public version of this code will be released with the camera-ready version of the paper
The central component of PolyGANs is the radial basis function implementation, and can be found in the gan_topics.py file, defined under the RBFSolver and RBFLayer classes. A reviewer interested in getting to the core of these methods is requested to check this part of the code to get insights on the RBF disciminator implementation, weight and centre computation, lambda_d computation, etc.
A slice of the RBF Network code is given below. In TF, the code for builing the RBF would look like this:
def discriminator_model_RBF(self):
if self.gan not in ['WAE'] and self.data not in ['g2','gmm8', 'gN']:
inputs = tf.keras.Input(shape=(self.output_size,self.output_size,1))
inputs_res = tf.keras.layers.Reshape(target_shape=[self.output_size*self.output_size])(inputs)
else:
inputs = tf.keras.Input(shape=(self.latent_dims,))
inputs_res = inputs
### A FLAG to control 'N', the number of centers
num_centers = 2*self.N_centers
### A Custom RBFLayer Implementation
### Please Check L167 of ``gan_topics.py``
### Note the argument ``order_m = self.rbf_m``, which is the m in k=2m-n
Out = RBFLayer(num_centres=num_centers, output_dim=1, order_m = self.rbf_m, batch_size = self.batch_size)(inputs_res)
### \lambda_d^* can be computed with another RBF layer
### Note the new arguement rbf_pow = -self.rbf_n (cf. Appendix B)
lambda_term = RBFLayer(num_centres=num_centers, output_dim=1, order_m = self.rbf_m, batch_size = self.batch_size, rbf_pow = -self.rbf_n)(inputs_res)
model = tf.keras.Model(inputs=inputs, outputs= [Out,lambda_term])
return model
This codebase consists of the TensorFlow2.x implementation PolyGAN and PolyGAN-WAE. The baseline comparisons are with multiple WGAN, GMMN and WAE variants as described in the paper. All results were optained when training on TF2.5. Other version of TF might cause instability due to library deprications. Use with care.
Dependencies can be installed via anaconda. The PolyGAN_GPU_TF25.yml file list the dependencies to setup the GPU system based environment. To install from the yml file, please run conda env create --file PolyGAN_GPU_TF25.yml :
GPU accelerated TensorFlow2.0 Environment:
dependencies:
- cudatoolkit=11.3.1
- cudnn=8.2.1=cuda11.3_0
- opencv=3.4.2
- pip=20.0.2
- python=3.6.10
- pytorch=1.4.0=py3.6_cpu_0
- pip:
- absl-py==0.9.0
- clean-fid==0.1.15
- h5py==2.10.0
- matplotlib==3.1.3
- numpy==1.19.5
- pot==0.8.0
- tensorflow-addons==0.13.0
- tensorflow-datasets==4.4.0
- tensorflow-estimator==2.5.0
- tensorflow-gpu==2.5.0
- tensorflow-probability==0.13.0
- torch==1.10.0+cu113
- torchaudio==0.10.0+cu113
- torchvision==0.11.1+cu113
- tqdm==4.42.1
If a GPU is unavailable, the CPU only environment can be built with PolyGAN_CPU.yml. This setting is meant to run evaluation code;. Training on the CPU environment is not advisable.
CPU based TensorFlow2.0 Environment:
- pip=20.0.2
- python=3.6.10
- opencv=3.4.2
- pytorch=1.4.0=py3.6_cpu_0
- pip:
- absl-py==0.9.0
- h5py==2.10.0
- clean-fid==0.1.15
- ipython==7.15.0
- ipython-genutils==0.2.0
- matplotlib==3.1.3
- numpy==1.18.1
- scikit-learn==0.22.1
- scipy==1.4.1
- pot==0.8.0
- tensorboard==2.0.2
- tensorflow-addons==0.6.0
- tensorflow-datasets==3.0.1
- tensorflow-estimator==2.0.1
- tensorflow==2.0.0
- tensorflow-probability==0.8.0
- tqdm==4.42.1
- gdown==3.12
If a Conda environment with the required dependecies already exists, or you wish to use your own environment for any particular reason, we suggest making a clone called PolyGAN to maintain consistent code-running experiemnt with the exisitng bash files: conda create --name PolyGAN --clone <Your_Env_Name>
The clean-fid and pot libraries are used for metric computation (cf. Appendix E).
Codes were tested locally and on Google Colab. Implementation Jupyter Notebooks to recrete the results shown in the paper will be included with the camera-ready version of the paper. The current version of the code allows implementing locally, the training codes for the various models considered.
In case of runtime issues occuring due to hardware/driver incompatibitlity, please refer the associated user-manuals of NVIDIA CUDA, CudNN, PyTorch or TensorFlow to install dependecies from source.
MNIST and CIFAR-10 are loaded from TensorFlow-Datasets. The CelebA datasets must be downloaded by running the following code:
python download_celeba.py
Code for downloading LSUN is well document in the official repository: https://github.com/fyu/lsun
. The datasets are relative large. Please keep trach of internet bandwidth availability.
The code provides training procedures on MNIST, CIFAR-10, CelebA and LSUN-Churches datasets. We include training codes for all basic experinets presented in the paper.
- Running
train_*.shbash files: There are two bash files to run an example instacne of trainging:train_PolyGAN_GPU.shandtrain_Baseline_GPU.sh. The fastest was to train a model is by running these bash files. Around 3-4 sample functions calls are provided in the above bash files. Uncomment the desired command to train for the associated testcase.
bash bash_files/train/train_PolyGAN_GPU.sh
bash bash_files/train/train_Baseline_GPU.sh
(While code is provided for running experiments on GPU, their CPU counterparts can be made by seting the appropriate flags, as discussed in the next section.)
2) Manually running gan_main.py: Aternatively, you can train any model of your choice by running gan_main.py with custom flags and modifiers. The list of flags and their defauly values are are defined in gan_main.py.
- Training on Colab: The experiments in the submission was trained on Google Colab (although it is not optimal for training on large models for long durations, pandemic related closures of other compute facilities made the use of Colab unavoidable). For those familiar with the approach, this repository could be cloned to your google drive and steps (1) or (2) could be used for training. CelebA must be downloaded to you current instance on Colab as reading data from GoogleDrive currently causes a Timeout error. Setting the flags
--colab 1,--latex_plot_flag 0, and--pbar_flag 0is advisable. Thecolabflag modifies CelebA data-handling code to read data from the local folder, rather thanPolyGANs/data/CelebA/. Thelatex_plot_flagflag removes code dependency on latex for plot labels, since the Colab isntance does not native include this. (Alternatively, you could install texlive_full in your colab instance). Lastly, turning off thepbar_flagwas found to prevent the browser from eating too much RAM when training the model. The .ipynb file for training on Colab, along with instructions of modifications necessary for running on Colab will be included with the public release of the paper.
The default values of the various flags used in the bash files are discussed in gan_main.py. While most flags pertain to training parameters, a few control the actula GAN trining algorithm:
--gan: Can be set toWGANorWAEbased on which set of experiments are to be executed. The corresponding file frommodels/GANs/ormodels/WAEs/will be imported.--topic: The kind of GAN varinat to consider.Basecalls the class associated with trainig baselines WGAN variants (GP, LP, Rd (or R1 as the oirignal authors call it), Rg (or R2), and ALP.PolyGANresults in calling those classes associated with PolyGAN training.GMMNcan be called for GMMN-IMQ and GMMN-RBFG variants. To run WAE variants, we set the topic toPolyGANorWAEMMDfor the proposed, or baselines variants, respectively.--loss: Choice of GAN loss. For WGAN, there isbase,GP,LP,ALP,R1andR2available for training. The proposed loss variant isRBF. For GMMN, the options areRBFG(Gaussian Kernel) andIMQ(Inverse Multiquadratic Kernel). For WAE, the choices areRBF,RBFG,IMQ,SW(Sliced Wassersten) andCW(Cramer Wold).--mode: Choose betweentrain,test, ormetrics. Whiletestgenetares interpolation and reconstruction results, themetricsmode allows for evaluating FID, sharpness, etc.--metrics: A comma seperated string of metrics to compute. Currently, it supportsFID,sharpness. When passing multiple, please avoid blank spaces.--data: Target data to train on. Supportmnist,cifar10,celeba, andchurch. (All lowercase only)--latent_dims: The letent space dimensionality to use in WAE variants. for 2D and nD Gaussians, this is atuomatically assumed to be the dimensionality of the data itself.--rbf_m: The gradient pentaty order, brought to life via the order of the polyharmonic RBF order 2m-n. Optimal performance when m = ceil(n/2)--GPU: A list indicating the visble GPU devices to CUDA. Set to0,0,1,<empyt>, etc. based on compute available.--device: The device that TensorFlow places its variables on. Begins iteration from0all the way up ton, whenndevies are proved in--GPU. Set--GPUas empty string and--deviceto-1to train on CPU.