From 343ab859578841193d2d71d650e1cdef3cc42f05 Mon Sep 17 00:00:00 2001 From: alexandrosstergiou Date: Thu, 3 Nov 2022 11:24:32 +0100 Subject: [PATCH] multi-gpu support --- README.md | 3 ++- deepinversion.py | 44 ++++++++++++++++++++++++++++++++----------- imagenet_inversion.py | 6 ++++-- 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index ed0bb16..7e6783b 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Code was tested on NVIDIA V100 GPU and Titan X Pascal. This snippet will generate 84 images by inverting resnet50 model from torchvision package. -`python imagenet_inversion.py --bs=84 --do_flip --exp_name="rn50_inversion" --r_feature=0.01 --arch_name="resnet50" --verifier --adi_scale=0.0 --setting_id=0 --lr 0.25` +`python imagenet_inversion.py --bs=84 --do_flip --exp_name="rn50_inversion" --r_feature=0.01 --arch_name="resnet50" --verifier --adi_scale=0.0 --setting_id=0 --lr 0.25 --gpu_ids 0` Arguments: @@ -61,6 +61,7 @@ Useful to observe generalizability of generated images. - `setting_id` - settings for optimization: 0 - multi resolution scheme, 1 - 2k iterations full resolution, 2 - 20k iterations (the closes to ResNet50 experiments in the paper). Recommended to use setting_id={0, 1} - `adi_scale` - competition coefficient. With positive value will lead to images that are good for the original model, but bad for verifier. Value 0.2 was used in the paper. - `random_label` - randomly select classes for inversion. Without this argument the code will generate hand picked classes. +- `gpu_ids` - device ids for using single/multi-gpu training. After 3k iterations (~6 mins on NVIDIA V100) generation is done: `Verifier accuracy: 91.6...%` (experiment with >98% verifier accuracy can be found `/example_logs`). We generated images by inverting vanilla ResNet50 (not trained for image generation) and classification accuracy by MobileNetv2 is >90%. A grid of images look like (from `/final_images/`, reduced quality due to JPEG compression. ) ![Generated grid of images](example_logs/fp32_set0_rn50_first_bn_scaled.jpg "ResNet50 Inverted images") diff --git a/deepinversion.py b/deepinversion.py index 2dcdb30..86f9830 100644 --- a/deepinversion.py +++ b/deepinversion.py @@ -22,6 +22,7 @@ import torchvision.utils as vutils from PIL import Image import numpy as np +import sys from utils.utils import lr_cosine_policy, lr_policy, beta_policy, mom_cosine_policy, clip, denormalize, create_folder @@ -33,19 +34,26 @@ class DeepInversionFeatureHook(): ''' def __init__(self, module): self.hook = module.register_forward_hook(self.hook_fn) + + self.r_feature_glob = {} def hook_fn(self, module, input, output): # hook co compute deepinversion's feature distribution regularization nch = input[0].shape[1] mean = input[0].mean([0, 2, 3]) var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False) - + #forcing mean and variance to match between two distributions #other ways might work better, i.g. KL divergence + + # single-gpu r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm( module.running_mean.data - mean, 2) - self.r_feature = r_feature + + # multi-gpu + self.r_feature_glob[r_feature.device] = r_feature + # must have no output def close(self): @@ -76,7 +84,8 @@ def __init__(self, bs=84, criterion=None, coefficients=dict(), network_output_function=lambda x: x, - hook_for_display = None): + hook_for_display = None, + gpus=[i for i in range(torch.cuda.device_count())]): ''' :param bs: batch size per GPU for image generation :param use_fp16: use FP16 (or APEX AMP) for model inversion, uses less memory and is faster for GPUs with Tensor Cores @@ -105,13 +114,14 @@ def __init__(self, bs=84, "adi_scale" - coefficient for Adaptive DeepInversion, competition, def =0 means no competition network_output_function: function to be applied to the output of the network to get the output hook_for_display: function to be executed at every print/save call, useful to check accuracy of verifier + gpus: list containing the ids of the gpus to be used ''' print("Deep inversion class generation") # for reproducibility torch.manual_seed(torch.cuda.current_device()) + self.gpus = gpus - self.net_teacher = net_teacher if "resolution" in parameters.keys(): self.image_resolution = parameters["resolution"] @@ -168,13 +178,15 @@ def __init__(self, bs=84, ## Create hooks for feature statistics self.loss_r_feature_layers = [] - for module in self.net_teacher.modules(): + for module in net_teacher.modules(): if isinstance(module, nn.BatchNorm2d): self.loss_r_feature_layers.append(DeepInversionFeatureHook(module)) self.hook_for_display = None if hook_for_display is not None: self.hook_for_display = hook_for_display + + self.net_teacher = torch.nn.DataParallel(net_teacher, device_ids=self.gpus) def get_images(self, net_student=None, targets=None): print("get_images call") @@ -204,8 +216,7 @@ def get_images(self, net_student=None, targets=None): img_original = self.image_resolution data_type = torch.half if use_fp16 else torch.float - inputs = torch.randn((self.bs, 3, img_original, img_original), requires_grad=True, device='cuda', - dtype=data_type) + inputs = torch.randn((self.bs, 3, img_original, img_original), requires_grad=True, device='cuda', dtype=data_type) pooling_function = nn.modules.pooling.AvgPool2d(kernel_size=2) if self.setting_id==0: @@ -271,7 +282,7 @@ def get_images(self, net_student=None, targets=None): # forward pass optimizer.zero_grad() net_teacher.zero_grad() - + outputs = net_teacher(inputs_jit) outputs = self.network_output_function(outputs) @@ -280,10 +291,20 @@ def get_images(self, net_student=None, targets=None): # R_prior losses loss_var_l1, loss_var_l2 = get_image_prior_losses(inputs_jit) - + # R_feature loss rescale = [self.first_bn_multiplier] + [1. for _ in range(len(self.loss_r_feature_layers)-1)] - loss_r_feature = sum([mod.r_feature * rescale[idx] for (idx, mod) in enumerate(self.loss_r_feature_layers)]) + + loss_r_feature = 0 + for idx,layer in enumerate(self.loss_r_feature_layers): + l_g = layer.r_feature_glob + r_feature_layer_loss = [] + for value in l_g.values(): + r_feature_layer_loss.append(value.to('cuda') * rescale[idx]) + r_feature_layer_loss = sum(r_feature_layer_loss) + loss_r_feature = loss_r_feature + r_feature_layer_loss + + # R_ADI loss_verifier_cig = torch.zeros(1) @@ -316,7 +337,7 @@ def get_images(self, net_student=None, targets=None): # l2 loss on images loss_l2 = torch.norm(inputs_jit.view(self.bs, -1), dim=1).mean() - + # combining losses loss_aux = self.var_scale_l2 * loss_var_l2 + \ self.var_scale_l1 * loss_var_l1 + \ @@ -399,6 +420,7 @@ def generate_batch(self, net_student=None, targets=None): # fix net_student if not (net_student is None): net_student = net_student.eval() + net_student = torch.DataParallel(net_student, device_ids=[i for i in range(self.gpus)]) if targets is not None: targets = torch.from_numpy(np.array(targets).squeeze()).cuda() diff --git a/imagenet_inversion.py b/imagenet_inversion.py index 0d09fbc..5395ad4 100644 --- a/imagenet_inversion.py +++ b/imagenet_inversion.py @@ -168,7 +168,8 @@ def run(args): criterion=criterion, coefficients = coefficients, network_output_function = network_output_function, - hook_for_display = hook_for_display) + hook_for_display = hook_for_display, + gpus = args.gpu_ids) net_student=None if args.adi_scale != 0: net_student = net_verifier @@ -186,7 +187,7 @@ def main(): parser.add_argument('--bs', default=64, type=int, help='batch size') parser.add_argument('--jitter', default=30, type=int, help='batch size') parser.add_argument('--comment', default='', type=str, help='batch size') - parser.add_argument('--arch_name', default='resnet50', type=str, help='model name from torchvision or resnet50v15') + parser.add_argument('--arch_name', default='resnet101', type=str, help='model name from torchvision or resnet50v15') parser.add_argument('--fp16', action='store_true', help='use FP16 for optimization') parser.add_argument('--exp_name', type=str, default='test', help='where to store experimental data') @@ -204,6 +205,7 @@ def main(): parser.add_argument('--l2', type=float, default=0.00001, help='l2 loss on the image') parser.add_argument('--main_loss_multiplier', type=float, default=1.0, help='coefficient for the main loss in optimization') parser.add_argument('--store_best_images', action='store_true', help='save best images as separate files') + parser.add_argument('--gpu_ids', nargs='+', type=int, help='ids of gpus to be used') args = parser.parse_args() print(args)