Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 71 additions & 111 deletions synthesis_pipeline/condGAN/cond_DCGAN_network.py
Original file line number Diff line number Diff line change
@@ -1,157 +1,117 @@
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import math

# ** GENERATOR MODEL (takes noise; returns tensor) **
class Generator(nn.Module):
def __init__(self,
n_categories,
noise_size,
channels=3,
feature_map_size=64,
embedding_size=50):
embedding_size=50,
img_size=64): # Assuming default power of 2 as 64x64 image size

super(Generator, self).__init__()

ngf = feature_map_size
power = int(math.log2(img_size))

self.label_embedding = nn.Embedding(n_categories, embedding_size)

self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(
in_channels=embedding_size + noise_size,
out_channels=ngf * 8,
kernel_size=4,
stride=1,
padding=0,
bias=False
),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (gen_dimesions*8) x 4 x 4
nn.ConvTranspose2d(
in_channels=ngf * 8,
out_channels=ngf * 4,
kernel_size=4,
stride=2,
padding=1, bias=False
),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (gen_dimesions*4) x 8 x 8
nn.ConvTranspose2d(
in_channels=ngf * 4,
out_channels=ngf * 2,
kernel_size=4,
stride=2,
padding=1, bias=False
),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (gen_dimesions*2) x 16 x 16
nn.ConvTranspose2d(
in_channels=ngf * 2,
out_channels=ngf,
kernel_size=4,
stride= 2,
padding=1,
bias=False
),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (gen_dimesions) x 32 x 32
nn.ConvTranspose2d(
in_channels=ngf,
out_channels=channels,
layers = []
prev_channels = noise_size + embedding_size # Initial input channels
out_channels = (ngf * 8)

for i in range(power-1):
layers.append(nn.ConvTranspose2d(
in_channels=prev_channels,
out_channels=out_channels, # out_channels is still itself
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.Tanh()
# state size. (nc) x 64 x 64
)
))
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.ReLU(True))
prev_channels = out_channels # Update input channels for the next layer
out_channels //= 2

# Output layer
layers.append(nn.ConvTranspose2d(
in_channels=prev_channels,
out_channels=channels,
kernel_size=4,
stride=2,
padding=1,
bias=False
))
layers.append(nn.Tanh())

self.main = nn.Sequential(*layers)

def forward(self, noise, label):
label_embedding = self.label_embedding(label).view(label.size(0), -1, 1, 1)
combined_input = torch.cat((noise, label_embedding), 1)
output = self.main(combined_input)
return output

# Discriminator currently has same isues*
# Kernel size 4x4 can't be larger than input 1x1

# ** DISCRIMINATOR MODEL (takes tensor; returns probability) **
class Discriminator(nn.Module):
def __init__(self,
n_categories,
channels=3,
feature_map_size=64,
embedding_size=50):
embedding_size=50,
img_size=64):
super(Discriminator, self).__init__()

ndf = feature_map_size
power = int(math.log2(img_size))

self.label_embedding = nn.Embedding(n_categories, embedding_size)

self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(
in_channels=channels + embedding_size,
out_channels=ndf,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndvf) x 32 x 32
nn.Conv2d(
in_channels=ndf,
out_channels=ndf * 2,
layers = []
prev_channels = channels + embedding_size # Initial input channels
for i in range(power-1):
out_channels = (ndf * 8) // (2 ** i)
layers.append(nn.Conv2d(
in_channels=prev_channels,
out_channels=out_channels,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(
in_channels=ndf * 2,
out_channels=ndf * 4,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(
in_channels=ndf * 4,
out_channels=ndf * 8,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(
in_channels=ndf * 8,
out_channels=1,
kernel_size=4,
stride=1,
padding=0,
bias=False
),
nn.Sigmoid()
))
layers.append(nn.LeakyReLU(0.2, inplace=True))
prev_channels = out_channels # Update input channels for the next layer

# Output layer
self.output_conv = nn.Conv2d(
in_channels=prev_channels,
out_channels=1,
kernel_size=2,
stride=2,
padding=0,
)

self.sigmoid = nn.Sigmoid()

self.main = nn.Sequential(*layers)

def forward(self, input, label):
label_embedding = self.label_embedding(label).view(label.size(0), -1, 1, 1)
expanded_labels = label_embedding.expand(label.size(0), label_embedding.size(1), input.size(2), input.size(3))
combined_input = torch.cat((input, expanded_labels), 1)
return self.main(combined_input)

label_embedding = label_embedding.expand(-1, -1, input.size(2), input.size(3)) # Expand along height and width dimensions
combined_input = torch.cat((input, label_embedding), 1)
# print('before main', combined_input.shape)
x = self.main(combined_input)
# print('after main', x.shape)
x = self.output_conv(x)
x = x.flatten()
# print('after output conv', x.shape)
x = self.sigmoid(x)
# print('after sigmoid', x.shape)
return x
9 changes: 6 additions & 3 deletions synthesis_pipeline/condGAN/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
batch_size = args.batch_size

dataset = MultiClassDataset(data_path=args.data,
category_max=20,
category_max=batch_size,
transform=transforms.Compose([
transforms.Resize(args.image_size),
transforms.CenterCrop(args.image_size),
Expand All @@ -67,8 +67,11 @@
beta1 = 0.5

# initialize networks
netG = Generator(n_categories=dataset.num_labels, noise_size=args.noise_size).to(device)
netD = Discriminator(n_categories=dataset.num_labels).to(device)
netG = Generator(n_categories=dataset.num_labels, noise_size=args.noise_size, img_size=args.image_size).to(device)
netD = Discriminator(n_categories=dataset.num_labels, img_size=args.image_size).to(device)

print(netG)
print(netD)

# randomly initialize network weights
netG.apply(torch_utils.weights_init)
Expand Down