-
Notifications
You must be signed in to change notification settings - Fork 20
added snn files #41
base: master
Are you sure you want to change the base?
added snn files #41
Changes from all commits
6c83601
01efd59
c94312b
c998bfe
64a0aa9
246a259
9da15e1
c0b0032
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| import numpy as np | ||
| import torch | ||
| from torch import nn | ||
| from torchvision import transforms | ||
|
|
||
| from model_tools.activations.core import ActivationsExtractorHelper | ||
| from model_tools.activations.pytorch import load_images | ||
| from .model import VGG_SNN_STDB | ||
|
|
||
|
|
||
| def create_model(timesteps=500, leak_mem=1.0, scaling_threshold=0.8, reset_threshold=0.0, default_threshold=1.0, | ||
| activation='STDB', architecture='VGG16', labels=1000, pretrained=True): | ||
| model = VGG_SNN_STDB(vgg_name=architecture, activation=activation, labels=labels, timesteps=timesteps, | ||
| leak_mem=leak_mem) | ||
| if pretrained: | ||
| pretrained_state = './snn_vgg16_imagenet.pth' | ||
| state = torch.load(pretrained_state, map_location='cpu') | ||
| model.load_state_dict(state['state_dict']) | ||
| model = nn.DataParallel(model) | ||
| if torch.cuda.is_available(): | ||
| model.cuda() | ||
| # VGG16 Imagenet thresholds | ||
| ann_thresholds = [10.16, 11.49, 2.65, 2.30, 0.77, 2.75, 1.33, 0.67, 1.13, 1.12, 0.43, 0.73, 1.08, 0.16, 0.58] | ||
| model.module.threshold_init(scaling_threshold=scaling_threshold, reset_threshold=reset_threshold, | ||
| thresholds=ann_thresholds[:], default_threshold=default_threshold) | ||
| model.eval() | ||
| model.module.network_init(timesteps) | ||
|
|
||
| normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | ||
| transform = transforms.Compose([ | ||
| transforms.Resize(224), | ||
| transforms.ToTensor(), | ||
| normalize, | ||
| ]) | ||
|
|
||
| def load_and_preprocess(image_filepaths): | ||
| images = load_images(image_filepaths) | ||
| images = [transform(image) for image in images] | ||
| return np.concatenate(images) | ||
|
|
||
| def get_activations(preprocessed_inputs, layer_names): | ||
| ... # TODO | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method is meant to output activations (aka firing rates) for preprocessed images for a given set of layer names. For standard vgg, this would output activations at the different blocks. I don't know how exactly your spiking network works, but is there a way for you to store spikes and their timing information in response to images?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added a spike_count variable in the forward method of the model to keep track of the number of spikes. Do you think that will work?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah that sounds good. As long as you can produce spike rates for a given millisecond time-bin, this should work |
||
|
|
||
| model = ActivationsExtractorHelper(identifier='spiking-vgg', | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure what the right name is |
||
| get_activations=get_activations, preprocessing=load_and_preprocess, | ||
| batch_size=35) | ||
| return model | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,198 @@ | ||
| import copy | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| torch.manual_seed(2) | ||
|
|
||
| cfg = { | ||
| 'VGG5': [64, 'A', 128, 'D', 128, 'A'], | ||
| 'VGG9': [64, 'A', 128, 'D', 128, 'A', 256, 'D', 256, 'A', 512, 'D', 512, 'D'], | ||
| 'VGG11': [64, 'A', 128, 'D', 256, 'A', 512, 'D', 512, 'D', 512, 'A', 512, 'D', 512, 'D'], | ||
| 'VGG16': [64, 'D', 64, 'A', 128, 'D', 128, 'A', 256, 'D', 256, 'D', 256, 'A', 512, 'D', 512, 'D', 512, 'A', 512, | ||
| 'D', 512, 'D', 512, 'A'] | ||
| } | ||
|
|
||
|
|
||
| class PoissonGenerator(nn.Module): | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def forward(self, input): | ||
| out = torch.mul(torch.le(torch.rand_like(input), torch.abs(input) * 1.0).float(), torch.sign(input)) | ||
| return out | ||
|
|
||
|
|
||
| class STDPSpike(torch.autograd.Function): | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, input, last_spike): | ||
| ctx.save_for_backward(last_spike) | ||
| out = torch.zeros_like(input).cuda() | ||
| out[input > 0] = 1.0 | ||
| return out | ||
|
|
||
|
|
||
| class VGG_SNN_STDB(nn.Module): | ||
|
|
||
| def __init__(self, vgg_name, activation='STDB', labels=1000, timesteps=75, leak_mem=0.99, drop=0.2): | ||
| super().__init__() | ||
|
|
||
| self.timesteps = timesteps | ||
| self.vgg_name = vgg_name | ||
| self.labels = labels | ||
| self.leak_mem = leak_mem | ||
| if activation == 'Linear': | ||
| self.act_func = LinearSpike.apply | ||
| elif activation == 'STDB': | ||
| self.act_func = STDPSpike.apply | ||
| self.input_layer = PoissonGenerator() | ||
|
|
||
| self.features, self.classifier = self._make_layers(cfg[self.vgg_name]) | ||
|
|
||
| def threshold_init(self, scaling_threshold=1.0, reset_threshold=0.0, thresholds=[], default_threshold=1.0): | ||
|
|
||
| # Initialize thresholds | ||
| self.scaling_threshold = scaling_threshold | ||
| self.reset_threshold = reset_threshold | ||
|
|
||
| self.threshold = {} | ||
| # print('\nThresholds:') | ||
|
|
||
| for pos in range(len(self.features)): | ||
| if isinstance(self.features[pos], nn.Conv2d): | ||
| self.threshold[pos] = thresholds.pop( | ||
| 0) * self.scaling_threshold + self.reset_threshold * default_threshold | ||
| # print('\t Layer{} : {:.2f}'.format(pos, self.threshold[pos])) | ||
|
|
||
| prev = len(self.features) | ||
|
|
||
| for pos in range(len(self.classifier) - 1): | ||
| if isinstance(self.classifier[pos], nn.Linear): | ||
| self.threshold[prev + pos] = thresholds.pop( | ||
| 0) * self.scaling_threshold + self.reset_threshold * default_threshold | ||
| # print('\t Layer{} : {:.2f}'.format(prev+pos, self.threshold[prev+pos])) | ||
|
|
||
| def _make_layers(self, cfg): | ||
| layers = [] | ||
| in_channels = 3 | ||
|
|
||
| for x in (cfg): | ||
| stride = 1 | ||
|
|
||
| if x == 'A': | ||
| layers += [nn.AvgPool2d(kernel_size=2, stride=2)] | ||
| elif x == 'D': | ||
| layers += [nn.Dropout(0.2)] | ||
| else: | ||
| layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1, stride=stride, bias=False), | ||
| nn.ReLU(inplace=True) | ||
| ] | ||
| in_channels = x | ||
|
|
||
| features = nn.Sequential(*layers) | ||
|
|
||
| layers = [] | ||
| layers += [nn.Linear(512 * 7 * 7, 4096, bias=False)] | ||
| layers += [nn.ReLU(inplace=True)] | ||
| layers += [nn.Dropout(0.2)] | ||
| layers += [nn.Linear(4096, 4096, bias=False)] | ||
| layers += [nn.ReLU(inplace=True)] | ||
| layers += [nn.Dropout(0.2)] | ||
| layers += [nn.Linear(4096, self.labels, bias=False)] | ||
|
|
||
| classifer = nn.Sequential(*layers) | ||
| return (features, classifer) | ||
|
|
||
| def network_init(self, update_interval): | ||
| self.update_interval = update_interval | ||
|
|
||
| def neuron_init(self, x): | ||
|
|
||
| self.batch_size = x.size(0) | ||
| self.width = x.size(2) | ||
| self.height = x.size(3) | ||
|
|
||
| self.mem = {} | ||
| self.spike = {} | ||
| self.mask = {} | ||
| self.spike_count = {} | ||
|
|
||
| for l in range(len(self.features)): | ||
|
|
||
| if isinstance(self.features[l], nn.Conv2d): | ||
| self.mem[l] = torch.zeros(self.batch_size, self.features[l].out_channels, self.width, self.height) | ||
| self.spike_count[l] = torch.zeros(self.mem[l].size()) | ||
|
|
||
| elif isinstance(self.features[l], nn.Dropout): | ||
| self.mask[l] = self.features[l](torch.ones(self.mem[l - 2].shape)) | ||
|
|
||
| elif isinstance(self.features[l], nn.AvgPool2d): | ||
| self.width = self.width // 2 | ||
| self.height = self.height // 2 | ||
|
|
||
| prev = len(self.features) | ||
|
|
||
| for l in range(len(self.classifier)): | ||
|
|
||
| if isinstance(self.classifier[l], nn.Linear): | ||
| self.mem[prev + l] = torch.zeros(self.batch_size, self.classifier[l].out_features) | ||
| self.spike_count[prev+l] = torch.zeros(self.mem[prev+l].size()) | ||
|
|
||
| elif isinstance(self.classifier[l], nn.Dropout): | ||
| self.mask[prev + l] = self.classifier[l](torch.ones(self.mem[prev + l - 2].shape)) | ||
|
|
||
| self.spike = copy.deepcopy(self.mem) | ||
| for key, values in self.spike.items(): | ||
| for value in values: | ||
| value.fill_(-1000) | ||
|
|
||
| def forward(self, x): | ||
|
|
||
| self.neuron_init(x) | ||
|
|
||
| for t in range(self.update_interval): | ||
|
|
||
| out_prev = self.input_layer(x) | ||
|
|
||
| for l in range(len(self.features)): | ||
|
|
||
| if isinstance(self.features[l], (nn.Conv2d)): | ||
| mem_thr = (self.mem[l] / self.threshold[l]) - 1.0 | ||
| out = self.act_func(mem_thr, (t - 1 - self.spike[l])) | ||
| rst = self.threshold[l] * (mem_thr > 0).float() | ||
| self.spike[l] = self.spike[l].masked_fill(out.bool(), t - 1) | ||
| self.spike_count[l][out.bool()] = self.spike_count[l][out.bool()] + 1 | ||
|
|
||
| self.mem[l] = self.leak_mem * self.mem[l] + self.features[l](out_prev) - rst | ||
| out_prev = out.clone() | ||
|
|
||
| elif isinstance(self.features[l], nn.AvgPool2d): | ||
| out_prev = self.features[l](out_prev) | ||
|
|
||
| elif isinstance(self.features[l], nn.Dropout): | ||
| out_prev = out_prev * self.mask[l] | ||
|
|
||
| out_prev = out_prev.reshape(self.batch_size, -1) | ||
mschrimpf marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| prev = len(self.features) | ||
|
|
||
| for l in range(len(self.classifier) - 1): | ||
|
|
||
| if isinstance(self.classifier[l], (nn.Linear)): | ||
| mem_thr = (self.mem[prev + l] / self.threshold[prev + l]) - 1.0 | ||
| out = self.act_func(mem_thr, (t - 1 - self.spike[prev + l])) | ||
| rst = self.threshold[prev + l] * (mem_thr > 0).float() | ||
| self.spike[prev + l] = self.spike[prev + l].masked_fill(out.bool(), t - 1) | ||
| self.spike_count[prev+l][out.bool()] = self.spike_count[prev+l][out.bool()] + 1 | ||
|
|
||
| self.mem[prev + l] = self.leak_mem * self.mem[prev + l] + self.classifier[l](out_prev) - rst | ||
| out_prev = out.clone() | ||
|
|
||
| elif isinstance(self.classifier[l], nn.Dropout): | ||
| out_prev = out_prev * self.mask[prev + l] | ||
|
|
||
| # Compute the classification layer outputs | ||
| self.mem[prev + l + 1] = self.mem[prev + l + 1] + self.classifier[l + 1](out_prev) | ||
|
|
||
| return self.mem[prev + l + 1] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -171,6 +171,7 @@ def __init__(self): | |
| 'resnext101_32x32d_wsl': self._resnext101_layers(), | ||
| 'resnext101_32x48d_wsl': self._resnext101_layers(), | ||
| 'fixres_resnext101_32x48d_wsl': self._resnext101_layers(), | ||
| 'spiking-vgg16': ['spike'], # TODO | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. define the names of the layers that you would like to test. Usually, we use the last defined layer for behavioral heads (e.g. classifier)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my model, the final layer does not consist of spiking neurons. In the final layer the membrane potential is simply accumulated and the cost function is defined on the accumulated potential. I am not sure if you need the spike counts or the membrane potential will work for testing
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the ideal layer here contains a general set of features that is broadly applicable, i.e. not just for ImageNet but also other tasks, and categories should be linearly decodable from that layer. In e.g. standard VGG, we use the last convolutional layer before the fully-connected, (in our mind) ImageNet-specific decoder. |
||
| } | ||
| for basemodel_identifier, default_layers in layers.items(): | ||
| self[basemodel_identifier] = default_layers | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.