Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Python conda test

on: [push]

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- uses: conda-incubator/setup-miniconda@v2
with:
miniconda-version: "latest"
activate-environment: ingan
environment-file: environment.yml
python-version: 3.8
auto-activate-base: false
- shell: bash -l {0}
run: |
conda info
conda list
- name: Test with pytest
shell: bash -l {0}
run: |
pytest .
- name: yapf
id: yapf
uses: diegovalenzuelaiturra/yapf-action@v0.0.1
with:
args: . --recursive --diff
- name: Fail if yapf made changes
if: steps.yapf.outputs.exit-code == 2
run: exit 1
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
examples/
results/
.idea/
2 changes: 2 additions & 0 deletions .yapfignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
results/
examples/
127 changes: 76 additions & 51 deletions InGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ def __init__(self, conf):
self.real_example = torch.FloatTensor(1, 3, conf.output_crop_size, conf.output_crop_size).cuda()

# Define networks
self.G = networks.Generator(conf.G_base_channels, conf.G_num_resblocks, conf.G_num_downscales, conf.G_use_bias,
conf.G_skip)
self.D = networks.MultiScaleDiscriminator(conf.output_crop_size, self.conf.D_max_num_scales,
self.conf.D_scale_factor, self.conf.D_base_channels)
self.G = networks.Generator(
conf.G_base_channels, conf.G_num_resblocks, conf.G_num_downscales, conf.G_use_bias, conf.G_skip
)
self.D = networks.MultiScaleDiscriminator(
conf.output_crop_size, self.conf.D_max_num_scales, self.conf.D_scale_factor, self.conf.D_base_channels
)
self.GAN_loss_layer = networks.GANLoss()
self.Reconstruct_loss = networks.WeightedMSELoss(use_L1=conf.use_L1)
self.RandCrop = networks.RandomCrop([conf.input_crop_size, conf.input_crop_size], must_divide=conf.must_divide)
Expand Down Expand Up @@ -89,15 +91,18 @@ def save(self, citer=None):
filename = citer
else:
filename = 'snapshot-{:05d}.pth.tar'.format(citer)
torch.save({'G': self.G.state_dict(),
'D': self.D.state_dict(),
'optim_G': self.optimizer_G.state_dict(),
'optim_D': self.optimizer_D.state_dict(),
'sched_G': self.lr_scheduler_G.state_dict(),
'sched_D': self.lr_scheduler_D.state_dict(),
'loss': self.GAN_loss_layer.state_dict(),
'iter': citer if citer else self.cur_iter},
os.path.join(self.conf.output_dir_path, filename))
torch.save(
{
'G': self.G.state_dict(),
'D': self.D.state_dict(),
'optim_G': self.optimizer_G.state_dict(),
'optim_D': self.optimizer_D.state_dict(),
'sched_G': self.lr_scheduler_G.state_dict(),
'sched_D': self.lr_scheduler_D.state_dict(),
'loss': self.GAN_loss_layer.state_dict(),
'iter': citer if citer else self.cur_iter
}, os.path.join(self.conf.output_dir_path, filename)
)

def resume(self, resume_path, test_flag=False):
resume = torch.load(resume_path, map_location={'cuda:5': 'cuda:0'})
Expand Down Expand Up @@ -134,34 +139,44 @@ def resume(self, resume_path, test_flag=False):
if len(missing):
warnings.warn('Missing the following state dicts from checkpoint: {}'.format(', '.join(missing)))

print('resuming checkpoint {}'.format(self.conf.resume))
print(('resuming checkpoint {}'.format(self.conf.resume)))

def test(self, input_tensor, output_size, rand_affine, input_size, run_d_pred=True, run_reconstruct=True):
with torch.no_grad():
self.G_pred = self.G.forward(Variable(input_tensor.detach()), output_size=output_size, random_affine=rand_affine)
self.G_pred = self.G.forward(
Variable(input_tensor.detach()), output_size=output_size, random_affine=rand_affine
)
if run_d_pred:
scale_weights_for_output = get_scale_weights(i=self.cur_iter,
max_i=self.conf.D_scale_weights_iter_for_even_scales,
start_factor=self.conf.D_scale_weights_sigma,
input_shape=self.G_pred.shape[2:],
min_size=self.conf.D_min_input_size,
num_scales_limit=self.conf.D_max_num_scales,
scale_factor=self.conf.D_scale_factor)
scale_weights_for_input = get_scale_weights(i=self.cur_iter,
max_i=self.conf.D_scale_weights_iter_for_even_scales,
start_factor=self.conf.D_scale_weights_sigma,
input_shape=input_tensor.shape[2:],
min_size=self.conf.D_min_input_size,
num_scales_limit=self.conf.D_max_num_scales,
scale_factor=self.conf.D_scale_factor)
self.D_preds = [self.D.forward(Variable(input_tensor.detach()), scale_weights_for_input),
self.D.forward(Variable(self.G_pred.detach()), scale_weights_for_output)]
scale_weights_for_output = get_scale_weights(
i=self.cur_iter,
max_i=self.conf.D_scale_weights_iter_for_even_scales,
start_factor=self.conf.D_scale_weights_sigma,
input_shape=self.G_pred.shape[2:],
min_size=self.conf.D_min_input_size,
num_scales_limit=self.conf.D_max_num_scales,
scale_factor=self.conf.D_scale_factor
)
scale_weights_for_input = get_scale_weights(
i=self.cur_iter,
max_i=self.conf.D_scale_weights_iter_for_even_scales,
start_factor=self.conf.D_scale_weights_sigma,
input_shape=input_tensor.shape[2:],
min_size=self.conf.D_min_input_size,
num_scales_limit=self.conf.D_max_num_scales,
scale_factor=self.conf.D_scale_factor
)
self.D_preds = [
self.D.forward(Variable(input_tensor.detach()), scale_weights_for_input),
self.D.forward(Variable(self.G_pred.detach()), scale_weights_for_output)
]
else:
self.D_preds = None

self.G_preds = [input_tensor, self.G_pred]

self.reconstruct = self.G.forward(self.G_pred, output_size=input_size, random_affine=-rand_affine) if run_reconstruct else None
self.reconstruct = self.G.forward(
self.G_pred, output_size=input_size, random_affine=-rand_affine
) if run_reconstruct else None

return self.G_preds, self.D_preds, self.reconstruct

Expand All @@ -171,14 +186,16 @@ def train_g(self):
self.optimizer_D.zero_grad()

# Determine output size of G (dynamic change)
output_size, random_affine = random_size(orig_size=self.input_tensor.shape[2:],
curriculum=self.conf.curriculum,
i=self.cur_iter,
iter_for_max_range=self.conf.iter_for_max_range,
must_divide=self.conf.must_divide,
min_scale=self.conf.min_scale,
max_scale=self.conf.max_scale,
max_transform_magniutude=self.conf.max_transform_magnitude)
output_size, random_affine = random_size(
orig_size=self.input_tensor.shape[2:],
curriculum=self.conf.curriculum,
i=self.cur_iter,
iter_for_max_range=self.conf.iter_for_max_range,
must_divide=self.conf.must_divide,
min_scale=self.conf.min_scale,
max_scale=self.conf.max_scale,
max_transform_magniutude=self.conf.max_transform_magnitude
)

# Add noise to G input for better generalization (make it ignore the 1/255 binning)
self.input_tensor_noised = self.input_tensor + (torch.rand_like(self.input_tensor) - 0.5) * 2.0 / 255
Expand All @@ -187,18 +204,22 @@ def train_g(self):
self.G_pred = self.G.forward(self.input_tensor_noised, output_size=output_size, random_affine=random_affine)

# Run generator result through discriminator forward pass
self.scale_weights = get_scale_weights(i=self.cur_iter,
max_i=self.conf.D_scale_weights_iter_for_even_scales,
start_factor=self.conf.D_scale_weights_sigma,
input_shape=self.G_pred.shape[2:],
min_size=self.conf.D_min_input_size,
num_scales_limit=self.conf.D_max_num_scales,
scale_factor=self.conf.D_scale_factor)
self.scale_weights = get_scale_weights(
i=self.cur_iter,
max_i=self.conf.D_scale_weights_iter_for_even_scales,
start_factor=self.conf.D_scale_weights_sigma,
input_shape=self.G_pred.shape[2:],
min_size=self.conf.D_min_input_size,
num_scales_limit=self.conf.D_max_num_scales,
scale_factor=self.conf.D_scale_factor
)
d_pred_fake = self.D.forward(self.G_pred, self.scale_weights)

# If reconstruction-loss is used, run through decoder to reconstruct, then calculate reconstruction loss
if self.conf.reconstruct_loss_stop_iter > self.cur_iter:
self.reconstruct = self.G.forward(self.G_pred, output_size=self.input_tensor.shape[2:], random_affine=-random_affine)
self.reconstruct = self.G.forward(
self.G_pred, output_size=self.input_tensor.shape[2:], random_affine=-random_affine
)
self.loss_G_reconstruct = self.criterionReconstruction(self.reconstruct, self.input_tensor, self.loss_mask)

# Calculate generator loss, based on discriminator prediction on generator result
Expand All @@ -224,9 +245,13 @@ def train_g(self):
if self.cur_iter > self.conf.G_extra_inverse_train_start_iter:
for _ in range(self.conf.G_extra_inverse_train):
self.optimizer_G.zero_grad()
self.inverse = self.G.forward(self.G_pred.detach(), output_size=self.input_tensor.shape[2:], random_affine=-random_affine)
self.loss_G_inverse = (self.criterionReconstruction(self.inverse, self.input_tensor, self.loss_mask) *
self.conf.G_extra_inverse_train_ratio)
self.inverse = self.G.forward(
self.G_pred.detach(), output_size=self.input_tensor.shape[2:], random_affine=-random_affine
)
self.loss_G_inverse = (
self.criterionReconstruction(self.inverse, self.input_tensor, self.loss_mask) *
self.conf.G_extra_inverse_train_ratio
)
self.loss_G_inverse.backward()
self.optimizer_G.step()

Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

Project page: http://www.wisdom.weizmann.ac.il/~vision/ingan/ (See our results and visual comparison to other methods)

Version ported to Python 3.8 and PyTorch 1.9 by [https://github.com/Bartolo1024](https://github.com/Bartolo1024) and used in [Level generation and style enhancement - deep learning for game development overview](https://arxiv.org/abs/2107.07397).

**Accepted ICCV'19 (Oral)**
----------
![](/figs/fruits.gif)
Expand Down
Loading