diff --git a/synthesis_pipeline/condGAN/cond_DCGAN_network.py b/synthesis_pipeline/condGAN/cond_DCGAN_network.py index ec36f2e..f5d97be 100644 --- a/synthesis_pipeline/condGAN/cond_DCGAN_network.py +++ b/synthesis_pipeline/condGAN/cond_DCGAN_network.py @@ -1,77 +1,55 @@ import torch import torch.nn as nn import torch.nn.parallel +import torch.nn.functional as F +import math -# ** GENERATOR MODEL (takes noise; returns tensor) ** class Generator(nn.Module): def __init__(self, n_categories, noise_size, channels=3, feature_map_size=64, - embedding_size=50): + embedding_size=50, + img_size=64): # Assuming default power of 2 as 64x64 image size super(Generator, self).__init__() - + ngf = feature_map_size + power = int(math.log2(img_size)) self.label_embedding = nn.Embedding(n_categories, embedding_size) - self.main = nn.Sequential( - # input is Z, going into a convolution - nn.ConvTranspose2d( - in_channels=embedding_size + noise_size, - out_channels=ngf * 8, - kernel_size=4, - stride=1, - padding=0, - bias=False - ), - nn.BatchNorm2d(ngf * 8), - nn.ReLU(True), - # state size. (gen_dimesions*8) x 4 x 4 - nn.ConvTranspose2d( - in_channels=ngf * 8, - out_channels=ngf * 4, - kernel_size=4, - stride=2, - padding=1, bias=False - ), - nn.BatchNorm2d(ngf * 4), - nn.ReLU(True), - # state size. (gen_dimesions*4) x 8 x 8 - nn.ConvTranspose2d( - in_channels=ngf * 4, - out_channels=ngf * 2, - kernel_size=4, - stride=2, - padding=1, bias=False - ), - nn.BatchNorm2d(ngf * 2), - nn.ReLU(True), - # state size. (gen_dimesions*2) x 16 x 16 - nn.ConvTranspose2d( - in_channels=ngf * 2, - out_channels=ngf, - kernel_size=4, - stride= 2, - padding=1, - bias=False - ), - nn.BatchNorm2d(ngf), - nn.ReLU(True), - # state size. (gen_dimesions) x 32 x 32 - nn.ConvTranspose2d( - in_channels=ngf, - out_channels=channels, + layers = [] + prev_channels = noise_size + embedding_size # Initial input channels + out_channels = (ngf * 8) + + for i in range(power-1): + layers.append(nn.ConvTranspose2d( + in_channels=prev_channels, + out_channels=out_channels, # out_channels is still itself kernel_size=4, stride=2, padding=1, bias=False - ), - nn.Tanh() - # state size. (nc) x 64 x 64 - ) + )) + layers.append(nn.BatchNorm2d(out_channels)) + layers.append(nn.ReLU(True)) + prev_channels = out_channels # Update input channels for the next layer + out_channels //= 2 + + # Output layer + layers.append(nn.ConvTranspose2d( + in_channels=prev_channels, + out_channels=channels, + kernel_size=4, + stride=2, + padding=1, + bias=False + )) + layers.append(nn.Tanh()) + + self.main = nn.Sequential(*layers) def forward(self, noise, label): label_embedding = self.label_embedding(label).view(label.size(0), -1, 1, 1) @@ -79,79 +57,61 @@ def forward(self, noise, label): output = self.main(combined_input) return output +# Discriminator currently has same isues* +# Kernel size 4x4 can't be larger than input 1x1 -# ** DISCRIMINATOR MODEL (takes tensor; returns probability) ** class Discriminator(nn.Module): def __init__(self, n_categories, channels=3, feature_map_size=64, - embedding_size=50): + embedding_size=50, + img_size=64): super(Discriminator, self).__init__() - + ndf = feature_map_size + power = int(math.log2(img_size)) self.label_embedding = nn.Embedding(n_categories, embedding_size) - self.main = nn.Sequential( - # input is (nc) x 64 x 64 - nn.Conv2d( - in_channels=channels + embedding_size, - out_channels=ndf, - kernel_size=4, - stride=2, - padding=1, - bias=False - ), - nn.LeakyReLU(0.2, inplace=True), - # state size. (ndvf) x 32 x 32 - nn.Conv2d( - in_channels=ndf, - out_channels=ndf * 2, + layers = [] + prev_channels = channels + embedding_size # Initial input channels + for i in range(power-1): + out_channels = (ndf * 8) // (2 ** i) + layers.append(nn.Conv2d( + in_channels=prev_channels, + out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=False - ), - nn.BatchNorm2d(ndf * 2), - nn.LeakyReLU(0.2, inplace=True), - # state size. (ndf*2) x 16 x 16 - nn.Conv2d( - in_channels=ndf * 2, - out_channels=ndf * 4, - kernel_size=4, - stride=2, - padding=1, - bias=False - ), - nn.BatchNorm2d(ndf * 4), - nn.LeakyReLU(0.2, inplace=True), - # state size. (ndf*4) x 8 x 8 - nn.Conv2d( - in_channels=ndf * 4, - out_channels=ndf * 8, - kernel_size=4, - stride=2, - padding=1, - bias=False - ), - nn.BatchNorm2d(ndf * 8), - nn.LeakyReLU(0.2, inplace=True), - # state size. (ndf*8) x 4 x 4 - nn.Conv2d( - in_channels=ndf * 8, - out_channels=1, - kernel_size=4, - stride=1, - padding=0, - bias=False - ), - nn.Sigmoid() + )) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + prev_channels = out_channels # Update input channels for the next layer + + # Output layer + self.output_conv = nn.Conv2d( + in_channels=prev_channels, + out_channels=1, + kernel_size=2, + stride=2, + padding=0, ) + + self.sigmoid = nn.Sigmoid() + + self.main = nn.Sequential(*layers) def forward(self, input, label): label_embedding = self.label_embedding(label).view(label.size(0), -1, 1, 1) - expanded_labels = label_embedding.expand(label.size(0), label_embedding.size(1), input.size(2), input.size(3)) - combined_input = torch.cat((input, expanded_labels), 1) - return self.main(combined_input) - + label_embedding = label_embedding.expand(-1, -1, input.size(2), input.size(3)) # Expand along height and width dimensions + combined_input = torch.cat((input, label_embedding), 1) + # print('before main', combined_input.shape) + x = self.main(combined_input) + # print('after main', x.shape) + x = self.output_conv(x) + x = x.flatten() + # print('after output conv', x.shape) + x = self.sigmoid(x) + # print('after sigmoid', x.shape) + return x diff --git a/synthesis_pipeline/condGAN/train.py b/synthesis_pipeline/condGAN/train.py index ae16f7e..36e0883 100644 --- a/synthesis_pipeline/condGAN/train.py +++ b/synthesis_pipeline/condGAN/train.py @@ -51,7 +51,7 @@ batch_size = args.batch_size dataset = MultiClassDataset(data_path=args.data, - category_max=20, + category_max=batch_size, transform=transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), @@ -67,8 +67,11 @@ beta1 = 0.5 # initialize networks -netG = Generator(n_categories=dataset.num_labels, noise_size=args.noise_size).to(device) -netD = Discriminator(n_categories=dataset.num_labels).to(device) +netG = Generator(n_categories=dataset.num_labels, noise_size=args.noise_size, img_size=args.image_size).to(device) +netD = Discriminator(n_categories=dataset.num_labels, img_size=args.image_size).to(device) + +print(netG) +print(netD) # randomly initialize network weights netG.apply(torch_utils.weights_init)