From 0e6fc16a948290fa62c1c1b8d251987cdce64207 Mon Sep 17 00:00:00 2001 From: Martin Schrimpf Date: Tue, 25 Feb 2025 22:40:35 +0100 Subject: [PATCH] implement Inter Subject Consistency --- .../benchmark_helpers/neural_common.py | 11 ++++- .../benchmarks/majajhong2015/benchmark.py | 19 +++++---- .../inter_subject_consistency/__init__.py | 4 ++ .../inter_subject_consistency/ceiling.py | 41 +++++++++++++++++++ .../metrics/inter_subject_consistency/test.py | 25 +++++++++++ 5 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 brainscore_vision/metrics/inter_subject_consistency/__init__.py create mode 100644 brainscore_vision/metrics/inter_subject_consistency/ceiling.py create mode 100644 brainscore_vision/metrics/inter_subject_consistency/test.py diff --git a/brainscore_vision/benchmark_helpers/neural_common.py b/brainscore_vision/benchmark_helpers/neural_common.py index 3648b7a018..8fb4dfa66e 100644 --- a/brainscore_vision/benchmark_helpers/neural_common.py +++ b/brainscore_vision/benchmark_helpers/neural_common.py @@ -8,7 +8,10 @@ class NeuralBenchmark(BenchmarkBase): - def __init__(self, identifier, assembly, similarity_metric, visual_degrees, number_of_trials, **kwargs): + def __init__(self, identifier, assembly, similarity_metric, + visual_degrees, number_of_trials, + inter_subject_ceiling_func=None, + **kwargs): super(NeuralBenchmark, self).__init__(identifier=identifier, **kwargs) self._assembly = assembly self._similarity_metric = similarity_metric @@ -19,6 +22,7 @@ def __init__(self, identifier, assembly, similarity_metric, visual_degrees, numb self.timebins = timebins self._visual_degrees = visual_degrees self._number_of_trials = number_of_trials + self._inter_subject_ceiling_func = inter_subject_ceiling_func def __call__(self, candidate: BrainModel): candidate.start_recording(self.region, time_bins=self.timebins) @@ -31,6 +35,10 @@ def __call__(self, candidate: BrainModel): ceiled_score = explained_variance(raw_score, self.ceiling) return ceiled_score + @property + def inter_subject_ceiling(self): + return self._inter_subject_ceiling_func() + def timebins_from_assembly(assembly): timebins = assembly['time_bin'].values @@ -63,6 +71,7 @@ def avg_repr(assembly): return apply_keep_attrs(assembly, avg_repr) + def apply_keep_attrs(assembly, fnc): # workaround to keeping attrs attrs = assembly.attrs assembly = fnc(assembly) diff --git a/brainscore_vision/benchmarks/majajhong2015/benchmark.py b/brainscore_vision/benchmarks/majajhong2015/benchmark.py index 5270ab7aff..a39ce3970a 100644 --- a/brainscore_vision/benchmarks/majajhong2015/benchmark.py +++ b/brainscore_vision/benchmarks/majajhong2015/benchmark.py @@ -26,14 +26,17 @@ spantime_pls_metric = lambda: load_metric('spantime_pls', crossvalidation_kwargs=crossvalidation_kwargs) def _DicarloMajajHong2015Region(region: str, access: str, identifier_metric_suffix: str, - similarity_metric: Metric, ceiler: Ceiling, time_interval: float = None): + similarity_metric: Metric, intra_subject_ceiler: Ceiling, time_interval: float = None): assembly_repetition = load_assembly(average_repetitions=False, region=region, access=access, time_interval=time_interval) assembly = load_assembly(average_repetitions=True, region=region, access=access, time_interval=time_interval) + inter_subject_ceiler = load_ceiling('inter_subject_consistency', + metric=similarity_metric, subject_column='animal') benchmark_identifier = f'MajajHong2015.{region}' + ('.public' if access == 'public' else '') return NeuralBenchmark(identifier=f'{benchmark_identifier}-{identifier_metric_suffix}', version=3, assembly=assembly, similarity_metric=similarity_metric, visual_degrees=VISUAL_DEGREES, number_of_trials=NUMBER_OF_TRIALS, - ceiling_func=lambda: ceiler(assembly_repetition), + ceiling_func=lambda: intra_subject_ceiler(assembly_repetition), + inter_subject_ceiling_func=lambda: inter_subject_ceiler(assembly), parent=region, bibtex=BIBTEX) @@ -41,37 +44,37 @@ def _DicarloMajajHong2015Region(region: str, access: str, identifier_metric_suff def DicarloMajajHong2015V4PLS(): return _DicarloMajajHong2015Region(region='V4', access='private', identifier_metric_suffix='pls', similarity_metric=pls_metric(), - ceiler=load_ceiling('internal_consistency')) + intra_subject_ceiler=load_ceiling('internal_consistency')) def DicarloMajajHong2015ITPLS(): return _DicarloMajajHong2015Region(region='IT', access='private', identifier_metric_suffix='pls', similarity_metric=pls_metric(), - ceiler=load_ceiling('internal_consistency')) + intra_subject_ceiler=load_ceiling('internal_consistency')) def MajajHongV4PublicBenchmark(): return _DicarloMajajHong2015Region(region='V4', access='public', identifier_metric_suffix='pls', similarity_metric=pls_metric(), - ceiler=load_ceiling('internal_consistency')) + intra_subject_ceiler=load_ceiling('internal_consistency')) def MajajHongITPublicBenchmark(): return _DicarloMajajHong2015Region(region='IT', access='public', identifier_metric_suffix='pls', similarity_metric=pls_metric(), - ceiler=load_ceiling('internal_consistency')) + intra_subject_ceiler=load_ceiling('internal_consistency')) def MajajHongV4TemporalPublicBenchmark(time_interval: float = None): return _DicarloMajajHong2015Region(region='V4', access='public', identifier_metric_suffix='pls', similarity_metric=spantime_pls_metric(), time_interval=time_interval, - ceiler=load_ceiling('internal_consistency_temporal')) + intra_subject_ceiler=load_ceiling('internal_consistency_temporal')) def MajajHongITTemporalPublicBenchmark(time_interval: float = None): return _DicarloMajajHong2015Region(region='IT', access='public', identifier_metric_suffix='pls', similarity_metric=spantime_pls_metric(), time_interval=time_interval, - ceiler=load_ceiling('internal_consistency_temporal')) + intra_subject_ceiler=load_ceiling('internal_consistency_temporal')) def load_assembly(average_repetitions: bool, region: str, access: str = 'private', time_interval: float = None): diff --git a/brainscore_vision/metrics/inter_subject_consistency/__init__.py b/brainscore_vision/metrics/inter_subject_consistency/__init__.py new file mode 100644 index 0000000000..3083e42b3b --- /dev/null +++ b/brainscore_vision/metrics/inter_subject_consistency/__init__.py @@ -0,0 +1,4 @@ +from brainscore_vision import metric_registry +from .ceiling import InterSubjectConsistency + +metric_registry['inter_subject_consistency'] = InterSubjectConsistency diff --git a/brainscore_vision/metrics/inter_subject_consistency/ceiling.py b/brainscore_vision/metrics/inter_subject_consistency/ceiling.py new file mode 100644 index 0000000000..3bf5c5e097 --- /dev/null +++ b/brainscore_vision/metrics/inter_subject_consistency/ceiling.py @@ -0,0 +1,41 @@ +from brainio.assemblies import DataAssembly +from brainscore_core import Metric, Score + +from brainscore_vision.metric_helpers.transformations import apply_aggregate +from brainscore_vision.metrics import Ceiling + + +class InterSubjectConsistency(Ceiling): + def __init__(self, + metric: Metric, + subject_column='subject'): + """ + :param metric: The metric to compare two halves. Typically same as is used for model-data comparisons + """ + self._metric = metric + self._subject_column = subject_column + + def __call__(self, assembly: DataAssembly) -> Score: + scores = [] + subjects = list(sorted(set(assembly[self._subject_column].values))) + for target_subject in subjects: + target_assembly = assembly[{'neuroid': [subject == target_subject + for subject in + assembly[self._subject_column].values]}] + # predictor are all other subjects + source_subjects = set(subjects) - {target_subject} + pool_assembly = assembly[{'neuroid': [subject in source_subjects + for subject in + assembly[self._subject_column].values]}] + score = self._metric(pool_assembly, target_assembly) + # store scores + score_has_subject = hasattr(score, 'raw') and hasattr(score.raw, self._subject_column) + apply_raw = 'raw' in score.attrs and \ + not score_has_subject # only propagate if column not already part of score + score = score.expand_dims(self._subject_column, _apply_raw=apply_raw) + score.__setitem__(self._subject_column, [target_subject], _apply_raw=apply_raw) + scores.append(score) + + scores = Score.merge(*scores) + scores = apply_aggregate(lambda scores: scores.mean(self._subject_column), scores) + return scores diff --git a/brainscore_vision/metrics/inter_subject_consistency/test.py b/brainscore_vision/metrics/inter_subject_consistency/test.py new file mode 100644 index 0000000000..515cdc72b6 --- /dev/null +++ b/brainscore_vision/metrics/inter_subject_consistency/test.py @@ -0,0 +1,25 @@ +from string import ascii_lowercase as alphabet + +import numpy as np +from brainio.assemblies import NeuroidAssembly +from pytest import approx + +from brainscore_vision import load_ceiling, load_metric +from numpy.random import RandomState + + +class TestInterSubjectConsistency: + def test_dummy_data(self): + rnd = RandomState(0) + subject_matrix = rnd.rand(7, 5) + data = NeuroidAssembly(np.concatenate([subject_matrix, subject_matrix], axis=1), + coords={'stimulus_id': ('presentation', list(alphabet)[:7]), + 'image_meta': ('presentation', list(alphabet)[:7]), + 'neuroid_id': ('neuroid', np.arange(10)), + 'neuroid_meta': ('neuroid', np.arange(10)), + 'subject': ('neuroid', np.repeat(['A', 'B'], 5))}, + dims=['presentation', 'neuroid']) + metric = load_metric('rdm') + ceiler = load_ceiling('inter_subject_consistency', metric=metric) + ceiling = ceiler(data) + assert ceiling.item() == approx(1, abs=1e-8)