This repository was archived by the owner on Jan 5, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 26
TemporalModelCommitments #3
Open
fksato
wants to merge
59
commits into
brain-score:master
Choose a base branch
from
fksato:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
55fa08b
add temporal model commitment
fksato 4189ae4
complete temporal map
fksato 64790c7
complete temporal map
fksato b7f2480
complete temporal map
fksato e418b30
complete temporal map
fksato dc4c152
code clean-up for TemporalModelCommitments
fksato f5ecd21
code clean-up TemporalModelCommitment
fksato 7418a86
code clean-up TemporalModelCommitment
fksato 4aee6bb
code clean-up TemporalModelCommitment
fksato 76e9c8b
code clean-up TemporalModelCommitment
fksato 0e245b4
code clean-up TemporalModelCommitment
fksato dc16eb6
code clean-up TemporalModelCommitment
fksato 8d272c9
code clean-up TemporalModelCommitment
fksato 40fd5d1
code clean-up TemporalModelCommitment
fksato 262e0ae
code clean-up TemporalModelCommitment
fksato 79a4a63
code clean-up TemporalModelCommitment
fksato 45056d6
code clean-up TemporalModelCommitment
fksato 1ece6b0
code clean-up TemporalModelCommitment
fksato 0060c57
code clean-up TemporalModelCommitment
fksato 163bb74
code clean-up TemporalModelCommitment
fksato e97ca1e
code clean-up TemporalModelCommitment
fksato a34cf2f
code clean-up TemporalModelCommitment
fksato 25bac61
code clean-up TemporalModelCommitment
fksato 815db33
code clean-up TemporalModelCommitment
fksato 5e8fe13
code clean-up TemporalModelCommitment
fksato 1ef87c6
code clean-up TemporalModelCommitment
fksato ac30543
code clean-up TemporalModelCommitment
fksato d5dfc33
code clean-up TemporalModelCommitment
fksato 2cb3219
add result caching parameters to temporal model commitments
fksato 135c93a
add result caching parameters to temporal model commitments
fksato f6e8037
add result caching parameters to temporal model commitments
fksato 95ebcbf
add result caching parameters to temporal model commitments
fksato bc134db
add result caching parameters to temporal model commitments
fksato ad01a0a
add result caching parameters to temporal model commitments
fksato f2b75e5
implement result caching for temporal model commitment data
fksato 203deb4
code/repository clean up
fksato fe76c1c
add temporal_map.py
fksato 9b6913f
add temporal_maps unit test
fksato 250a679
fix merge conflict, add temporal_maps pytest
fksato 0471651
Merge branch 'master' of github.com:brain-score/model-tools
fksato bfd21f9
merge with upstream
fksato f88c2b6
Merge branch 'master' of github.com:brain-score/model-tools
fksato d69b0a7
Delete .gitignore
fksato 66dcb79
Delete conftest.cpython-36-PYTEST.pyc
fksato 5f20b6e
Delete __init__.cpython-36.pyc
fksato e540772
first pass code cleanup
fksato 849b86a
add testing assemblies for temporal tests
fksato 2cea388
Merge remote-tracking branch 'upstream/master'
fksato ca3782e
add simplified temporal assemblies
fksato 521181c
add pls_regression comparison to test temporal maps
fksato 1ba7cbd
add correct stimulus set paths to temporal testing assemblies
fksato bc7e9f9
add correct stimulus set paths to temporal testing assemblies
fksato c5d20ef
Merge remote-tracking branch 'upstream/master'
fksato 9396676
remove unnecessary imports
fksato 07a8238
Merge remote-tracking branch 'upstream/master'
fksato 0ecb126
add new assembly/stimulus for temporal mapping tests
fksato 11454ea
finish pytest
fksato 115b03e
add absolute image paths to testing stimulus set
fksato d4c8d54
rename pytorch model in temporal testing
fksato File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| from typing import Optional | ||
|
|
||
| from brainio_base.assemblies import merge_data_arrays | ||
|
|
||
| from model_tools.brain_transformation import ModelCommitment | ||
|
|
||
| from brainscore.model_interface import BrainModel | ||
| from brainscore.metrics.regression import pls_regression | ||
|
|
||
| from result_caching import store, store_dict | ||
|
|
||
| class TemporalModelCommitment(BrainModel): | ||
| def __init__(self, identifier, base_model, layers, region_layer_map: Optional[dict] = None): | ||
| self.layers = layers | ||
| self.identifier = identifier | ||
| self.base_model = base_model | ||
| # | ||
| self.model_commitment = ModelCommitment(self.identifier, self.base_model, self.layers) | ||
| self.commit_region = self.model_commitment.commit_region | ||
| self.region_assemblies = self.model_commitment.region_assemblies | ||
| self.region_layer_map = self.model_commitment.layer_model.region_layer_map | ||
| self.recorded_regions = [] | ||
|
|
||
| self.time_bins = None | ||
| self._temporal_maps = {} | ||
| self._layer_regions = None | ||
|
|
||
| def make_temporal(self, assembly): | ||
| if not self.region_layer_map: | ||
| for region in self.region_assemblies.keys(): | ||
| self.model_commitment.do_commit_region(region) | ||
| # assert self.region_layer_map # force commit_region to come before | ||
| assert len(set(assembly.time_bin.values)) > 1 # force temporal recordings/assembly | ||
|
|
||
| temporal_mapped_regions = set(assembly['region'].values) | ||
|
|
||
| temporal_mapped_regions = list(set(self.region_layer_map.keys()).intersection(self.region_layer_map.keys())) | ||
| layer_regions = {self.region_layer_map[region]: region for region in temporal_mapped_regions} | ||
|
|
||
| stimulus_set = assembly.stimulus_set | ||
|
|
||
| activations = self.base_model(stimulus_set, layers=list(layer_regions.keys())) | ||
| activations = self._set_region_coords(activations, layer_regions) | ||
|
|
||
| self._temporal_maps = self._set_temporal_maps(self.identifier, temporal_mapped_regions, activations, assembly) | ||
|
|
||
| def look_at(self, stimuli): | ||
| layer_regions = {self.region_layer_map[region]: region for region in self.recorded_regions} | ||
| assert len(layer_regions) == len(self.recorded_regions), f"duplicate layers for {self.recorded_regions}" | ||
| activations = self.base_model(stimuli, layers=list(layer_regions.keys())) | ||
|
|
||
| activations = self._set_region_coords(activations ,layer_regions) | ||
| return self._temporal_activations(self.identifier, activations) | ||
|
|
||
| @store(identifier_ignore=['assembly']) | ||
| def _temporal_activations(self, identifier, assembly): | ||
| temporal_assembly = [] | ||
| for region in self.recorded_regions: | ||
| temporal_regressors = self._temporal_maps[region] | ||
| region_activations = assembly.sel(region=region) | ||
| for time_bin in self.time_bins: | ||
| regressor = temporal_regressors[time_bin] | ||
| regressed_act = regressor.predict(region_activations) | ||
| regressed_act = self._package_temporal(time_bin, region, regressed_act) | ||
| temporal_assembly.append(regressed_act) | ||
| temporal_assembly = merge_data_arrays(temporal_assembly) | ||
| return temporal_assembly | ||
|
|
||
| @store_dict(dict_key='temporal_mapped_regions', identifier_ignore=['temporal_mapped_regions', 'activations' ,'assembly']) | ||
| def _set_temporal_maps(self, identifier, temporal_mapped_regions, activations, assembly): | ||
| temporal_maps = {} | ||
| for region in temporal_mapped_regions: | ||
| time_bin_regressor = {} | ||
| region_activations = activations.sel(region=region) | ||
| for time_bin in assembly.time_bin.values: | ||
| target_assembly = assembly.sel(region=region, time_bin=time_bin) | ||
| regressor = pls_regression(neuroid_coord=('neuroid_id' ,'layer' ,'region')) | ||
| regressor.fit(region_activations, target_assembly) | ||
| time_bin_regressor[time_bin] = regressor | ||
| temporal_maps[region] = time_bin_regressor | ||
| return temporal_maps | ||
|
|
||
| def _set_region_coords(self, activations, layer_regions): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll work on a fix for this soon, so that you hopefully won't need this here anymore (brain-score/brainio_base#1) |
||
| coords = { 'region' : (('neuroid'), [layer_regions[layer] for layer in activations['layer'].values]) } | ||
| activations = activations.assign_coords(**coords) | ||
| activations = activations.set_index({'neuroid' :'region'}, append=True) | ||
| return activations | ||
|
|
||
| def _package_temporal(self, time_bin, region, assembly): | ||
| assert len(time_bin) == 2 | ||
| assembly = assembly.expand_dims('time_bin', axis=-1) | ||
| coords = { | ||
| 'time_bin_start': (('time_bin'), [time_bin[0]]) | ||
| , 'time_bin_end': (('time_bin'), [time_bin[1]]) | ||
| , 'region' : (('neuroid'), [region] * assembly.shape[1]) | ||
| } | ||
| assembly = assembly.assign_coords(**coords) | ||
| assembly = assembly.set_index(time_bin=['time_bin_start', 'time_bin_end'], neuroid='region', append=True) | ||
| return assembly | ||
|
|
||
| def start_recording(self, recording_target, time_bins: Optional[list] = None): | ||
| self.model_commitment.start_recording(recording_target) | ||
| assert self._temporal_maps | ||
| assert self.region_layer_map | ||
| assert recording_target in self._temporal_maps.keys() | ||
| if self.time_bins is None: | ||
| self.time_bins = self._temporal_maps[recording_target].keys() | ||
| else: | ||
| assert set(self._temporal_maps[recording_target].keys()).issuperset(set(time_bins)) | ||
| self.recorded_regions = [recording_target] | ||
| self.time_bins = time_bins | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| import pytest | ||
| import functools | ||
| import numpy as np | ||
| from os import path | ||
|
|
||
| from model_tools.activations import PytorchWrapper | ||
| from model_tools.brain_transformation.temporal_map import TemporalModelCommitment | ||
|
|
||
| from brainscore.metrics.regression import pls_regression | ||
| from brainscore.assemblies.public import load_assembly | ||
|
|
||
| from xarray import DataArray | ||
| from pandas import DataFrame | ||
|
|
||
| from brainio_base.stimuli import StimulusSet | ||
| from brainio_base.assemblies import NeuronRecordingAssembly | ||
|
|
||
| def load_test_assemblies(variation, region): | ||
| image_dir = path.join(path.dirname(path.abspath(__file__)), 'test_temporal_stimulus') | ||
| if type(variation) is not list: | ||
| variation = [variation] | ||
|
|
||
| num_stim = 5 | ||
| neuroid_cnt = 168 | ||
| time_bin_cnt = 5 | ||
| resp = np.random.rand(num_stim, neuroid_cnt, time_bin_cnt) | ||
|
|
||
| dims = ['presentation', 'neuroid', 'time_bin'] | ||
| coords = { | ||
| 'image_id': ('presentation', range(num_stim)), | ||
| 'y': ('presentation', range(num_stim)), | ||
| 'neuroid_id': ('neuroid', [f'{i}' for i in range(neuroid_cnt)]), | ||
| 'region': ('neuroid', ['IT'] * neuroid_cnt), | ||
| 'x': ('neuroid', range(neuroid_cnt)), | ||
| 'time_bin_start': ('time_bin', range(-10, 40, 10)), | ||
| 'time_bin_end': ('time_bin', range(0, 50, 10)) | ||
| } | ||
|
|
||
| assembly = DataArray(data=resp, dims=dims, coords=coords) | ||
| assembly = assembly.set_index(presentation=['image_id', 'y'], | ||
| neuroid=['neuroid_id','region', 'x'], | ||
| time_bin=['time_bin_start', 'time_bin_end'], | ||
| append=True) | ||
|
|
||
| stim_meta = [{'id': k} for k in range(num_stim)] | ||
| image_paths = {} | ||
| for i in range(num_stim): | ||
| f_name = f"im_{i:05}.jpg" | ||
| im_path = path.join(image_dir, f_name) | ||
|
|
||
| meta = stim_meta[i] | ||
| meta['image_id'] = f'{i}' | ||
| meta['image_file_name'] = f_name | ||
| image_paths[f'{i}'] = im_path | ||
|
|
||
| stim_set = DataFrame(stim_meta) | ||
|
|
||
| stim_set = StimulusSet(stim_set) | ||
| stim_set.image_paths = image_paths | ||
| stim_set.name = f'testing_temporal_stims_{region}_var{"".join(str(v) for v in variation)}' | ||
|
|
||
| assembly = NeuronRecordingAssembly(assembly) | ||
|
|
||
| assembly.attrs['stimulus_set'] = stim_set | ||
| assembly.attrs['stimulus_set_name'] = stim_set.name | ||
| return assembly | ||
|
|
||
| def pytorch_custom(): | ||
mschrimpf marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import torch | ||
| from torch import nn | ||
| from model_tools.activations.pytorch import load_preprocess_images | ||
|
|
||
| class MyModel_Temporal(nn.Module): | ||
| def __init__(self): | ||
| super(MyModel_Temporal, self).__init__() | ||
| self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3) | ||
| self.relu1 = torch.nn.ReLU() | ||
| linear_input_size = np.power((224 - 3 + 2 * 0) / 1 + 1, 2) * 2 | ||
| self.linear = torch.nn.Linear(int(linear_input_size), 1000) | ||
| self.relu2 = torch.nn.ReLU() # can't get named ReLU output otherwise | ||
|
|
||
| def forward(self, x): | ||
| x = self.conv1(x) | ||
| x = self.relu1(x) | ||
| x = x.view(x.size(0), -1) | ||
| x = self.linear(x) | ||
| x = self.relu2(x) | ||
| return x | ||
|
|
||
| preprocessing = functools.partial(load_preprocess_images, image_size=224) | ||
| return PytorchWrapper(model=MyModel_Temporal(), preprocessing=preprocessing) | ||
|
|
||
| class TestTemporalModelCommitment: | ||
| test_data = [(pytorch_custom, ['linear', 'relu2'], 'IT')] | ||
| @pytest.mark.parametrize("model_ctr, layers, region", test_data) | ||
| def test(self, model_ctr, layers, region): | ||
| commit_assembly = load_assembly(name='dicarlo.Majaj2015.lowvar.IT', | ||
| **{'average_repetition': False}) | ||
|
|
||
| training_assembly = load_test_assemblies([0,3], region) | ||
| validation_assembly = load_test_assemblies(6, region) | ||
|
|
||
| expected_region = region if type(region)==list else [region] | ||
| expected_region_count = len(expected_region) | ||
| expected_time_bin_count = len(training_assembly.time_bin.values) | ||
|
|
||
| extractor = model_ctr() | ||
|
|
||
| t_bins = [t for t in training_assembly.time_bin.values if 0 <= t[0] < 30] | ||
| expected_recorded_time_count = len(t_bins) | ||
|
|
||
| temporal_model = TemporalModelCommitment('', extractor, layers) | ||
| # commit region: | ||
| temporal_model.commit_region(region, commit_assembly) | ||
| # make temporal: | ||
| temporal_model.make_temporal(training_assembly) | ||
| assert len(temporal_model._temporal_maps.keys()) == expected_region_count | ||
| assert len(temporal_model._temporal_maps[region].keys()) == expected_time_bin_count | ||
| # start recording: | ||
| temporal_model.start_recording(region, t_bins) | ||
| assert temporal_model.recorded_regions == expected_region | ||
| # look at: | ||
| stim = validation_assembly.stimulus_set | ||
| temporal_activations = temporal_model.look_at(stim) | ||
| assert set(temporal_activations.region.values) == set(expected_region) | ||
| assert len(set(temporal_activations.time_bin.values)) == expected_recorded_time_count | ||
| # | ||
| test_layer = temporal_model.region_layer_map[region] | ||
| train_stim_set = training_assembly.stimulus_set | ||
| for time_test in t_bins: | ||
| target_assembly = training_assembly.sel(time_bin=time_test, region=region) | ||
| region_activations = extractor(train_stim_set, [test_layer]) | ||
| regressor = pls_regression(neuroid_coord=('neuroid_id', 'layer', 'region')) | ||
| regressor.fit(region_activations, target_assembly) | ||
| # | ||
| test_activations = extractor(stim, [test_layer]) | ||
| test_predictions = regressor.predict(test_activations).values | ||
| # | ||
| temporal_model_prediction = temporal_activations.sel(region=region, time_bin=time_test).values | ||
| assert temporal_model_prediction == pytest.approx(test_predictions, rel=1e-3, abs=1e-6) | ||
|
|
||
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.