Skip to content
This repository was archived by the owner on Jan 5, 2024. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions candidate_models/base_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def load_preprocess_images(image_filepaths):
return wrapper


def spiking_vgg16():
from .spiking_vgg import create_model
return create_model(architecture='VGG16')


class BaseModelPool(UniqueKeyDict):
"""
Provides a set of standard models.
Expand Down Expand Up @@ -273,6 +278,8 @@ def __init__(self):
'fixres_resnext101_32x48d_wsl': lambda: fixres(
'resnext101_32x48d_wsl',
'https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNeXt_101_32x48d.pth'),

'spiking-vgg16': spiking_vgg16,
}
# MobileNets
for version, multiplier, image_size in [
Expand Down
1 change: 1 addition & 0 deletions candidate_models/base_models/models.csv
Original file line number Diff line number Diff line change
Expand Up @@ -1361,3 +1361,4 @@ fixres_resnext101_32x48d_wsl,http://openaccess.thecvf.com/content_ECCV_2018/html
year = {2019},
month = {jun},
}",0.863,
spiking-vgg16,TODO: paper link,TODO: paper bibtex,0.651,0.817
47 changes: 47 additions & 0 deletions candidate_models/base_models/spiking_vgg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
import torch
from torch import nn
from torchvision import transforms

from model_tools.activations.core import ActivationsExtractorHelper
from model_tools.activations.pytorch import load_images
from .model import VGG_SNN_STDB


def create_model(timesteps=500, leak_mem=1.0, scaling_threshold=0.8, reset_threshold=0.0, default_threshold=1.0,
activation='STDB', architecture='VGG16', labels=1000, pretrained=True):
model = VGG_SNN_STDB(vgg_name=architecture, activation=activation, labels=labels, timesteps=timesteps,
leak_mem=leak_mem)
if pretrained:
pretrained_state = './snn_vgg16_imagenet.pth'
state = torch.load(pretrained_state, map_location='cpu')
model.load_state_dict(state['state_dict'])
model = nn.DataParallel(model)
if torch.cuda.is_available():
model.cuda()
# VGG16 Imagenet thresholds
ann_thresholds = [10.16, 11.49, 2.65, 2.30, 0.77, 2.75, 1.33, 0.67, 1.13, 1.12, 0.43, 0.73, 1.08, 0.16, 0.58]
model.module.threshold_init(scaling_threshold=scaling_threshold, reset_threshold=reset_threshold,
thresholds=ann_thresholds[:], default_threshold=default_threshold)
model.eval()
model.module.network_init(timesteps)

normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
normalize,
])

def load_and_preprocess(image_filepaths):
images = load_images(image_filepaths)
images = [transform(image) for image in images]
return np.concatenate(images)

def get_activations(preprocessed_inputs, layer_names):
... # TODO
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is meant to output activations (aka firing rates) for preprocessed images for a given set of layer names. For standard vgg, this would output activations at the different blocks. I don't know how exactly your spiking network works, but is there a way for you to store spikes and their timing information in response to images?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a spike_count variable in the forward method of the model to keep track of the number of spikes. Do you think that will work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that sounds good. As long as you can produce spike rates for a given millisecond time-bin, this should work


model = ActivationsExtractorHelper(identifier='spiking-vgg',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what the right name is

get_activations=get_activations, preprocessing=load_and_preprocess,
batch_size=35)
return model
198 changes: 198 additions & 0 deletions candidate_models/base_models/spiking_vgg/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import copy

import torch
import torch.nn as nn

torch.manual_seed(2)

cfg = {
'VGG5': [64, 'A', 128, 'D', 128, 'A'],
'VGG9': [64, 'A', 128, 'D', 128, 'A', 256, 'D', 256, 'A', 512, 'D', 512, 'D'],
'VGG11': [64, 'A', 128, 'D', 256, 'A', 512, 'D', 512, 'D', 512, 'A', 512, 'D', 512, 'D'],
'VGG16': [64, 'D', 64, 'A', 128, 'D', 128, 'A', 256, 'D', 256, 'D', 256, 'A', 512, 'D', 512, 'D', 512, 'A', 512,
'D', 512, 'D', 512, 'A']
}


class PoissonGenerator(nn.Module):

def __init__(self):
super().__init__()

def forward(self, input):
out = torch.mul(torch.le(torch.rand_like(input), torch.abs(input) * 1.0).float(), torch.sign(input))
return out


class STDPSpike(torch.autograd.Function):

@staticmethod
def forward(ctx, input, last_spike):
ctx.save_for_backward(last_spike)
out = torch.zeros_like(input).cuda()
out[input > 0] = 1.0
return out


class VGG_SNN_STDB(nn.Module):

def __init__(self, vgg_name, activation='STDB', labels=1000, timesteps=75, leak_mem=0.99, drop=0.2):
super().__init__()

self.timesteps = timesteps
self.vgg_name = vgg_name
self.labels = labels
self.leak_mem = leak_mem
if activation == 'Linear':
self.act_func = LinearSpike.apply
elif activation == 'STDB':
self.act_func = STDPSpike.apply
self.input_layer = PoissonGenerator()

self.features, self.classifier = self._make_layers(cfg[self.vgg_name])

def threshold_init(self, scaling_threshold=1.0, reset_threshold=0.0, thresholds=[], default_threshold=1.0):

# Initialize thresholds
self.scaling_threshold = scaling_threshold
self.reset_threshold = reset_threshold

self.threshold = {}
# print('\nThresholds:')

for pos in range(len(self.features)):
if isinstance(self.features[pos], nn.Conv2d):
self.threshold[pos] = thresholds.pop(
0) * self.scaling_threshold + self.reset_threshold * default_threshold
# print('\t Layer{} : {:.2f}'.format(pos, self.threshold[pos]))

prev = len(self.features)

for pos in range(len(self.classifier) - 1):
if isinstance(self.classifier[pos], nn.Linear):
self.threshold[prev + pos] = thresholds.pop(
0) * self.scaling_threshold + self.reset_threshold * default_threshold
# print('\t Layer{} : {:.2f}'.format(prev+pos, self.threshold[prev+pos]))

def _make_layers(self, cfg):
layers = []
in_channels = 3

for x in (cfg):
stride = 1

if x == 'A':
layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
elif x == 'D':
layers += [nn.Dropout(0.2)]
else:
layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1, stride=stride, bias=False),
nn.ReLU(inplace=True)
]
in_channels = x

features = nn.Sequential(*layers)

layers = []
layers += [nn.Linear(512 * 7 * 7, 4096, bias=False)]
layers += [nn.ReLU(inplace=True)]
layers += [nn.Dropout(0.2)]
layers += [nn.Linear(4096, 4096, bias=False)]
layers += [nn.ReLU(inplace=True)]
layers += [nn.Dropout(0.2)]
layers += [nn.Linear(4096, self.labels, bias=False)]

classifer = nn.Sequential(*layers)
return (features, classifer)

def network_init(self, update_interval):
self.update_interval = update_interval

def neuron_init(self, x):

self.batch_size = x.size(0)
self.width = x.size(2)
self.height = x.size(3)

self.mem = {}
self.spike = {}
self.mask = {}
self.spike_count = {}

for l in range(len(self.features)):

if isinstance(self.features[l], nn.Conv2d):
self.mem[l] = torch.zeros(self.batch_size, self.features[l].out_channels, self.width, self.height)
self.spike_count[l] = torch.zeros(self.mem[l].size())

elif isinstance(self.features[l], nn.Dropout):
self.mask[l] = self.features[l](torch.ones(self.mem[l - 2].shape))

elif isinstance(self.features[l], nn.AvgPool2d):
self.width = self.width // 2
self.height = self.height // 2

prev = len(self.features)

for l in range(len(self.classifier)):

if isinstance(self.classifier[l], nn.Linear):
self.mem[prev + l] = torch.zeros(self.batch_size, self.classifier[l].out_features)
self.spike_count[prev+l] = torch.zeros(self.mem[prev+l].size())

elif isinstance(self.classifier[l], nn.Dropout):
self.mask[prev + l] = self.classifier[l](torch.ones(self.mem[prev + l - 2].shape))

self.spike = copy.deepcopy(self.mem)
for key, values in self.spike.items():
for value in values:
value.fill_(-1000)

def forward(self, x):

self.neuron_init(x)

for t in range(self.update_interval):

out_prev = self.input_layer(x)

for l in range(len(self.features)):

if isinstance(self.features[l], (nn.Conv2d)):
mem_thr = (self.mem[l] / self.threshold[l]) - 1.0
out = self.act_func(mem_thr, (t - 1 - self.spike[l]))
rst = self.threshold[l] * (mem_thr > 0).float()
self.spike[l] = self.spike[l].masked_fill(out.bool(), t - 1)
self.spike_count[l][out.bool()] = self.spike_count[l][out.bool()] + 1

self.mem[l] = self.leak_mem * self.mem[l] + self.features[l](out_prev) - rst
out_prev = out.clone()

elif isinstance(self.features[l], nn.AvgPool2d):
out_prev = self.features[l](out_prev)

elif isinstance(self.features[l], nn.Dropout):
out_prev = out_prev * self.mask[l]

out_prev = out_prev.reshape(self.batch_size, -1)
prev = len(self.features)

for l in range(len(self.classifier) - 1):

if isinstance(self.classifier[l], (nn.Linear)):
mem_thr = (self.mem[prev + l] / self.threshold[prev + l]) - 1.0
out = self.act_func(mem_thr, (t - 1 - self.spike[prev + l]))
rst = self.threshold[prev + l] * (mem_thr > 0).float()
self.spike[prev + l] = self.spike[prev + l].masked_fill(out.bool(), t - 1)
self.spike_count[prev+l][out.bool()] = self.spike_count[prev+l][out.bool()] + 1

self.mem[prev + l] = self.leak_mem * self.mem[prev + l] + self.classifier[l](out_prev) - rst
out_prev = out.clone()

elif isinstance(self.classifier[l], nn.Dropout):
out_prev = out_prev * self.mask[prev + l]

# Compute the classification layer outputs
self.mem[prev + l + 1] = self.mem[prev + l + 1] + self.classifier[l + 1](out_prev)

return self.mem[prev + l + 1]
1 change: 1 addition & 0 deletions candidate_models/model_commitments/ml_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(self):
'resnext101_32x32d_wsl': self._resnext101_layers(),
'resnext101_32x48d_wsl': self._resnext101_layers(),
'fixres_resnext101_32x48d_wsl': self._resnext101_layers(),
'spiking-vgg16': ['spike'], # TODO
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define the names of the layers that you would like to test. Usually, we use the last defined layer for behavioral heads (e.g. classifier)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my model, the final layer does not consist of spiking neurons. In the final layer the membrane potential is simply accumulated and the cost function is defined on the accumulated potential. I am not sure if you need the spike counts or the membrane potential will work for testing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ideal layer here contains a general set of features that is broadly applicable, i.e. not just for ImageNet but also other tasks, and categories should be linearly decodable from that layer. In e.g. standard VGG, we use the last convolutional layer before the fully-connected, (in our mind) ImageNet-specific decoder.

}
for basemodel_identifier, default_layers in layers.items():
self[basemodel_identifier] = default_layers
Expand Down
2 changes: 2 additions & 0 deletions tests/test_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class TestImagenet:
('resnext101_32x48d_wsl', .854),
# FixRes: from https://arxiv.org/pdf/1906.06423.pdf, Table 8
('fixres_resnext101_32x48d_wsl', .863),
# spiking-vgg: from <TODO: paper source>
('spiking-vgg16', .651),
])
def test_top1(self, model, expected_top1):
# clear tf graph
Expand Down