From 6c83601fb8e570a491ea4b7d931c524be66c58e6 Mon Sep 17 00:00:00 2001 From: Nitin Rathi Date: Thu, 19 Sep 2019 16:13:01 -0400 Subject: [PATCH 1/8] added snn files --- .../model_commitments/imagenet_test_bscore.py | 146 +++++++++++++ .../model_commitments/spiking_model_bscore.py | 201 ++++++++++++++++++ 2 files changed, 347 insertions(+) create mode 100644 candidate_models/model_commitments/imagenet_test_bscore.py create mode 100644 candidate_models/model_commitments/spiking_model_bscore.py diff --git a/candidate_models/model_commitments/imagenet_test_bscore.py b/candidate_models/model_commitments/imagenet_test_bscore.py new file mode 100644 index 0000000..6a2ef4e --- /dev/null +++ b/candidate_models/model_commitments/imagenet_test_bscore.py @@ -0,0 +1,146 @@ +#--------------------------------------------------- +# Imports +#--------------------------------------------------- +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms, models +from torch.utils.data.dataloader import DataLoader +from matplotlib import pyplot as plt +from matplotlib.gridspec import GridSpec +import numpy as np +import datetime +import pdb +from spiking_model_bscore import * +import sys +import os + + +use_cuda = True + +torch.manual_seed(2) +if torch.cuda.is_available() and use_cuda: + print ("\n \t ------- Running on GPU -------") + torch.set_default_tensor_type('torch.cuda.FloatTensor') + + +def test(loader): + + with torch.no_grad(): + model.eval() + total_loss = 0 + correct = 0 + is_best = False + print_accuracy_every_batch = True + global max_correct + + model.module.network_init(timesteps) + + for batch_idx, (data, target) in enumerate(loader): + + if torch.cuda.is_available() and use_cuda: + data, target = data.cuda(), target.cuda() + + output = model(data) + output = output/(timesteps) + + loss = F.cross_entropy(output,target) + total_loss += loss.item() + pred = output.max(1, keepdim=True)[1] + correct += pred.eq(target.data.view_as(pred)).cpu().sum() + + if print_accuracy_every_batch: + + print('\nAccuracy: {}/{}({:.2f}%)'.format( + correct.item(), + (batch_idx+1)*data.size(0), + 100. * correct.item() / ((batch_idx+1)*data.size(0)) + ) + ) + + print('\nTest set: Loss: {:.6f}, Current: {:.2f}%, Best: {:.2f}%\n'. format( + total_loss/(batch_idx+1), + 100. * correct.item() / len(test_loader.dataset), + 100. * max_correct.item() / len(test_loader.dataset) + ) + ) + + +dataset = 'IMAGENET' +batch_size = 35 +timesteps = 500 +num_workers = 4 +leak_mem = 1.0 +scaling_threshold = 0.8 +reset_threshold = 0.0 +default_threshold = 1.0 +activation = 'STDB' +architecture = 'VGG16' +pretrained = True +pretrained_state = './snn_vgg16_imagenet.pth' + +if dataset == 'IMAGENET': + labels = 1000 + traindir = os.path.join('/data2/backup/imagenet2012/', 'train') + valdir = os.path.join('/data2/backup/imagenet2012/', 'val') + normalize = transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]) + trainset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + testset = datasets.ImageFolder( + valdir, + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + +train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) +test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) + +model = VGG_SNN_STDB(vgg_name = architecture, activation = activation, labels=labels, timesteps=timesteps, leak_mem=leak_mem) + +if pretrained: + state = torch.load(pretrained_state, map_location='cpu') + model.load_state_dict(state['state_dict']) + +model = nn.DataParallel(model) + +if torch.cuda.is_available() and use_cuda: + model.cuda() + + +criterion = nn.CrossEntropyLoss() + +print('Dataset :{} '.format(dataset)) +print('Batch Size :{} '.format(batch_size)) +print('Timesteps :{} '.format(timesteps)) + +print('Membrane Leak :{} '.format(leak_mem)) +print('Scaling Threshold :{} '.format(scaling_threshold)) +print('Activation :{} '.format(activation)) +print('Architecture :{} '.format(architecture)) +if pretrained: + print('Pretrained Weight File :{} '.format(pretrained_state)) + +print('Criterion :{} '.format(criterion)) +print('\n{}'.format(model)) + +start_time = datetime.datetime.now() + +#VGG16 Imagenet thresholds +ann_thresholds = [10.16, 11.49, 2.65, 2.30, 0.77, 2.75, 1.33, 0.67, 1.13, 1.12, 0.43, 0.73, 1.08, 0.16, 0.58] + +model.module.threshold_init(scaling_threshold=scaling_threshold, reset_threshold=reset_threshold, thresholds = ann_thresholds[:], default_threshold=default_threshold) + + +test(test_loader) diff --git a/candidate_models/model_commitments/spiking_model_bscore.py b/candidate_models/model_commitments/spiking_model_bscore.py new file mode 100644 index 0000000..5ce1480 --- /dev/null +++ b/candidate_models/model_commitments/spiking_model_bscore.py @@ -0,0 +1,201 @@ +#--------------------------------------------------- +# Imports +#--------------------------------------------------- +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import pdb +import math +from collections import OrderedDict +from matplotlib import pyplot as plt +import copy + +torch.manual_seed(2) + +cfg = { + 'VGG5' : [64, 'A', 128, 'D', 128, 'A'], + 'VGG9': [64, 'A', 128, 'D', 128, 'A', 256, 'D', 256, 'A', 512, 'D', 512, 'D'], + 'VGG11': [64, 'A', 128, 'D', 256, 'A', 512, 'D', 512, 'D', 512, 'A', 512, 'D', 512, 'D'], + 'VGG16': [64, 'D', 64, 'A', 128, 'D', 128, 'A', 256, 'D', 256, 'D', 256, 'A', 512, 'D', 512, 'D', 512, 'A', 512, 'D', 512, 'D', 512, 'A'] +} +class PoissonGenerator(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self,input): + + out = torch.mul(torch.le(torch.rand_like(input), torch.abs(input)*1.0).float(),torch.sign(input)) + return out + +class STDPSpike(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, last_spike): + + ctx.save_for_backward(last_spike) + out = torch.zeros_like(input).cuda() + out[input > 0] = 1.0 + return out + +class VGG_SNN_STDB(nn.Module): + + def __init__(self, vgg_name, activation='STDB', labels=1000, timesteps=75, leak_mem=0.99, drop=0.2): + super().__init__() + + self.timesteps = timesteps + self.vgg_name = vgg_name + self.labels = labels + self.leak_mem = leak_mem + if activation == 'Linear': + self.act_func = LinearSpike.apply + elif activation == 'STDB': + self.act_func = STDPSpike.apply + self.input_layer = PoissonGenerator() + + self.features, self.classifier = self._make_layers(cfg[self.vgg_name]) + + def threshold_init(self, scaling_threshold=1.0, reset_threshold=0.0, thresholds=[], default_threshold=1.0): + + # Initialize thresholds + self.scaling_threshold = scaling_threshold + self.reset_threshold = reset_threshold + + self.threshold = {} + #print('\nThresholds:') + + for pos in range(len(self.features)): + if isinstance(self.features[pos], nn.Conv2d): + self.threshold[pos] = thresholds.pop(0) * self.scaling_threshold + self.reset_threshold * default_threshold + #print('\t Layer{} : {:.2f}'.format(pos, self.threshold[pos])) + + prev = len(self.features) + + for pos in range(len(self.classifier)-1): + if isinstance(self.classifier[pos], nn.Linear): + self.threshold[prev+pos] = thresholds.pop(0) * self.scaling_threshold + self.reset_threshold * default_threshold + #print('\t Layer{} : {:.2f}'.format(prev+pos, self.threshold[prev+pos])) + + + def _make_layers(self, cfg): + layers = [] + in_channels = 3 + + for x in (cfg): + stride = 1 + + if x == 'A': + layers += [nn.AvgPool2d(kernel_size=2, stride=2)] + elif x == 'D': + layers += [nn.Dropout(0.2)] + else: + layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1, stride=stride, bias=False), + nn.ReLU(inplace=True) + ] + in_channels = x + + features = nn.Sequential(*layers) + + layers = [] + layers += [nn.Linear(512*7*7, 4096, bias=False)] + layers += [nn.ReLU(inplace=True)] + layers += [nn.Dropout(0.2)] + layers += [nn.Linear(4096, 4096, bias=False)] + layers += [nn.ReLU(inplace=True)] + layers += [nn.Dropout(0.2)] + layers += [nn.Linear(4096, self.labels, bias=False)] + + classifer = nn.Sequential(*layers) + return (features, classifer) + + + def network_init(self, update_interval): + self.update_interval = update_interval + + def neuron_init(self, x): + + self.batch_size = x.size(0) + self.width = x.size(2) + self.height = x.size(3) + + self.mem = {} + self.spike = {} + self.mask = {} + + for l in range(len(self.features)): + + if isinstance(self.features[l], nn.Conv2d): + self.mem[l] = torch.zeros(self.batch_size, self.features[l].out_channels, self.width, self.height) + + elif isinstance(self.features[l], nn.Dropout): + self.mask[l] = self.features[l](torch.ones(self.mem[l-2].shape)) + + elif isinstance(self.features[l], nn.AvgPool2d): + self.width = self.width//2 + self.height = self.height//2 + + prev = len(self.features) + + for l in range(len(self.classifier)): + + if isinstance(self.classifier[l], nn.Linear): + self.mem[prev+l] = torch.zeros(self.batch_size, self.classifier[l].out_features) + + elif isinstance(self.classifier[l], nn.Dropout): + self.mask[prev+l] = self.classifier[l](torch.ones(self.mem[prev+l-2].shape)) + + + self.spike = copy.deepcopy(self.mem) + for key, values in self.spike.items(): + for value in values: + value.fill_(-1000) + + + + def forward(self, x): + + self.neuron_init(x) + + for t in range(self.update_interval): + + out_prev = self.input_layer(x) + + for l in range(len(self.features)): + + if isinstance(self.features[l], (nn.Conv2d)): + mem_thr = (self.mem[l]/self.threshold[l]) - 1.0 + out = self.act_func(mem_thr, (t-1-self.spike[l])) + rst = self.threshold[l] * (mem_thr>0).float() + self.spike[l] = self.spike[l].masked_fill(out.byte(),t-1) + + self.mem[l] = self.leak_mem*self.mem[l] + self.features[l](out_prev) - rst + out_prev = out.clone() + + elif isinstance(self.features[l], nn.AvgPool2d): + out_prev = self.features[l](out_prev) + + elif isinstance(self.features[l], nn.Dropout): + out_prev = out_prev * self.mask[l] + + out_prev = out_prev.reshape(self.batch_size, -1) + prev = len(self.features) + + for l in range(len(self.classifier)-1): + + if isinstance(self.classifier[l], (nn.Linear)): + mem_thr = (self.mem[prev+l]/self.threshold[prev+l]) - 1.0 + out = self.act_func(mem_thr, (t-1-self.spike[prev+l])) + rst = self.threshold[prev+l] * (mem_thr>0).float() + self.spike[prev+l] = self.spike[prev+l].masked_fill(out.byte(),t-1) + + self.mem[prev+l] = self.leak_mem*self.mem[prev+l] + self.classifier[l](out_prev) - rst + out_prev = out.clone() + + elif isinstance(self.classifier[l], nn.Dropout): + out_prev = out_prev * self.mask[prev+l] + + # Compute the classification layer outputs + self.mem[prev+l+1] = self.mem[prev+l+1] + self.classifier[l+1](out_prev) + + return self.mem[prev+l+1] \ No newline at end of file From 01efd599f6a50c09f291c70659dcb6ffc4985952 Mon Sep 17 00:00:00 2001 From: Martin Schrimpf Date: Tue, 24 Sep 2019 13:43:21 -0400 Subject: [PATCH 2/8] move to sub-package --- .../base_models/spiking_vgg/__init__.py | 0 .../spiking_vgg}/imagenet_test_bscore.py | 0 .../base_models/spiking_vgg/model.py | 201 ++++++++++++++++++ .../model_commitments/spiking_model_bscore.py | 201 ------------------ 4 files changed, 201 insertions(+), 201 deletions(-) create mode 100644 candidate_models/base_models/spiking_vgg/__init__.py rename candidate_models/{model_commitments => base_models/spiking_vgg}/imagenet_test_bscore.py (100%) create mode 100644 candidate_models/base_models/spiking_vgg/model.py delete mode 100644 candidate_models/model_commitments/spiking_model_bscore.py diff --git a/candidate_models/base_models/spiking_vgg/__init__.py b/candidate_models/base_models/spiking_vgg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/candidate_models/model_commitments/imagenet_test_bscore.py b/candidate_models/base_models/spiking_vgg/imagenet_test_bscore.py similarity index 100% rename from candidate_models/model_commitments/imagenet_test_bscore.py rename to candidate_models/base_models/spiking_vgg/imagenet_test_bscore.py diff --git a/candidate_models/base_models/spiking_vgg/model.py b/candidate_models/base_models/spiking_vgg/model.py new file mode 100644 index 0000000..5926527 --- /dev/null +++ b/candidate_models/base_models/spiking_vgg/model.py @@ -0,0 +1,201 @@ +# --------------------------------------------------- +# Imports +# --------------------------------------------------- +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import pdb +import math +from collections import OrderedDict +from matplotlib import pyplot as plt +import copy + +torch.manual_seed(2) + +cfg = { + 'VGG5': [64, 'A', 128, 'D', 128, 'A'], + 'VGG9': [64, 'A', 128, 'D', 128, 'A', 256, 'D', 256, 'A', 512, 'D', 512, 'D'], + 'VGG11': [64, 'A', 128, 'D', 256, 'A', 512, 'D', 512, 'D', 512, 'A', 512, 'D', 512, 'D'], + 'VGG16': [64, 'D', 64, 'A', 128, 'D', 128, 'A', 256, 'D', 256, 'D', 256, 'A', 512, 'D', 512, 'D', 512, 'A', 512, + 'D', 512, 'D', 512, 'A'] +} + + +class PoissonGenerator(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.mul(torch.le(torch.rand_like(input), torch.abs(input) * 1.0).float(), torch.sign(input)) + return out + + +class STDPSpike(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, last_spike): + ctx.save_for_backward(last_spike) + out = torch.zeros_like(input).cuda() + out[input > 0] = 1.0 + return out + + +class VGG_SNN_STDB(nn.Module): + + def __init__(self, vgg_name, activation='STDB', labels=1000, timesteps=75, leak_mem=0.99, drop=0.2): + super().__init__() + + self.timesteps = timesteps + self.vgg_name = vgg_name + self.labels = labels + self.leak_mem = leak_mem + if activation == 'Linear': + self.act_func = LinearSpike.apply + elif activation == 'STDB': + self.act_func = STDPSpike.apply + self.input_layer = PoissonGenerator() + + self.features, self.classifier = self._make_layers(cfg[self.vgg_name]) + + def threshold_init(self, scaling_threshold=1.0, reset_threshold=0.0, thresholds=[], default_threshold=1.0): + + # Initialize thresholds + self.scaling_threshold = scaling_threshold + self.reset_threshold = reset_threshold + + self.threshold = {} + # print('\nThresholds:') + + for pos in range(len(self.features)): + if isinstance(self.features[pos], nn.Conv2d): + self.threshold[pos] = thresholds.pop( + 0) * self.scaling_threshold + self.reset_threshold * default_threshold + # print('\t Layer{} : {:.2f}'.format(pos, self.threshold[pos])) + + prev = len(self.features) + + for pos in range(len(self.classifier) - 1): + if isinstance(self.classifier[pos], nn.Linear): + self.threshold[prev + pos] = thresholds.pop( + 0) * self.scaling_threshold + self.reset_threshold * default_threshold + # print('\t Layer{} : {:.2f}'.format(prev+pos, self.threshold[prev+pos])) + + def _make_layers(self, cfg): + layers = [] + in_channels = 3 + + for x in (cfg): + stride = 1 + + if x == 'A': + layers += [nn.AvgPool2d(kernel_size=2, stride=2)] + elif x == 'D': + layers += [nn.Dropout(0.2)] + else: + layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1, stride=stride, bias=False), + nn.ReLU(inplace=True) + ] + in_channels = x + + features = nn.Sequential(*layers) + + layers = [] + layers += [nn.Linear(512 * 7 * 7, 4096, bias=False)] + layers += [nn.ReLU(inplace=True)] + layers += [nn.Dropout(0.2)] + layers += [nn.Linear(4096, 4096, bias=False)] + layers += [nn.ReLU(inplace=True)] + layers += [nn.Dropout(0.2)] + layers += [nn.Linear(4096, self.labels, bias=False)] + + classifer = nn.Sequential(*layers) + return (features, classifer) + + def network_init(self, update_interval): + self.update_interval = update_interval + + def neuron_init(self, x): + + self.batch_size = x.size(0) + self.width = x.size(2) + self.height = x.size(3) + + self.mem = {} + self.spike = {} + self.mask = {} + + for l in range(len(self.features)): + + if isinstance(self.features[l], nn.Conv2d): + self.mem[l] = torch.zeros(self.batch_size, self.features[l].out_channels, self.width, self.height) + + elif isinstance(self.features[l], nn.Dropout): + self.mask[l] = self.features[l](torch.ones(self.mem[l - 2].shape)) + + elif isinstance(self.features[l], nn.AvgPool2d): + self.width = self.width // 2 + self.height = self.height // 2 + + prev = len(self.features) + + for l in range(len(self.classifier)): + + if isinstance(self.classifier[l], nn.Linear): + self.mem[prev + l] = torch.zeros(self.batch_size, self.classifier[l].out_features) + + elif isinstance(self.classifier[l], nn.Dropout): + self.mask[prev + l] = self.classifier[l](torch.ones(self.mem[prev + l - 2].shape)) + + self.spike = copy.deepcopy(self.mem) + for key, values in self.spike.items(): + for value in values: + value.fill_(-1000) + + def forward(self, x): + + self.neuron_init(x) + + for t in range(self.update_interval): + + out_prev = self.input_layer(x) + + for l in range(len(self.features)): + + if isinstance(self.features[l], (nn.Conv2d)): + mem_thr = (self.mem[l] / self.threshold[l]) - 1.0 + out = self.act_func(mem_thr, (t - 1 - self.spike[l])) + rst = self.threshold[l] * (mem_thr > 0).float() + self.spike[l] = self.spike[l].masked_fill(out.byte(), t - 1) + + self.mem[l] = self.leak_mem * self.mem[l] + self.features[l](out_prev) - rst + out_prev = out.clone() + + elif isinstance(self.features[l], nn.AvgPool2d): + out_prev = self.features[l](out_prev) + + elif isinstance(self.features[l], nn.Dropout): + out_prev = out_prev * self.mask[l] + + out_prev = out_prev.reshape(self.batch_size, -1) + prev = len(self.features) + + for l in range(len(self.classifier) - 1): + + if isinstance(self.classifier[l], (nn.Linear)): + mem_thr = (self.mem[prev + l] / self.threshold[prev + l]) - 1.0 + out = self.act_func(mem_thr, (t - 1 - self.spike[prev + l])) + rst = self.threshold[prev + l] * (mem_thr > 0).float() + self.spike[prev + l] = self.spike[prev + l].masked_fill(out.byte(), t - 1) + + self.mem[prev + l] = self.leak_mem * self.mem[prev + l] + self.classifier[l](out_prev) - rst + out_prev = out.clone() + + elif isinstance(self.classifier[l], nn.Dropout): + out_prev = out_prev * self.mask[prev + l] + + # Compute the classification layer outputs + self.mem[prev + l + 1] = self.mem[prev + l + 1] + self.classifier[l + 1](out_prev) + + return self.mem[prev + l + 1] \ No newline at end of file diff --git a/candidate_models/model_commitments/spiking_model_bscore.py b/candidate_models/model_commitments/spiking_model_bscore.py deleted file mode 100644 index 5ce1480..0000000 --- a/candidate_models/model_commitments/spiking_model_bscore.py +++ /dev/null @@ -1,201 +0,0 @@ -#--------------------------------------------------- -# Imports -#--------------------------------------------------- -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -import pdb -import math -from collections import OrderedDict -from matplotlib import pyplot as plt -import copy - -torch.manual_seed(2) - -cfg = { - 'VGG5' : [64, 'A', 128, 'D', 128, 'A'], - 'VGG9': [64, 'A', 128, 'D', 128, 'A', 256, 'D', 256, 'A', 512, 'D', 512, 'D'], - 'VGG11': [64, 'A', 128, 'D', 256, 'A', 512, 'D', 512, 'D', 512, 'A', 512, 'D', 512, 'D'], - 'VGG16': [64, 'D', 64, 'A', 128, 'D', 128, 'A', 256, 'D', 256, 'D', 256, 'A', 512, 'D', 512, 'D', 512, 'A', 512, 'D', 512, 'D', 512, 'A'] -} -class PoissonGenerator(nn.Module): - - def __init__(self): - super().__init__() - - def forward(self,input): - - out = torch.mul(torch.le(torch.rand_like(input), torch.abs(input)*1.0).float(),torch.sign(input)) - return out - -class STDPSpike(torch.autograd.Function): - - @staticmethod - def forward(ctx, input, last_spike): - - ctx.save_for_backward(last_spike) - out = torch.zeros_like(input).cuda() - out[input > 0] = 1.0 - return out - -class VGG_SNN_STDB(nn.Module): - - def __init__(self, vgg_name, activation='STDB', labels=1000, timesteps=75, leak_mem=0.99, drop=0.2): - super().__init__() - - self.timesteps = timesteps - self.vgg_name = vgg_name - self.labels = labels - self.leak_mem = leak_mem - if activation == 'Linear': - self.act_func = LinearSpike.apply - elif activation == 'STDB': - self.act_func = STDPSpike.apply - self.input_layer = PoissonGenerator() - - self.features, self.classifier = self._make_layers(cfg[self.vgg_name]) - - def threshold_init(self, scaling_threshold=1.0, reset_threshold=0.0, thresholds=[], default_threshold=1.0): - - # Initialize thresholds - self.scaling_threshold = scaling_threshold - self.reset_threshold = reset_threshold - - self.threshold = {} - #print('\nThresholds:') - - for pos in range(len(self.features)): - if isinstance(self.features[pos], nn.Conv2d): - self.threshold[pos] = thresholds.pop(0) * self.scaling_threshold + self.reset_threshold * default_threshold - #print('\t Layer{} : {:.2f}'.format(pos, self.threshold[pos])) - - prev = len(self.features) - - for pos in range(len(self.classifier)-1): - if isinstance(self.classifier[pos], nn.Linear): - self.threshold[prev+pos] = thresholds.pop(0) * self.scaling_threshold + self.reset_threshold * default_threshold - #print('\t Layer{} : {:.2f}'.format(prev+pos, self.threshold[prev+pos])) - - - def _make_layers(self, cfg): - layers = [] - in_channels = 3 - - for x in (cfg): - stride = 1 - - if x == 'A': - layers += [nn.AvgPool2d(kernel_size=2, stride=2)] - elif x == 'D': - layers += [nn.Dropout(0.2)] - else: - layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1, stride=stride, bias=False), - nn.ReLU(inplace=True) - ] - in_channels = x - - features = nn.Sequential(*layers) - - layers = [] - layers += [nn.Linear(512*7*7, 4096, bias=False)] - layers += [nn.ReLU(inplace=True)] - layers += [nn.Dropout(0.2)] - layers += [nn.Linear(4096, 4096, bias=False)] - layers += [nn.ReLU(inplace=True)] - layers += [nn.Dropout(0.2)] - layers += [nn.Linear(4096, self.labels, bias=False)] - - classifer = nn.Sequential(*layers) - return (features, classifer) - - - def network_init(self, update_interval): - self.update_interval = update_interval - - def neuron_init(self, x): - - self.batch_size = x.size(0) - self.width = x.size(2) - self.height = x.size(3) - - self.mem = {} - self.spike = {} - self.mask = {} - - for l in range(len(self.features)): - - if isinstance(self.features[l], nn.Conv2d): - self.mem[l] = torch.zeros(self.batch_size, self.features[l].out_channels, self.width, self.height) - - elif isinstance(self.features[l], nn.Dropout): - self.mask[l] = self.features[l](torch.ones(self.mem[l-2].shape)) - - elif isinstance(self.features[l], nn.AvgPool2d): - self.width = self.width//2 - self.height = self.height//2 - - prev = len(self.features) - - for l in range(len(self.classifier)): - - if isinstance(self.classifier[l], nn.Linear): - self.mem[prev+l] = torch.zeros(self.batch_size, self.classifier[l].out_features) - - elif isinstance(self.classifier[l], nn.Dropout): - self.mask[prev+l] = self.classifier[l](torch.ones(self.mem[prev+l-2].shape)) - - - self.spike = copy.deepcopy(self.mem) - for key, values in self.spike.items(): - for value in values: - value.fill_(-1000) - - - - def forward(self, x): - - self.neuron_init(x) - - for t in range(self.update_interval): - - out_prev = self.input_layer(x) - - for l in range(len(self.features)): - - if isinstance(self.features[l], (nn.Conv2d)): - mem_thr = (self.mem[l]/self.threshold[l]) - 1.0 - out = self.act_func(mem_thr, (t-1-self.spike[l])) - rst = self.threshold[l] * (mem_thr>0).float() - self.spike[l] = self.spike[l].masked_fill(out.byte(),t-1) - - self.mem[l] = self.leak_mem*self.mem[l] + self.features[l](out_prev) - rst - out_prev = out.clone() - - elif isinstance(self.features[l], nn.AvgPool2d): - out_prev = self.features[l](out_prev) - - elif isinstance(self.features[l], nn.Dropout): - out_prev = out_prev * self.mask[l] - - out_prev = out_prev.reshape(self.batch_size, -1) - prev = len(self.features) - - for l in range(len(self.classifier)-1): - - if isinstance(self.classifier[l], (nn.Linear)): - mem_thr = (self.mem[prev+l]/self.threshold[prev+l]) - 1.0 - out = self.act_func(mem_thr, (t-1-self.spike[prev+l])) - rst = self.threshold[prev+l] * (mem_thr>0).float() - self.spike[prev+l] = self.spike[prev+l].masked_fill(out.byte(),t-1) - - self.mem[prev+l] = self.leak_mem*self.mem[prev+l] + self.classifier[l](out_prev) - rst - out_prev = out.clone() - - elif isinstance(self.classifier[l], nn.Dropout): - out_prev = out_prev * self.mask[prev+l] - - # Compute the classification layer outputs - self.mem[prev+l+1] = self.mem[prev+l+1] + self.classifier[l+1](out_prev) - - return self.mem[prev+l+1] \ No newline at end of file From c94312b1395cbca7954f8cc527d3ae812b9599d7 Mon Sep 17 00:00:00 2001 From: Martin Schrimpf Date: Tue, 24 Sep 2019 13:52:30 -0400 Subject: [PATCH 3/8] combine model creation in __init__ --- .../base_models/spiking_vgg/__init__.py | 56 +++++++ .../spiking_vgg/imagenet_test_bscore.py | 146 ------------------ 2 files changed, 56 insertions(+), 146 deletions(-) delete mode 100644 candidate_models/base_models/spiking_vgg/imagenet_test_bscore.py diff --git a/candidate_models/base_models/spiking_vgg/__init__.py b/candidate_models/base_models/spiking_vgg/__init__.py index e69de29..eec2800 100644 --- a/candidate_models/base_models/spiking_vgg/__init__.py +++ b/candidate_models/base_models/spiking_vgg/__init__.py @@ -0,0 +1,56 @@ +import torch +from torch import nn +from torchvision import transforms + +from .model import VGG_SNN_STDB + + +def create_model(): + batch_size = 35 + timesteps = 500 + num_workers = 4 + leak_mem = 1.0 + scaling_threshold = 0.8 + reset_threshold = 0.0 + default_threshold = 1.0 + activation = 'STDB' + architecture = 'VGG16' + pretrained = True + pretrained_state = './snn_vgg16_imagenet.pth' + + normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + trainset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + testset = datasets.ImageFolder( + valdir, + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + + model = VGG_SNN_STDB(vgg_name=architecture, activation=activation, labels=labels, timesteps=timesteps, + leak_mem=leak_mem) + + if pretrained: + state = torch.load(pretrained_state, map_location='cpu') + model.load_state_dict(state['state_dict']) + + model = nn.DataParallel(model) + + if torch.cuda.is_available() and use_cuda: + model.cuda() + + # VGG16 Imagenet thresholds + ann_thresholds = [10.16, 11.49, 2.65, 2.30, 0.77, 2.75, 1.33, 0.67, 1.13, 1.12, 0.43, 0.73, 1.08, 0.16, 0.58] + model.module.threshold_init(scaling_threshold=scaling_threshold, reset_threshold=reset_threshold, + thresholds=ann_thresholds[:], default_threshold=default_threshold) + model.eval() + model.module.network_init(timesteps) diff --git a/candidate_models/base_models/spiking_vgg/imagenet_test_bscore.py b/candidate_models/base_models/spiking_vgg/imagenet_test_bscore.py deleted file mode 100644 index 6a2ef4e..0000000 --- a/candidate_models/base_models/spiking_vgg/imagenet_test_bscore.py +++ /dev/null @@ -1,146 +0,0 @@ -#--------------------------------------------------- -# Imports -#--------------------------------------------------- -from __future__ import print_function -import argparse -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms, models -from torch.utils.data.dataloader import DataLoader -from matplotlib import pyplot as plt -from matplotlib.gridspec import GridSpec -import numpy as np -import datetime -import pdb -from spiking_model_bscore import * -import sys -import os - - -use_cuda = True - -torch.manual_seed(2) -if torch.cuda.is_available() and use_cuda: - print ("\n \t ------- Running on GPU -------") - torch.set_default_tensor_type('torch.cuda.FloatTensor') - - -def test(loader): - - with torch.no_grad(): - model.eval() - total_loss = 0 - correct = 0 - is_best = False - print_accuracy_every_batch = True - global max_correct - - model.module.network_init(timesteps) - - for batch_idx, (data, target) in enumerate(loader): - - if torch.cuda.is_available() and use_cuda: - data, target = data.cuda(), target.cuda() - - output = model(data) - output = output/(timesteps) - - loss = F.cross_entropy(output,target) - total_loss += loss.item() - pred = output.max(1, keepdim=True)[1] - correct += pred.eq(target.data.view_as(pred)).cpu().sum() - - if print_accuracy_every_batch: - - print('\nAccuracy: {}/{}({:.2f}%)'.format( - correct.item(), - (batch_idx+1)*data.size(0), - 100. * correct.item() / ((batch_idx+1)*data.size(0)) - ) - ) - - print('\nTest set: Loss: {:.6f}, Current: {:.2f}%, Best: {:.2f}%\n'. format( - total_loss/(batch_idx+1), - 100. * correct.item() / len(test_loader.dataset), - 100. * max_correct.item() / len(test_loader.dataset) - ) - ) - - -dataset = 'IMAGENET' -batch_size = 35 -timesteps = 500 -num_workers = 4 -leak_mem = 1.0 -scaling_threshold = 0.8 -reset_threshold = 0.0 -default_threshold = 1.0 -activation = 'STDB' -architecture = 'VGG16' -pretrained = True -pretrained_state = './snn_vgg16_imagenet.pth' - -if dataset == 'IMAGENET': - labels = 1000 - traindir = os.path.join('/data2/backup/imagenet2012/', 'train') - valdir = os.path.join('/data2/backup/imagenet2012/', 'val') - normalize = transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]) - trainset = datasets.ImageFolder( - traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - testset = datasets.ImageFolder( - valdir, - transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])) - -train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) -test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) - -model = VGG_SNN_STDB(vgg_name = architecture, activation = activation, labels=labels, timesteps=timesteps, leak_mem=leak_mem) - -if pretrained: - state = torch.load(pretrained_state, map_location='cpu') - model.load_state_dict(state['state_dict']) - -model = nn.DataParallel(model) - -if torch.cuda.is_available() and use_cuda: - model.cuda() - - -criterion = nn.CrossEntropyLoss() - -print('Dataset :{} '.format(dataset)) -print('Batch Size :{} '.format(batch_size)) -print('Timesteps :{} '.format(timesteps)) - -print('Membrane Leak :{} '.format(leak_mem)) -print('Scaling Threshold :{} '.format(scaling_threshold)) -print('Activation :{} '.format(activation)) -print('Architecture :{} '.format(architecture)) -if pretrained: - print('Pretrained Weight File :{} '.format(pretrained_state)) - -print('Criterion :{} '.format(criterion)) -print('\n{}'.format(model)) - -start_time = datetime.datetime.now() - -#VGG16 Imagenet thresholds -ann_thresholds = [10.16, 11.49, 2.65, 2.30, 0.77, 2.75, 1.33, 0.67, 1.13, 1.12, 0.43, 0.73, 1.08, 0.16, 0.58] - -model.module.threshold_init(scaling_threshold=scaling_threshold, reset_threshold=reset_threshold, thresholds = ann_thresholds[:], default_threshold=default_threshold) - - -test(test_loader) From c998bfef1fce66ea80e651ff6f83d88c6f679a2b Mon Sep 17 00:00:00 2001 From: Martin Schrimpf Date: Tue, 24 Sep 2019 14:06:59 -0400 Subject: [PATCH 4/8] remove unused imports --- candidate_models/base_models/spiking_vgg/model.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/candidate_models/base_models/spiking_vgg/model.py b/candidate_models/base_models/spiking_vgg/model.py index 5926527..956a089 100644 --- a/candidate_models/base_models/spiking_vgg/model.py +++ b/candidate_models/base_models/spiking_vgg/model.py @@ -1,15 +1,7 @@ -# --------------------------------------------------- -# Imports -# --------------------------------------------------- +import copy + import torch import torch.nn as nn -import torch.nn.functional as F -import numpy as np -import pdb -import math -from collections import OrderedDict -from matplotlib import pyplot as plt -import copy torch.manual_seed(2) @@ -198,4 +190,4 @@ def forward(self, x): # Compute the classification layer outputs self.mem[prev + l + 1] = self.mem[prev + l + 1] + self.classifier[l + 1](out_prev) - return self.mem[prev + l + 1] \ No newline at end of file + return self.mem[prev + l + 1] From 64a0aa9f4132db00f9bba16f902abb8a36887c3f Mon Sep 17 00:00:00 2001 From: Martin Schrimpf Date: Tue, 24 Sep 2019 14:08:39 -0400 Subject: [PATCH 5/8] skeleton-wrap model; add to basemodel-pool; add skeletons for layers, csv, imagenet test --- candidate_models/base_models/__init__.py | 7 +++ candidate_models/base_models/models.csv | 1 + .../base_models/spiking_vgg/__init__.py | 63 ++++++++----------- candidate_models/model_commitments/ml_pool.py | 1 + tests/test_imagenet.py | 2 + 5 files changed, 38 insertions(+), 36 deletions(-) diff --git a/candidate_models/base_models/__init__.py b/candidate_models/base_models/__init__.py index 7094b85..44a4e84 100644 --- a/candidate_models/base_models/__init__.py +++ b/candidate_models/base_models/__init__.py @@ -194,6 +194,11 @@ def load_preprocess_images(image_filepaths): return wrapper +def spiking_vgg16(): + from .spiking_vgg import create_model + return create_model(architecture='VGG16') + + class BaseModelPool(UniqueKeyDict): """ Provides a set of standard models. @@ -273,6 +278,8 @@ def __init__(self): 'fixres_resnext101_32x48d_wsl': lambda: fixres( 'resnext101_32x48d_wsl', 'https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNeXt_101_32x48d.pth'), + + 'spiking-vgg16': spiking_vgg16, } # MobileNets for version, multiplier, image_size in [ diff --git a/candidate_models/base_models/models.csv b/candidate_models/base_models/models.csv index ed5814a..0d9aea4 100644 --- a/candidate_models/base_models/models.csv +++ b/candidate_models/base_models/models.csv @@ -1361,3 +1361,4 @@ fixres_resnext101_32x48d_wsl,http://openaccess.thecvf.com/content_ECCV_2018/html year = {2019}, month = {jun}, }",0.863, +spiking-vgg16,TODO: paper link,TODO: paper bibtex,TODO: ImageNet top-1,TODO: ImageNet top-5 diff --git a/candidate_models/base_models/spiking_vgg/__init__.py b/candidate_models/base_models/spiking_vgg/__init__.py index eec2800..ddc5246 100644 --- a/candidate_models/base_models/spiking_vgg/__init__.py +++ b/candidate_models/base_models/spiking_vgg/__init__.py @@ -1,56 +1,47 @@ +import numpy as np import torch from torch import nn from torchvision import transforms +from model_tools.activations.core import ActivationsExtractorHelper +from model_tools.activations.pytorch import load_images from .model import VGG_SNN_STDB -def create_model(): - batch_size = 35 - timesteps = 500 - num_workers = 4 - leak_mem = 1.0 - scaling_threshold = 0.8 - reset_threshold = 0.0 - default_threshold = 1.0 - activation = 'STDB' - architecture = 'VGG16' - pretrained = True - pretrained_state = './snn_vgg16_imagenet.pth' - - normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - trainset = datasets.ImageFolder( - traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - testset = datasets.ImageFolder( - valdir, - transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])) - +def create_model(timesteps=500, leak_mem=1.0, scaling_threshold=0.8, reset_threshold=0.0, default_threshold=1.0, + activation='STDB', architecture='VGG16', labels=1000, pretrained=True): model = VGG_SNN_STDB(vgg_name=architecture, activation=activation, labels=labels, timesteps=timesteps, leak_mem=leak_mem) - if pretrained: + pretrained_state = './snn_vgg16_imagenet.pth' state = torch.load(pretrained_state, map_location='cpu') model.load_state_dict(state['state_dict']) - model = nn.DataParallel(model) - - if torch.cuda.is_available() and use_cuda: + if torch.cuda.is_available(): model.cuda() - # VGG16 Imagenet thresholds ann_thresholds = [10.16, 11.49, 2.65, 2.30, 0.77, 2.75, 1.33, 0.67, 1.13, 1.12, 0.43, 0.73, 1.08, 0.16, 0.58] model.module.threshold_init(scaling_threshold=scaling_threshold, reset_threshold=reset_threshold, thresholds=ann_thresholds[:], default_threshold=default_threshold) model.eval() model.module.network_init(timesteps) + + normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + transform = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + normalize, + ]) + + def load_and_preprocess(image_filepaths): + images = load_images(image_filepaths) + images = [transform(image) for image in images] + return np.concatenate(images) + + def get_activations(preprocessed_inputs, layer_names): + ... # TODO + + model = ActivationsExtractorHelper(identifier='spiking-vgg', + get_activations=get_activations, preprocessing=load_and_preprocess, + batch_size=35) + return model diff --git a/candidate_models/model_commitments/ml_pool.py b/candidate_models/model_commitments/ml_pool.py index ae884e1..e2963e8 100644 --- a/candidate_models/model_commitments/ml_pool.py +++ b/candidate_models/model_commitments/ml_pool.py @@ -171,6 +171,7 @@ def __init__(self): 'resnext101_32x32d_wsl': self._resnext101_layers(), 'resnext101_32x48d_wsl': self._resnext101_layers(), 'fixres_resnext101_32x48d_wsl': self._resnext101_layers(), + 'spiking-vgg16': ['spike'], # TODO } for basemodel_identifier, default_layers in layers.items(): self[basemodel_identifier] = default_layers diff --git a/tests/test_imagenet.py b/tests/test_imagenet.py index db47ae5..f66a7a8 100644 --- a/tests/test_imagenet.py +++ b/tests/test_imagenet.py @@ -95,6 +95,8 @@ class TestImagenet: ('resnext101_32x48d_wsl', .854), # FixRes: from https://arxiv.org/pdf/1906.06423.pdf, Table 8 ('fixres_resnext101_32x48d_wsl', .863), + # spiking-vgg: from + ('spiking-vgg16', ...), ]) def test_top1(self, model, expected_top1): # clear tf graph From 246a25905444acedb4cd1c9187d3c91cfbffbe05 Mon Sep 17 00:00:00 2001 From: nitin-rathi Date: Mon, 30 Sep 2019 18:49:46 -0400 Subject: [PATCH 6/8] Added spike count for each neuron The variable spike_count counts the number of spikes for each neuron over all the time-steps --- candidate_models/base_models/spiking_vgg/model.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/candidate_models/base_models/spiking_vgg/model.py b/candidate_models/base_models/spiking_vgg/model.py index 956a089..b21068e 100644 --- a/candidate_models/base_models/spiking_vgg/model.py +++ b/candidate_models/base_models/spiking_vgg/model.py @@ -117,11 +117,13 @@ def neuron_init(self, x): self.mem = {} self.spike = {} self.mask = {} + self.spike_count = {} for l in range(len(self.features)): if isinstance(self.features[l], nn.Conv2d): self.mem[l] = torch.zeros(self.batch_size, self.features[l].out_channels, self.width, self.height) + self.spike_count[l] = torch.zeros(self.mem[l].size()) elif isinstance(self.features[l], nn.Dropout): self.mask[l] = self.features[l](torch.ones(self.mem[l - 2].shape)) @@ -136,6 +138,7 @@ def neuron_init(self, x): if isinstance(self.classifier[l], nn.Linear): self.mem[prev + l] = torch.zeros(self.batch_size, self.classifier[l].out_features) + self.spike_count[prev+l] = torch.zeros(self.mem[prev+l].size()) elif isinstance(self.classifier[l], nn.Dropout): self.mask[prev + l] = self.classifier[l](torch.ones(self.mem[prev + l - 2].shape)) @@ -159,7 +162,8 @@ def forward(self, x): mem_thr = (self.mem[l] / self.threshold[l]) - 1.0 out = self.act_func(mem_thr, (t - 1 - self.spike[l])) rst = self.threshold[l] * (mem_thr > 0).float() - self.spike[l] = self.spike[l].masked_fill(out.byte(), t - 1) + self.spike[l] = self.spike[l].masked_fill(out.bool(), t - 1) + self.spike_count[l][out.bool()] = self.spike_count[l][out.bool()] + 1 self.mem[l] = self.leak_mem * self.mem[l] + self.features[l](out_prev) - rst out_prev = out.clone() @@ -179,7 +183,8 @@ def forward(self, x): mem_thr = (self.mem[prev + l] / self.threshold[prev + l]) - 1.0 out = self.act_func(mem_thr, (t - 1 - self.spike[prev + l])) rst = self.threshold[prev + l] * (mem_thr > 0).float() - self.spike[prev + l] = self.spike[prev + l].masked_fill(out.byte(), t - 1) + self.spike[prev + l] = self.spike[prev + l].masked_fill(out.bool(), t - 1) + self.spike_count[prev+l][out.bool()] = self.spike_count[prev+l][out.bool()] + 1 self.mem[prev + l] = self.leak_mem * self.mem[prev + l] + self.classifier[l](out_prev) - rst out_prev = out.clone() From 9da15e1f951d90a5bc9d8a14f29d08409f0d76bf Mon Sep 17 00:00:00 2001 From: nitin-rathi Date: Mon, 30 Sep 2019 19:20:09 -0400 Subject: [PATCH 7/8] added ImageNet accuracy --- candidate_models/base_models/models.csv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candidate_models/base_models/models.csv b/candidate_models/base_models/models.csv index 0d9aea4..13a7855 100644 --- a/candidate_models/base_models/models.csv +++ b/candidate_models/base_models/models.csv @@ -1361,4 +1361,4 @@ fixres_resnext101_32x48d_wsl,http://openaccess.thecvf.com/content_ECCV_2018/html year = {2019}, month = {jun}, }",0.863, -spiking-vgg16,TODO: paper link,TODO: paper bibtex,TODO: ImageNet top-1,TODO: ImageNet top-5 +spiking-vgg16,TODO: paper link,TODO: paper bibtex,0.651,0.817 From c0b0032f5c88f451178047706fe7b936ed6ad37a Mon Sep 17 00:00:00 2001 From: nitin-rathi Date: Mon, 30 Sep 2019 19:31:05 -0400 Subject: [PATCH 8/8] added top-1 accuracy --- tests/test_imagenet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_imagenet.py b/tests/test_imagenet.py index f66a7a8..1a084d7 100644 --- a/tests/test_imagenet.py +++ b/tests/test_imagenet.py @@ -96,7 +96,7 @@ class TestImagenet: # FixRes: from https://arxiv.org/pdf/1906.06423.pdf, Table 8 ('fixres_resnext101_32x48d_wsl', .863), # spiking-vgg: from - ('spiking-vgg16', ...), + ('spiking-vgg16', .651), ]) def test_top1(self, model, expected_top1): # clear tf graph