Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion brainscore_vision/benchmark_helpers/neural_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@


class NeuralBenchmark(BenchmarkBase):
def __init__(self, identifier, assembly, similarity_metric, visual_degrees, number_of_trials, **kwargs):
def __init__(self, identifier, assembly, similarity_metric,
visual_degrees, number_of_trials,
inter_subject_ceiling_func=None,
**kwargs):
super(NeuralBenchmark, self).__init__(identifier=identifier, **kwargs)
self._assembly = assembly
self._similarity_metric = similarity_metric
Expand All @@ -19,6 +22,7 @@ def __init__(self, identifier, assembly, similarity_metric, visual_degrees, numb
self.timebins = timebins
self._visual_degrees = visual_degrees
self._number_of_trials = number_of_trials
self._inter_subject_ceiling_func = inter_subject_ceiling_func

def __call__(self, candidate: BrainModel):
candidate.start_recording(self.region, time_bins=self.timebins)
Expand All @@ -31,6 +35,10 @@ def __call__(self, candidate: BrainModel):
ceiled_score = explained_variance(raw_score, self.ceiling)
return ceiled_score

@property
def inter_subject_ceiling(self):
return self._inter_subject_ceiling_func()


def timebins_from_assembly(assembly):
timebins = assembly['time_bin'].values
Expand Down Expand Up @@ -63,6 +71,7 @@ def avg_repr(assembly):

return apply_keep_attrs(assembly, avg_repr)


def apply_keep_attrs(assembly, fnc): # workaround to keeping attrs
attrs = assembly.attrs
assembly = fnc(assembly)
Expand Down
19 changes: 11 additions & 8 deletions brainscore_vision/benchmarks/majajhong2015/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,52 +26,55 @@
spantime_pls_metric = lambda: load_metric('spantime_pls', crossvalidation_kwargs=crossvalidation_kwargs)

def _DicarloMajajHong2015Region(region: str, access: str, identifier_metric_suffix: str,
similarity_metric: Metric, ceiler: Ceiling, time_interval: float = None):
similarity_metric: Metric, intra_subject_ceiler: Ceiling, time_interval: float = None):
assembly_repetition = load_assembly(average_repetitions=False, region=region, access=access, time_interval=time_interval)
assembly = load_assembly(average_repetitions=True, region=region, access=access, time_interval=time_interval)
inter_subject_ceiler = load_ceiling('inter_subject_consistency',
metric=similarity_metric, subject_column='animal')
benchmark_identifier = f'MajajHong2015.{region}' + ('.public' if access == 'public' else '')
return NeuralBenchmark(identifier=f'{benchmark_identifier}-{identifier_metric_suffix}', version=3,
assembly=assembly, similarity_metric=similarity_metric,
visual_degrees=VISUAL_DEGREES, number_of_trials=NUMBER_OF_TRIALS,
ceiling_func=lambda: ceiler(assembly_repetition),
ceiling_func=lambda: intra_subject_ceiler(assembly_repetition),
inter_subject_ceiling_func=lambda: inter_subject_ceiler(assembly),
parent=region,
bibtex=BIBTEX)


def DicarloMajajHong2015V4PLS():
return _DicarloMajajHong2015Region(region='V4', access='private', identifier_metric_suffix='pls',
similarity_metric=pls_metric(),
ceiler=load_ceiling('internal_consistency'))
intra_subject_ceiler=load_ceiling('internal_consistency'))


def DicarloMajajHong2015ITPLS():
return _DicarloMajajHong2015Region(region='IT', access='private', identifier_metric_suffix='pls',
similarity_metric=pls_metric(),
ceiler=load_ceiling('internal_consistency'))
intra_subject_ceiler=load_ceiling('internal_consistency'))


def MajajHongV4PublicBenchmark():
return _DicarloMajajHong2015Region(region='V4', access='public', identifier_metric_suffix='pls',
similarity_metric=pls_metric(),
ceiler=load_ceiling('internal_consistency'))
intra_subject_ceiler=load_ceiling('internal_consistency'))


def MajajHongITPublicBenchmark():
return _DicarloMajajHong2015Region(region='IT', access='public', identifier_metric_suffix='pls',
similarity_metric=pls_metric(),
ceiler=load_ceiling('internal_consistency'))
intra_subject_ceiler=load_ceiling('internal_consistency'))


def MajajHongV4TemporalPublicBenchmark(time_interval: float = None):
return _DicarloMajajHong2015Region(region='V4', access='public', identifier_metric_suffix='pls',
similarity_metric=spantime_pls_metric(), time_interval=time_interval,
ceiler=load_ceiling('internal_consistency_temporal'))
intra_subject_ceiler=load_ceiling('internal_consistency_temporal'))


def MajajHongITTemporalPublicBenchmark(time_interval: float = None):
return _DicarloMajajHong2015Region(region='IT', access='public', identifier_metric_suffix='pls',
similarity_metric=spantime_pls_metric(), time_interval=time_interval,
ceiler=load_ceiling('internal_consistency_temporal'))
intra_subject_ceiler=load_ceiling('internal_consistency_temporal'))


def load_assembly(average_repetitions: bool, region: str, access: str = 'private', time_interval: float = None):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from brainscore_vision import metric_registry
from .ceiling import InterSubjectConsistency

metric_registry['inter_subject_consistency'] = InterSubjectConsistency
41 changes: 41 additions & 0 deletions brainscore_vision/metrics/inter_subject_consistency/ceiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from brainio.assemblies import DataAssembly
from brainscore_core import Metric, Score

from brainscore_vision.metric_helpers.transformations import apply_aggregate
from brainscore_vision.metrics import Ceiling


class InterSubjectConsistency(Ceiling):
def __init__(self,
metric: Metric,
subject_column='subject'):
"""
:param metric: The metric to compare two halves. Typically same as is used for model-data comparisons
"""
self._metric = metric
self._subject_column = subject_column

def __call__(self, assembly: DataAssembly) -> Score:
scores = []
subjects = list(sorted(set(assembly[self._subject_column].values)))
for target_subject in subjects:
target_assembly = assembly[{'neuroid': [subject == target_subject
for subject in
assembly[self._subject_column].values]}]
# predictor are all other subjects
source_subjects = set(subjects) - {target_subject}
pool_assembly = assembly[{'neuroid': [subject in source_subjects
for subject in
assembly[self._subject_column].values]}]
score = self._metric(pool_assembly, target_assembly)
# store scores
score_has_subject = hasattr(score, 'raw') and hasattr(score.raw, self._subject_column)
apply_raw = 'raw' in score.attrs and \
not score_has_subject # only propagate if column not already part of score
score = score.expand_dims(self._subject_column, _apply_raw=apply_raw)
score.__setitem__(self._subject_column, [target_subject], _apply_raw=apply_raw)
scores.append(score)

scores = Score.merge(*scores)
scores = apply_aggregate(lambda scores: scores.mean(self._subject_column), scores)
return scores
25 changes: 25 additions & 0 deletions brainscore_vision/metrics/inter_subject_consistency/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from string import ascii_lowercase as alphabet

import numpy as np
from brainio.assemblies import NeuroidAssembly
from pytest import approx

from brainscore_vision import load_ceiling, load_metric
from numpy.random import RandomState


class TestInterSubjectConsistency:
def test_dummy_data(self):
rnd = RandomState(0)
subject_matrix = rnd.rand(7, 5)
data = NeuroidAssembly(np.concatenate([subject_matrix, subject_matrix], axis=1),
coords={'stimulus_id': ('presentation', list(alphabet)[:7]),
'image_meta': ('presentation', list(alphabet)[:7]),
'neuroid_id': ('neuroid', np.arange(10)),
'neuroid_meta': ('neuroid', np.arange(10)),
'subject': ('neuroid', np.repeat(['A', 'B'], 5))},
dims=['presentation', 'neuroid'])
metric = load_metric('rdm')
ceiler = load_ceiling('inter_subject_consistency', metric=metric)
ceiling = ceiler(data)
assert ceiling.item() == approx(1, abs=1e-8)