Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Code was tested on NVIDIA V100 GPU and Titan X Pascal.

This snippet will generate 84 images by inverting resnet50 model from torchvision package.

`python imagenet_inversion.py --bs=84 --do_flip --exp_name="rn50_inversion" --r_feature=0.01 --arch_name="resnet50" --verifier --adi_scale=0.0 --setting_id=0 --lr 0.25`
`python imagenet_inversion.py --bs=84 --do_flip --exp_name="rn50_inversion" --r_feature=0.01 --arch_name="resnet50" --verifier --adi_scale=0.0 --setting_id=0 --lr 0.25 --gpu_ids 0`

Arguments:

Expand All @@ -61,6 +61,7 @@ Useful to observe generalizability of generated images.
- `setting_id` - settings for optimization: 0 - multi resolution scheme, 1 - 2k iterations full resolution, 2 - 20k iterations (the closes to ResNet50 experiments in the paper). Recommended to use setting_id={0, 1}
- `adi_scale` - competition coefficient. With positive value will lead to images that are good for the original model, but bad for verifier. Value 0.2 was used in the paper.
- `random_label` - randomly select classes for inversion. Without this argument the code will generate hand picked classes.
- `gpu_ids` - device ids for using single/multi-gpu training.

After 3k iterations (~6 mins on NVIDIA V100) generation is done: `Verifier accuracy: 91.6...%` (experiment with >98% verifier accuracy can be found `/example_logs`). We generated images by inverting vanilla ResNet50 (not trained for image generation) and classification accuracy by MobileNetv2 is >90%. A grid of images look like (from `/final_images/`, reduced quality due to JPEG compression. )
![Generated grid of images](example_logs/fp32_set0_rn50_first_bn_scaled.jpg "ResNet50 Inverted images")
Expand Down
44 changes: 33 additions & 11 deletions deepinversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torchvision.utils as vutils
from PIL import Image
import numpy as np
import sys

from utils.utils import lr_cosine_policy, lr_policy, beta_policy, mom_cosine_policy, clip, denormalize, create_folder

Expand All @@ -33,19 +34,26 @@ class DeepInversionFeatureHook():
'''
def __init__(self, module):
self.hook = module.register_forward_hook(self.hook_fn)

self.r_feature_glob = {}

def hook_fn(self, module, input, output):
# hook co compute deepinversion's feature distribution regularization
nch = input[0].shape[1]
mean = input[0].mean([0, 2, 3])
var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)

#forcing mean and variance to match between two distributions
#other ways might work better, i.g. KL divergence

# single-gpu
r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
module.running_mean.data - mean, 2)

self.r_feature = r_feature

# multi-gpu
self.r_feature_glob[r_feature.device] = r_feature

# must have no output

def close(self):
Expand Down Expand Up @@ -76,7 +84,8 @@ def __init__(self, bs=84,
criterion=None,
coefficients=dict(),
network_output_function=lambda x: x,
hook_for_display = None):
hook_for_display = None,
gpus=[i for i in range(torch.cuda.device_count())]):
'''
:param bs: batch size per GPU for image generation
:param use_fp16: use FP16 (or APEX AMP) for model inversion, uses less memory and is faster for GPUs with Tensor Cores
Expand Down Expand Up @@ -105,13 +114,14 @@ def __init__(self, bs=84,
"adi_scale" - coefficient for Adaptive DeepInversion, competition, def =0 means no competition
network_output_function: function to be applied to the output of the network to get the output
hook_for_display: function to be executed at every print/save call, useful to check accuracy of verifier
gpus: list containing the ids of the gpus to be used
'''

print("Deep inversion class generation")
# for reproducibility
torch.manual_seed(torch.cuda.current_device())
self.gpus = gpus

self.net_teacher = net_teacher

if "resolution" in parameters.keys():
self.image_resolution = parameters["resolution"]
Expand Down Expand Up @@ -168,13 +178,15 @@ def __init__(self, bs=84,
## Create hooks for feature statistics
self.loss_r_feature_layers = []

for module in self.net_teacher.modules():
for module in net_teacher.modules():
if isinstance(module, nn.BatchNorm2d):
self.loss_r_feature_layers.append(DeepInversionFeatureHook(module))

self.hook_for_display = None
if hook_for_display is not None:
self.hook_for_display = hook_for_display

self.net_teacher = torch.nn.DataParallel(net_teacher, device_ids=self.gpus)

def get_images(self, net_student=None, targets=None):
print("get_images call")
Expand Down Expand Up @@ -204,8 +216,7 @@ def get_images(self, net_student=None, targets=None):
img_original = self.image_resolution

data_type = torch.half if use_fp16 else torch.float
inputs = torch.randn((self.bs, 3, img_original, img_original), requires_grad=True, device='cuda',
dtype=data_type)
inputs = torch.randn((self.bs, 3, img_original, img_original), requires_grad=True, device='cuda', dtype=data_type)
pooling_function = nn.modules.pooling.AvgPool2d(kernel_size=2)

if self.setting_id==0:
Expand Down Expand Up @@ -271,7 +282,7 @@ def get_images(self, net_student=None, targets=None):
# forward pass
optimizer.zero_grad()
net_teacher.zero_grad()

outputs = net_teacher(inputs_jit)
outputs = self.network_output_function(outputs)

Expand All @@ -280,10 +291,20 @@ def get_images(self, net_student=None, targets=None):

# R_prior losses
loss_var_l1, loss_var_l2 = get_image_prior_losses(inputs_jit)

# R_feature loss
rescale = [self.first_bn_multiplier] + [1. for _ in range(len(self.loss_r_feature_layers)-1)]
loss_r_feature = sum([mod.r_feature * rescale[idx] for (idx, mod) in enumerate(self.loss_r_feature_layers)])

loss_r_feature = 0
for idx,layer in enumerate(self.loss_r_feature_layers):
l_g = layer.r_feature_glob
r_feature_layer_loss = []
for value in l_g.values():
r_feature_layer_loss.append(value.to('cuda') * rescale[idx])
r_feature_layer_loss = sum(r_feature_layer_loss)
loss_r_feature = loss_r_feature + r_feature_layer_loss



# R_ADI
loss_verifier_cig = torch.zeros(1)
Expand Down Expand Up @@ -316,7 +337,7 @@ def get_images(self, net_student=None, targets=None):

# l2 loss on images
loss_l2 = torch.norm(inputs_jit.view(self.bs, -1), dim=1).mean()

# combining losses
loss_aux = self.var_scale_l2 * loss_var_l2 + \
self.var_scale_l1 * loss_var_l1 + \
Expand Down Expand Up @@ -399,6 +420,7 @@ def generate_batch(self, net_student=None, targets=None):
# fix net_student
if not (net_student is None):
net_student = net_student.eval()
net_student = torch.DataParallel(net_student, device_ids=[i for i in range(self.gpus)])

if targets is not None:
targets = torch.from_numpy(np.array(targets).squeeze()).cuda()
Expand Down
6 changes: 4 additions & 2 deletions imagenet_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def run(args):
criterion=criterion,
coefficients = coefficients,
network_output_function = network_output_function,
hook_for_display = hook_for_display)
hook_for_display = hook_for_display,
gpus = args.gpu_ids)
net_student=None
if args.adi_scale != 0:
net_student = net_verifier
Expand All @@ -186,7 +187,7 @@ def main():
parser.add_argument('--bs', default=64, type=int, help='batch size')
parser.add_argument('--jitter', default=30, type=int, help='batch size')
parser.add_argument('--comment', default='', type=str, help='batch size')
parser.add_argument('--arch_name', default='resnet50', type=str, help='model name from torchvision or resnet50v15')
parser.add_argument('--arch_name', default='resnet101', type=str, help='model name from torchvision or resnet50v15')

parser.add_argument('--fp16', action='store_true', help='use FP16 for optimization')
parser.add_argument('--exp_name', type=str, default='test', help='where to store experimental data')
Expand All @@ -204,6 +205,7 @@ def main():
parser.add_argument('--l2', type=float, default=0.00001, help='l2 loss on the image')
parser.add_argument('--main_loss_multiplier', type=float, default=1.0, help='coefficient for the main loss in optimization')
parser.add_argument('--store_best_images', action='store_true', help='save best images as separate files')
parser.add_argument('--gpu_ids', nargs='+', type=int, help='ids of gpus to be used')

args = parser.parse_args()
print(args)
Expand Down