diff --git a/candidate_models/base_models/__init__.py b/candidate_models/base_models/__init__.py index 7094b85..44a4e84 100644 --- a/candidate_models/base_models/__init__.py +++ b/candidate_models/base_models/__init__.py @@ -194,6 +194,11 @@ def load_preprocess_images(image_filepaths): return wrapper +def spiking_vgg16(): + from .spiking_vgg import create_model + return create_model(architecture='VGG16') + + class BaseModelPool(UniqueKeyDict): """ Provides a set of standard models. @@ -273,6 +278,8 @@ def __init__(self): 'fixres_resnext101_32x48d_wsl': lambda: fixres( 'resnext101_32x48d_wsl', 'https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNeXt_101_32x48d.pth'), + + 'spiking-vgg16': spiking_vgg16, } # MobileNets for version, multiplier, image_size in [ diff --git a/candidate_models/base_models/models.csv b/candidate_models/base_models/models.csv index ed5814a..13a7855 100644 --- a/candidate_models/base_models/models.csv +++ b/candidate_models/base_models/models.csv @@ -1361,3 +1361,4 @@ fixres_resnext101_32x48d_wsl,http://openaccess.thecvf.com/content_ECCV_2018/html year = {2019}, month = {jun}, }",0.863, +spiking-vgg16,TODO: paper link,TODO: paper bibtex,0.651,0.817 diff --git a/candidate_models/base_models/spiking_vgg/__init__.py b/candidate_models/base_models/spiking_vgg/__init__.py new file mode 100644 index 0000000..ddc5246 --- /dev/null +++ b/candidate_models/base_models/spiking_vgg/__init__.py @@ -0,0 +1,47 @@ +import numpy as np +import torch +from torch import nn +from torchvision import transforms + +from model_tools.activations.core import ActivationsExtractorHelper +from model_tools.activations.pytorch import load_images +from .model import VGG_SNN_STDB + + +def create_model(timesteps=500, leak_mem=1.0, scaling_threshold=0.8, reset_threshold=0.0, default_threshold=1.0, + activation='STDB', architecture='VGG16', labels=1000, pretrained=True): + model = VGG_SNN_STDB(vgg_name=architecture, activation=activation, labels=labels, timesteps=timesteps, + leak_mem=leak_mem) + if pretrained: + pretrained_state = './snn_vgg16_imagenet.pth' + state = torch.load(pretrained_state, map_location='cpu') + model.load_state_dict(state['state_dict']) + model = nn.DataParallel(model) + if torch.cuda.is_available(): + model.cuda() + # VGG16 Imagenet thresholds + ann_thresholds = [10.16, 11.49, 2.65, 2.30, 0.77, 2.75, 1.33, 0.67, 1.13, 1.12, 0.43, 0.73, 1.08, 0.16, 0.58] + model.module.threshold_init(scaling_threshold=scaling_threshold, reset_threshold=reset_threshold, + thresholds=ann_thresholds[:], default_threshold=default_threshold) + model.eval() + model.module.network_init(timesteps) + + normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + transform = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + normalize, + ]) + + def load_and_preprocess(image_filepaths): + images = load_images(image_filepaths) + images = [transform(image) for image in images] + return np.concatenate(images) + + def get_activations(preprocessed_inputs, layer_names): + ... # TODO + + model = ActivationsExtractorHelper(identifier='spiking-vgg', + get_activations=get_activations, preprocessing=load_and_preprocess, + batch_size=35) + return model diff --git a/candidate_models/base_models/spiking_vgg/model.py b/candidate_models/base_models/spiking_vgg/model.py new file mode 100644 index 0000000..b21068e --- /dev/null +++ b/candidate_models/base_models/spiking_vgg/model.py @@ -0,0 +1,198 @@ +import copy + +import torch +import torch.nn as nn + +torch.manual_seed(2) + +cfg = { + 'VGG5': [64, 'A', 128, 'D', 128, 'A'], + 'VGG9': [64, 'A', 128, 'D', 128, 'A', 256, 'D', 256, 'A', 512, 'D', 512, 'D'], + 'VGG11': [64, 'A', 128, 'D', 256, 'A', 512, 'D', 512, 'D', 512, 'A', 512, 'D', 512, 'D'], + 'VGG16': [64, 'D', 64, 'A', 128, 'D', 128, 'A', 256, 'D', 256, 'D', 256, 'A', 512, 'D', 512, 'D', 512, 'A', 512, + 'D', 512, 'D', 512, 'A'] +} + + +class PoissonGenerator(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.mul(torch.le(torch.rand_like(input), torch.abs(input) * 1.0).float(), torch.sign(input)) + return out + + +class STDPSpike(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, last_spike): + ctx.save_for_backward(last_spike) + out = torch.zeros_like(input).cuda() + out[input > 0] = 1.0 + return out + + +class VGG_SNN_STDB(nn.Module): + + def __init__(self, vgg_name, activation='STDB', labels=1000, timesteps=75, leak_mem=0.99, drop=0.2): + super().__init__() + + self.timesteps = timesteps + self.vgg_name = vgg_name + self.labels = labels + self.leak_mem = leak_mem + if activation == 'Linear': + self.act_func = LinearSpike.apply + elif activation == 'STDB': + self.act_func = STDPSpike.apply + self.input_layer = PoissonGenerator() + + self.features, self.classifier = self._make_layers(cfg[self.vgg_name]) + + def threshold_init(self, scaling_threshold=1.0, reset_threshold=0.0, thresholds=[], default_threshold=1.0): + + # Initialize thresholds + self.scaling_threshold = scaling_threshold + self.reset_threshold = reset_threshold + + self.threshold = {} + # print('\nThresholds:') + + for pos in range(len(self.features)): + if isinstance(self.features[pos], nn.Conv2d): + self.threshold[pos] = thresholds.pop( + 0) * self.scaling_threshold + self.reset_threshold * default_threshold + # print('\t Layer{} : {:.2f}'.format(pos, self.threshold[pos])) + + prev = len(self.features) + + for pos in range(len(self.classifier) - 1): + if isinstance(self.classifier[pos], nn.Linear): + self.threshold[prev + pos] = thresholds.pop( + 0) * self.scaling_threshold + self.reset_threshold * default_threshold + # print('\t Layer{} : {:.2f}'.format(prev+pos, self.threshold[prev+pos])) + + def _make_layers(self, cfg): + layers = [] + in_channels = 3 + + for x in (cfg): + stride = 1 + + if x == 'A': + layers += [nn.AvgPool2d(kernel_size=2, stride=2)] + elif x == 'D': + layers += [nn.Dropout(0.2)] + else: + layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1, stride=stride, bias=False), + nn.ReLU(inplace=True) + ] + in_channels = x + + features = nn.Sequential(*layers) + + layers = [] + layers += [nn.Linear(512 * 7 * 7, 4096, bias=False)] + layers += [nn.ReLU(inplace=True)] + layers += [nn.Dropout(0.2)] + layers += [nn.Linear(4096, 4096, bias=False)] + layers += [nn.ReLU(inplace=True)] + layers += [nn.Dropout(0.2)] + layers += [nn.Linear(4096, self.labels, bias=False)] + + classifer = nn.Sequential(*layers) + return (features, classifer) + + def network_init(self, update_interval): + self.update_interval = update_interval + + def neuron_init(self, x): + + self.batch_size = x.size(0) + self.width = x.size(2) + self.height = x.size(3) + + self.mem = {} + self.spike = {} + self.mask = {} + self.spike_count = {} + + for l in range(len(self.features)): + + if isinstance(self.features[l], nn.Conv2d): + self.mem[l] = torch.zeros(self.batch_size, self.features[l].out_channels, self.width, self.height) + self.spike_count[l] = torch.zeros(self.mem[l].size()) + + elif isinstance(self.features[l], nn.Dropout): + self.mask[l] = self.features[l](torch.ones(self.mem[l - 2].shape)) + + elif isinstance(self.features[l], nn.AvgPool2d): + self.width = self.width // 2 + self.height = self.height // 2 + + prev = len(self.features) + + for l in range(len(self.classifier)): + + if isinstance(self.classifier[l], nn.Linear): + self.mem[prev + l] = torch.zeros(self.batch_size, self.classifier[l].out_features) + self.spike_count[prev+l] = torch.zeros(self.mem[prev+l].size()) + + elif isinstance(self.classifier[l], nn.Dropout): + self.mask[prev + l] = self.classifier[l](torch.ones(self.mem[prev + l - 2].shape)) + + self.spike = copy.deepcopy(self.mem) + for key, values in self.spike.items(): + for value in values: + value.fill_(-1000) + + def forward(self, x): + + self.neuron_init(x) + + for t in range(self.update_interval): + + out_prev = self.input_layer(x) + + for l in range(len(self.features)): + + if isinstance(self.features[l], (nn.Conv2d)): + mem_thr = (self.mem[l] / self.threshold[l]) - 1.0 + out = self.act_func(mem_thr, (t - 1 - self.spike[l])) + rst = self.threshold[l] * (mem_thr > 0).float() + self.spike[l] = self.spike[l].masked_fill(out.bool(), t - 1) + self.spike_count[l][out.bool()] = self.spike_count[l][out.bool()] + 1 + + self.mem[l] = self.leak_mem * self.mem[l] + self.features[l](out_prev) - rst + out_prev = out.clone() + + elif isinstance(self.features[l], nn.AvgPool2d): + out_prev = self.features[l](out_prev) + + elif isinstance(self.features[l], nn.Dropout): + out_prev = out_prev * self.mask[l] + + out_prev = out_prev.reshape(self.batch_size, -1) + prev = len(self.features) + + for l in range(len(self.classifier) - 1): + + if isinstance(self.classifier[l], (nn.Linear)): + mem_thr = (self.mem[prev + l] / self.threshold[prev + l]) - 1.0 + out = self.act_func(mem_thr, (t - 1 - self.spike[prev + l])) + rst = self.threshold[prev + l] * (mem_thr > 0).float() + self.spike[prev + l] = self.spike[prev + l].masked_fill(out.bool(), t - 1) + self.spike_count[prev+l][out.bool()] = self.spike_count[prev+l][out.bool()] + 1 + + self.mem[prev + l] = self.leak_mem * self.mem[prev + l] + self.classifier[l](out_prev) - rst + out_prev = out.clone() + + elif isinstance(self.classifier[l], nn.Dropout): + out_prev = out_prev * self.mask[prev + l] + + # Compute the classification layer outputs + self.mem[prev + l + 1] = self.mem[prev + l + 1] + self.classifier[l + 1](out_prev) + + return self.mem[prev + l + 1] diff --git a/candidate_models/model_commitments/ml_pool.py b/candidate_models/model_commitments/ml_pool.py index ae884e1..e2963e8 100644 --- a/candidate_models/model_commitments/ml_pool.py +++ b/candidate_models/model_commitments/ml_pool.py @@ -171,6 +171,7 @@ def __init__(self): 'resnext101_32x32d_wsl': self._resnext101_layers(), 'resnext101_32x48d_wsl': self._resnext101_layers(), 'fixres_resnext101_32x48d_wsl': self._resnext101_layers(), + 'spiking-vgg16': ['spike'], # TODO } for basemodel_identifier, default_layers in layers.items(): self[basemodel_identifier] = default_layers diff --git a/tests/test_imagenet.py b/tests/test_imagenet.py index db47ae5..1a084d7 100644 --- a/tests/test_imagenet.py +++ b/tests/test_imagenet.py @@ -95,6 +95,8 @@ class TestImagenet: ('resnext101_32x48d_wsl', .854), # FixRes: from https://arxiv.org/pdf/1906.06423.pdf, Table 8 ('fixres_resnext101_32x48d_wsl', .863), + # spiking-vgg: from + ('spiking-vgg16', .651), ]) def test_top1(self, model, expected_top1): # clear tf graph