Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions brainscore_vision/models/AT_efficientnet_b4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from brainscore_vision import model_registry
from .model import get_layers,get_model


model_registry['AT_efficientnet-b4'] = \
lambda: ModelCommitment(identifier='AT_efficientnet-b4', activations_model=get_model('AT_efficientnet-b4'), layers=get_layers('AT_efficientnet-b4'))
67 changes: 67 additions & 0 deletions brainscore_vision/models/AT_efficientnet_b4/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import functools
from efficientnet_pytorch import EfficientNet
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images
from brainscore_vision.model_helpers.check_submission import check_models


def get_model(name):
assert name == 'AT_efficientnet-b4'
model = EfficientNet.from_pretrained("efficientnet-b4", advprop=True)
model.set_swish(memory_efficient=False)
preprocessing = functools.partial(load_preprocess_images, image_size=224, normalize_mean=(0.5, 0.5, 0.5), normalize_std=(0.5, 0.5, 0.5))
wrapper = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing)
from types import MethodType
def _output_layer(self):
return self._model._fc

wrapper._output_layer = MethodType(_output_layer, wrapper)
wrapper.image_size = 224
return wrapper


def get_layers(name):
assert name == 'AT_efficientnet-b4'
return [
'_blocks.0',
'_blocks.1',
'_blocks.2',
'_blocks.3',
'_blocks.4',
'_blocks.5',
'_blocks.6',
'_blocks.7',
'_blocks.8',
'_blocks.9',
'_blocks.10',
'_blocks.11',
'_blocks.12',
'_blocks.13',
'_blocks.14',
'_blocks.15',
'_blocks.16',
'_blocks.17',
'_blocks.18',
'_blocks.19',
'_blocks.20',
'_blocks.21',
'_blocks.22',
'_blocks.23',
'_blocks.24',
'_blocks.25',
'_blocks.26',
'_blocks.27',
'_blocks.28',
'_blocks.29',
'_blocks.30',
'_blocks.31',
]

def get_bibtex(model_identifier):
"""
A method returning the bibtex reference of the requested model as a string.
"""
return ''

if __name__ == '__main__':
check_models.check_base_models(__name__)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"IT": "_blocks.22",
"V4": "_blocks.10",
"V2": "_blocks.10",
"V1": "_blocks.10"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
efficientnet_pytorch
8 changes: 8 additions & 0 deletions brainscore_vision/models/AT_efficientnet_b4/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import brainscore_vision
import pytest


@pytest.mark.travis_slow
def test_has_identifier():
model = brainscore_vision.load_model('AT_efficientnet-b4')
assert model.identifier == 'AT_efficientnet-b4'
Loading