From 3116e1b6ea7cebf1c0cddd46d5e95d43d9f14926 Mon Sep 17 00:00:00 2001 From: Khaled K Shehada Date: Wed, 8 Nov 2023 15:10:48 -0500 Subject: [PATCH 1/8] Added support for bidirectional LM models via masked processing --- .../model_helpers/huggingface.py | 128 ++++++++++++++++-- 1 file changed, 115 insertions(+), 13 deletions(-) diff --git a/brainscore_language/model_helpers/huggingface.py b/brainscore_language/model_helpers/huggingface.py index 8e309b7a..3602a9df 100644 --- a/brainscore_language/model_helpers/huggingface.py +++ b/brainscore_language/model_helpers/huggingface.py @@ -1,6 +1,7 @@ from collections import OrderedDict import functools +import itertools import logging import numpy as np import re @@ -9,7 +10,7 @@ from numpy.core import defchararray from torch.utils.hooks import RemovableHandle from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding +from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer, BatchEncoding from transformers.modeling_outputs import CausalLMOutput from typing import Union, List, Tuple, Dict, Callable @@ -26,6 +27,7 @@ def __init__( region_layer_mapping: dict, model=None, tokenizer=None, + bidirectional=False, task_heads: Union[None, Dict[ArtificialSubject.Task, Callable]] = None, ): """ @@ -34,6 +36,7 @@ def __init__( This can be left empty, but the model will not be able to be tested on neural benchmarks :param model: the model to run inference from. Using `AutoModelForCausalLM.from_pretrained` if `None`. :param tokenizer: the model's associated tokenizer. Using `AutoTokenizer.from_pretrained` if `None`. + :param bidirectional: whether to use bidirectional (masked) modeling [default: False] :param task_heads: a mapping from one or multiple tasks (:class:`~brainscore_language.artificial_subject.ArtificialSubject.Task`) to a function outputting the requested task output, given the basemodel's base output @@ -42,7 +45,18 @@ def __init__( self._logger = logging.getLogger(fullname(self)) self.model_id = model_id self.region_layer_mapping = region_layer_mapping - self.basemodel = (model if model is not None else AutoModelForCausalLM.from_pretrained(self.model_id)) + self.bidirectional = bidirectional + + if model is not None: + self.basemodel = model + elif self.bidirectional: + self.basemodel = AutoModelForMaskedLM.from_pretrained(self.model_id) + else: + self.basemodel = AutoModelForCausalLM.from_pretrained(self.model_id) + + # Context window = # positional embeddings - # special tokens [CLS, SEP] + self.context_window = getattr(self.basemodel.config, "max_position_embeddings", 0) - 2 + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.basemodel.to(self.device) self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(self.model_id, @@ -71,15 +85,7 @@ def start_neural_recording(self, recording_type: ArtificialSubject.RecordingType): self.neural_recordings.append((recording_target, recording_type)) - def digest_text(self, text: Union[str, List[str]]) -> Dict[str, DataAssembly]: - """ - :param text: the text to be used for inference e.g. "the quick brown fox" - :return: assembly of either behavioral output or internal neural representations - """ - - if type(text) == str: - text = [text] - + def _causal_inference(self, text): output = {'behavior': [], 'neural': []} number_of_tokens = 0 @@ -115,6 +121,97 @@ def digest_text(self, text: Union[str, List[str]]) -> Dict[str, DataAssembly]: if self.neural_recordings: representations = self.output_to_representations(layer_representations, stimuli_coords=stimuli_coords) output['neural'].append(representations) + return output + + def _masked_inference(self, text): + output = {'behavior': [], 'neural': []} + text_tokens = [] + + # preprocessing: get the tokens for each text part + remaining_tokens = 0 + for text_part in text: + part_tokens = self.tokenizer.tokenize(text_part) + remaining_tokens += len(part_tokens) + text_tokens.append(part_tokens) + + start_part = 0 + cur_number_of_tokens = 0 + for part_number, text_part in enumerate(tqdm(text)): + cur_number_of_tokens += len(text_tokens[part_number]) + while cur_number_of_tokens > (self.context_window / 2) and (start_part < part_number): + cur_number_of_tokens -= len(text_tokens[start_part]) + start_part += 1 + + end_part = part_number + 1 + if self.behavioral_task == ArtificialSubject.Task.reading_times: + # For reading time estimation, the input should be masked, otherwise + # surprisal will be very low. + end_part = part_number + + context_tokens = list(itertools.chain.from_iterable(text_tokens[start_part:end_part])) + context = prepare_context(text[start_part: part_number + 1]) + + # Add MASK tokens to the second half of the context + context_tokens += [self.tokenizer.mask_token] * min(remaining_tokens, self.context_window - len(context_tokens)) + masked_part = self.tokenizer.convert_tokens_to_string(context_tokens) + part_inputs = self.tokenizer(masked_part, return_tensors="pt", return_overflowing_tokens=self._tokenizer_returns_overflow) + part_inputs.to(self.device) + + # prepare recording hooks + hooks, layer_representations = self._setup_hooks() + + # predicted_logits = logits[-self.current_tokens['input_ids'].shape[-1] - 1: - 1, :].contiguous() + self.current_tokens = self.tokenizer(text_part, return_tensors="pt", add_special_tokens=False, return_overflowing_tokens=self._tokenizer_returns_overflow) + self.current_tokens.to(self.device) + + # run and remove hooks + try: + with torch.no_grad(): + base_output = self.basemodel(**part_inputs) + for hook in hooks: + hook.remove() + except: + breakpoint() + + mask_token_index = torch.where(part_inputs["input_ids"] == self.tokenizer.mask_token_id)[1][0] + base_output.logits = base_output.logits[:, :mask_token_index + 1] + if self.tokenizer.cls_token is not None: + base_output.logits = base_output.logits[:, 1:] + + remaining_tokens -= len(text_tokens[part_number]) + + # format output + stimuli_coords = { + 'stimulus': ('presentation', [text_part]), + 'context': ('presentation', [context]), + 'part_number': ('presentation', [part_number]), + } + if self.behavioral_task: + behavioral_output = self.output_to_behavior(base_output=base_output) + behavior = BehavioralAssembly( + [behavioral_output], + coords=stimuli_coords, + dims=['presentation'] + ) + output['behavior'].append(behavior) + if self.neural_recordings: + representations = self.output_to_representations(layer_representations, stimuli_coords=stimuli_coords) + output['neural'].append(representations) + return output + + def digest_text(self, text: Union[str, List[str]]) -> Dict[str, DataAssembly]: + """ + :param text: the text to be used for inference e.g. "the quick brown fox" + :return: assembly of either behavioral output or internal neural representations + """ + + if type(text) == str: + text = [text] + + if self.bidirectional: + output = self._masked_inference(text) + else: + output = self._causal_inference(text) # merge over text parts self._logger.debug("Merging outputs") @@ -198,10 +295,13 @@ def _setup_hooks(self): return hooks, layer_representations def output_to_representations(self, layer_representations: Dict[Tuple[str, str, str], np.ndarray], stimuli_coords): + # Choose to first token [CLS] in bidirectional models, the last token for causal models, to represent passage representation_values = np.concatenate([ - # Choose to use last token (-1) of values[batch, token, unit] to represent passage. - values[:, -1:, :].squeeze(0).cpu() for values in layer_representations.values()], + values[:, :1, :].squeeze(0).cpu() if self.bidirectional + else values[:, -1:, :].squeeze(0).cpu() + for values in layer_representations.values()], axis=-1) # concatenate along neuron axis + neuroid_coords = { 'layer': ('neuroid', np.concatenate([[layer] * values.shape[-1] for (recording_target, recording_type, layer), values @@ -288,6 +388,8 @@ def _register_hook(self, def hook_function(_layer: torch.nn.Module, _input, output: torch.Tensor, key=key): # fix for when taking out only the hidden state, this is different from dropout because of residual state # see: https://github.com/huggingface/transformers/blob/c06d55564740ebdaaf866ffbbbabf8843b34df4b/src/transformers/models/gpt2/modeling_gpt2.py#L428 + if isinstance(output, tuple): + output = output[0] output = output[0] if len(output) > 1 else output target_dict[key] = output From f0f2c68fb003c48919948942ecfda87322c21bc0 Mon Sep 17 00:00:00 2001 From: Khaled K Shehada Date: Wed, 8 Nov 2023 15:12:54 -0500 Subject: [PATCH 2/8] Added base BERT huggingface subject --- brainscore_language/models/bert/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 brainscore_language/models/bert/__init__.py diff --git a/brainscore_language/models/bert/__init__.py b/brainscore_language/models/bert/__init__.py new file mode 100644 index 00000000..ae430603 --- /dev/null +++ b/brainscore_language/models/bert/__init__.py @@ -0,0 +1,6 @@ +from brainscore_language import model_registry +from brainscore_language import ArtificialSubject +from brainscore_language.model_helpers.huggingface import HuggingfaceSubject + +model_registry['bert-base-uncased'] = lambda: HuggingfaceSubject(model_id='bert-base-uncased', region_layer_mapping={ + ArtificialSubject.RecordingTarget.language_system: 'bert.encoder.layer.4'}, bidirectional=True) From 3e7c3f2fc31157c5d147752ba8ac544d8d205873 Mon Sep 17 00:00:00 2001 From: Khaled K Shehada Date: Wed, 8 Nov 2023 15:35:13 -0500 Subject: [PATCH 3/8] Added unit tests for tasks with the bidriectional model --- brainscore_language/models/bert/test.py | 45 +++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 brainscore_language/models/bert/test.py diff --git a/brainscore_language/models/bert/test.py b/brainscore_language/models/bert/test.py new file mode 100644 index 00000000..91d81dd5 --- /dev/null +++ b/brainscore_language/models/bert/test.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest + +from brainscore_language import load_model +from brainscore_language.artificial_subject import ArtificialSubject + + +@pytest.mark.parametrize('model_identifier, expected_reading_times', [ + ('bert-base-uncased', [np.nan, 15.068062, 13.729589, 16.449226, + 18.178684, 18.060932, 17.804218, 26.74436]), +]) +def test_reading_times(model_identifier, expected_reading_times): + model = load_model(model_identifier) + text = ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy'] + model.start_behavioral_task(task=ArtificialSubject.Task.reading_times) + reading_times = model.digest_text(text)['behavior'] + np.testing.assert_allclose( + reading_times, + expected_reading_times, + atol=0.0001) + + +@pytest.mark.parametrize('model_identifier, expected_next_words', [ + ('bert-base-uncased', ['eyes', 'into', 'fallen', 'again']), +]) +def test_next_word(model_identifier, expected_next_words): + model = load_model(model_identifier) + text = ['The quick brown', 'fox jumps', 'over the', 'lazy dog'] + model.start_behavioral_task(task=ArtificialSubject.Task.next_word) + next_word_predictions = model.digest_text(text)['behavior'] + np.testing.assert_array_equal(next_word_predictions, expected_next_words) + + +@pytest.mark.parametrize('model_identifier, feature_size', [ + ('bert-base-uncased', 768), +]) +def test_neural(model_identifier, feature_size): + model = load_model(model_identifier) + text = ['the quick brown fox', 'jumps over', 'the lazy dog'] + model.start_neural_recording(recording_target=ArtificialSubject.RecordingTarget.language_system, + recording_type=ArtificialSubject.RecordingType.fMRI) + representations = model.digest_text(text)['neural'] + assert len(representations['presentation']) == 3 + np.testing.assert_array_equal(representations['stimulus'], text) + assert len(representations['neuroid']) == feature_size From d375b1100de6697eab718b18ce4da2a20c799ca2 Mon Sep 17 00:00:00 2001 From: Khaled Shehada <45083797+shehadak@users.noreply.github.com> Date: Thu, 9 Nov 2023 11:33:54 -0500 Subject: [PATCH 4/8] Update brainscore_language/model_helpers/huggingface.py Co-authored-by: Martin Schrimpf --- brainscore_language/model_helpers/huggingface.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/brainscore_language/model_helpers/huggingface.py b/brainscore_language/model_helpers/huggingface.py index 3602a9df..eb0e5c1f 100644 --- a/brainscore_language/model_helpers/huggingface.py +++ b/brainscore_language/model_helpers/huggingface.py @@ -297,8 +297,9 @@ def _setup_hooks(self): def output_to_representations(self, layer_representations: Dict[Tuple[str, str, str], np.ndarray], stimuli_coords): # Choose to first token [CLS] in bidirectional models, the last token for causal models, to represent passage representation_values = np.concatenate([ - values[:, :1, :].squeeze(0).cpu() if self.bidirectional - else values[:, -1:, :].squeeze(0).cpu() + # values are [batch, token, unit] + values[:, -1:, :].squeeze(0).cpu() if not self.bidirectional # use last token (-1) to represent passage + else values[:, :1, :].squeeze(0).cpu() # for bidirectional models, use first token for values in layer_representations.values()], axis=-1) # concatenate along neuron axis From 73fdcec1abe594fa5275d321864b60f0f9b2d695 Mon Sep 17 00:00:00 2001 From: Khaled K Shehada Date: Thu, 9 Nov 2023 14:11:42 -0500 Subject: [PATCH 5/8] Refactored common functionality in _masked_inference and _causal_inference --- .../model_helpers/huggingface.py | 137 ++++++++---------- 1 file changed, 62 insertions(+), 75 deletions(-) diff --git a/brainscore_language/model_helpers/huggingface.py b/brainscore_language/model_helpers/huggingface.py index eb0e5c1f..ab67b06c 100644 --- a/brainscore_language/model_helpers/huggingface.py +++ b/brainscore_language/model_helpers/huggingface.py @@ -89,45 +89,24 @@ def _causal_inference(self, text): output = {'behavior': [], 'neural': []} number_of_tokens = 0 - text_iterator = tqdm(text, desc='digest text') if len(text) > 100 else text # show progress bar if many parts - for part_number, text_part in enumerate(text_iterator): + for part_number, text_part in enumerate(tqdm(text, desc='digest text')): # prepare string representation of context context = prepare_context(text[:part_number + 1]) context_tokens, number_of_tokens = self._tokenize(context, number_of_tokens) + + # setup hooks in the model's layers and perform inference on `context_tokens` + base_output, layer_representations = self._run_model_with_hooks(context_tokens) - # prepare recording hooks - hooks, layer_representations = self._setup_hooks() - - # run and remove hooks - with torch.no_grad(): - base_output = self.basemodel(**context_tokens) - for hook in hooks: - hook.remove() - - # format output - stimuli_coords = { - 'stimulus': ('presentation', [text_part]), - 'context': ('presentation', [context]), - 'part_number': ('presentation', [part_number]), - } - if self.behavioral_task: - behavioral_output = self.output_to_behavior(base_output=base_output) - behavior = BehavioralAssembly( - [behavioral_output], - coords=stimuli_coords, - dims=['presentation'] - ) - output['behavior'].append(behavior) - if self.neural_recordings: - representations = self.output_to_representations(layer_representations, stimuli_coords=stimuli_coords) - output['neural'].append(representations) + # update output dict with new behavioral output and/or neural representations + output = self._format_output(base_output, layer_representations, text_part, context, part_number, output) + return output def _masked_inference(self, text): output = {'behavior': [], 'neural': []} text_tokens = [] - # preprocessing: get the tokens for each text part + # Preprocessing: get the tokens for each text part remaining_tokens = 0 for text_part in text: part_tokens = self.tokenizer.tokenize(text_part) @@ -135,69 +114,77 @@ def _masked_inference(self, text): text_tokens.append(part_tokens) start_part = 0 - cur_number_of_tokens = 0 - for part_number, text_part in enumerate(tqdm(text)): - cur_number_of_tokens += len(text_tokens[part_number]) - while cur_number_of_tokens > (self.context_window / 2) and (start_part < part_number): - cur_number_of_tokens -= len(text_tokens[start_part]) + number_of_tokens = 0 + for part_number, text_part in enumerate(tqdm(text, desc='digest text')): + number_of_tokens += len(text_tokens[part_number]) + while number_of_tokens > (self.context_window / 2) and (start_part < part_number): + number_of_tokens -= len(text_tokens[start_part]) start_part += 1 end_part = part_number + 1 if self.behavioral_task == ArtificialSubject.Task.reading_times: - # For reading time estimation, the input should be masked, otherwise - # surprisal will be very low. + # For reading time estimation, this part should be masked, otherwise + # surprisal will be very low since the model will have seen the tokens. end_part = part_number - context_tokens = list(itertools.chain.from_iterable(text_tokens[start_part:end_part])) + unmasked_context_tokens = list(itertools.chain.from_iterable(text_tokens[start_part:end_part])) context = prepare_context(text[start_part: part_number + 1]) - # Add MASK tokens to the second half of the context - context_tokens += [self.tokenizer.mask_token] * min(remaining_tokens, self.context_window - len(context_tokens)) - masked_part = self.tokenizer.convert_tokens_to_string(context_tokens) - part_inputs = self.tokenizer(masked_part, return_tensors="pt", return_overflowing_tokens=self._tokenizer_returns_overflow) - part_inputs.to(self.device) + # Add [MASK] tokens to the second half of the context + unmasked_context_tokens += [self.tokenizer.mask_token] * min(remaining_tokens, self.context_window - len(unmasked_context_tokens)) + masked_part = self.tokenizer.convert_tokens_to_string(unmasked_context_tokens) + context_tokens = self.tokenizer(masked_part, return_tensors="pt", return_overflowing_tokens=self._tokenizer_returns_overflow) + context_tokens.to(self.device) + remaining_tokens -= len(text_tokens[part_number]) - # prepare recording hooks - hooks, layer_representations = self._setup_hooks() - - # predicted_logits = logits[-self.current_tokens['input_ids'].shape[-1] - 1: - 1, :].contiguous() + # Tokenize the text part without masking for comparison with model logits (e.g. for reading time estimates) self.current_tokens = self.tokenizer(text_part, return_tensors="pt", add_special_tokens=False, return_overflowing_tokens=self._tokenizer_returns_overflow) self.current_tokens.to(self.device) - # run and remove hooks - try: - with torch.no_grad(): - base_output = self.basemodel(**part_inputs) - for hook in hooks: - hook.remove() - except: - breakpoint() - - mask_token_index = torch.where(part_inputs["input_ids"] == self.tokenizer.mask_token_id)[1][0] + # Setup hooks in the model's layers and perform inference on `context_tokens` + base_output, layer_representations = self._run_model_with_hooks(context_tokens) + + # Post processing model output: removing logits for [MASK] tokens in the future context + mask_token_index = torch.where(context_tokens["input_ids"] == self.tokenizer.mask_token_id)[1][0] base_output.logits = base_output.logits[:, :mask_token_index + 1] if self.tokenizer.cls_token is not None: base_output.logits = base_output.logits[:, 1:] - remaining_tokens -= len(text_tokens[part_number]) - - # format output - stimuli_coords = { - 'stimulus': ('presentation', [text_part]), - 'context': ('presentation', [context]), - 'part_number': ('presentation', [part_number]), - } - if self.behavioral_task: - behavioral_output = self.output_to_behavior(base_output=base_output) - behavior = BehavioralAssembly( - [behavioral_output], - coords=stimuli_coords, - dims=['presentation'] - ) - output['behavior'].append(behavior) - if self.neural_recordings: - representations = self.output_to_representations(layer_representations, stimuli_coords=stimuli_coords) - output['neural'].append(representations) + # Update output dict with new behavioral output and/or neural representations + output = self._format_output(base_output, layer_representations, text_part, context, part_number, output) + + return output + + def _format_output(self, base_output, layer_representations, text_part, context, part_number, output): + if not output: + output = {'behavior': [], 'neural': []} + + stimuli_coords = { + 'stimulus': ('presentation', [text_part]), + 'context': ('presentation', [context]), + 'part_number': ('presentation', [part_number]), + } + if self.behavioral_task: + behavioral_output = self.output_to_behavior(base_output=base_output) + behavior = BehavioralAssembly( + [behavioral_output], + coords=stimuli_coords, + dims=['presentation'] + ) + output["behavior"].append(behavior) + if self.neural_recordings: + representations = self.output_to_representations(layer_representations, stimuli_coords=stimuli_coords) + output["neural"].append(representations) return output + + def _run_model_with_hooks(self, context_tokens): + # prepare recording hooks + hooks, layer_representations = self._setup_hooks() + with torch.no_grad(): + base_output = self.basemodel(**context_tokens) + for hook in hooks: + hook.remove() + return base_output, layer_representations def digest_text(self, text: Union[str, List[str]]) -> Dict[str, DataAssembly]: """ From 66311de32f5d4b1204ce1bbc344afd68c5b92a1d Mon Sep 17 00:00:00 2001 From: Khaled K Shehada Date: Thu, 9 Nov 2023 14:27:14 -0500 Subject: [PATCH 6/8] Added unit tests for bidirectional huggingface models --- tests/test_model_helpers/test_huggingface.py | 45 ++++++++++++++------ 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/tests/test_model_helpers/test_huggingface.py b/tests/test_model_helpers/test_huggingface.py index b72bac4d..7fae3266 100644 --- a/tests/test_model_helpers/test_huggingface.py +++ b/tests/test_model_helpers/test_huggingface.py @@ -11,36 +11,38 @@ class TestNextWord: - @pytest.mark.parametrize('model_identifier, expected_next_word', [ - pytest.param('bert-base-uncased', '.', marks=pytest.mark.memory_intense), - pytest.param('gpt2-xl', 'jumps', marks=pytest.mark.memory_intense), - ('distilgpt2', 'es'), + @pytest.mark.parametrize('model_identifier, expected_next_word, bidirectional', [ + pytest.param('bert-base-uncased', 'and', True, marks=pytest.mark.memory_intense), + pytest.param('bert-base-uncased', '.', False, marks=pytest.mark.memory_intense), + pytest.param('gpt2-xl', 'jumps', False, marks=pytest.mark.memory_intense), + ('distilgpt2', 'es', False), ]) - def test_single_string(self, model_identifier, expected_next_word): + def test_single_string(self, model_identifier, expected_next_word, bidirectional): """ This is a simple test that takes in text = 'the quick brown fox', and tests the next word. This test is a stand-in prototype to check if our model definitions are correct. """ - model = HuggingfaceSubject(model_id=model_identifier, region_layer_mapping={}) + model = HuggingfaceSubject(model_id=model_identifier, region_layer_mapping={}, bidirectional=bidirectional) text = 'the quick brown fox' _logger.info(f'Running {model.identifier()} with text "{text}"') model.start_behavioral_task(task=ArtificialSubject.Task.next_word) next_word = model.digest_text(text)['behavior'].values assert next_word == expected_next_word - @pytest.mark.parametrize('model_identifier, expected_next_words', [ - pytest.param('bert-base-uncased', ['.', '.', '.'], marks=pytest.mark.memory_intense), - pytest.param('gpt2-xl', ['jumps', 'the', 'dog'], marks=pytest.mark.memory_intense), - ('distilgpt2', ['es', 'the', 'fox']), + @pytest.mark.parametrize('model_identifier, expected_next_words, bidirectional', [ + pytest.param('bert-base-uncased', [';', 'the', 'water'], True, marks=pytest.mark.memory_intense), + pytest.param('bert-base-uncased', ['.', '.', '.'], False, marks=pytest.mark.memory_intense), + pytest.param('gpt2-xl', ['jumps', 'the', 'dog'], False, marks=pytest.mark.memory_intense), + ('distilgpt2', ['es', 'the', 'fox'], False), ]) - def test_list_input(self, model_identifier, expected_next_words): + def test_list_input(self, model_identifier, expected_next_words, bidirectional): """ This is a simple test that takes in text = ['the quick brown fox', 'jumps over', 'the lazy'], and tests the next word for each text part in the list. This test is a stand-in prototype to check if our model definitions are correct. """ - model = HuggingfaceSubject(model_id=model_identifier, region_layer_mapping={}) + model = HuggingfaceSubject(model_id=model_identifier, region_layer_mapping={}, bidirectional=bidirectional) text = ['the quick brown fox', 'jumps over', 'the lazy'] _logger.info(f'Running {model.identifier()} with text "{text}"') model.start_behavioral_task(task=ArtificialSubject.Task.next_word) @@ -173,6 +175,25 @@ def test_one_text_single_target(self): assert len(representations['neuroid']) == 768 _logger.info(f'representation shape is correct: {representations.shape}') + @pytest.mark.memory_intense + def test_one_text_single_target_bidirectional(self): + """ + This is a simple test that takes in text = 'the quick brown fox', and asserts that a bidirectiona BERT model + layer indexed by `representation_layer` has 1 text presentation and 768 neurons. This test is a stand-in prototype + to check if our model definitions are correct. + """ + model = HuggingfaceSubject(model_id='bert-base-uncased', region_layer_mapping={ + ArtificialSubject.RecordingTarget.language_system: 'bert.encoder.layer.4'}) + text = 'the quick brown fox' + _logger.info(f'Running {model.identifier()} with text "{text}"') + model.start_neural_recording(recording_target=ArtificialSubject.RecordingTarget.language_system, + recording_type=ArtificialSubject.RecordingType.fMRI) + representations = model.digest_text(text)['neural'] + assert len(representations['presentation']) == 1 + assert representations['stimulus'].squeeze() == text + assert len(representations['neuroid']) == 768 + _logger.info(f'representation shape is correct: {representations.shape}') + @pytest.mark.memory_intense def test_one_text_two_targets(self): model = HuggingfaceSubject(model_id='distilgpt2', region_layer_mapping={ From 2a1d9ed1aee9b4804360e8a77df7e2e5312c738b Mon Sep 17 00:00:00 2001 From: Khaled Shehada <45083797+shehadak@users.noreply.github.com> Date: Fri, 10 Nov 2023 12:51:02 -0500 Subject: [PATCH 7/8] Update brainscore_language/model_helpers/huggingface.py --- brainscore_language/model_helpers/huggingface.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/brainscore_language/model_helpers/huggingface.py b/brainscore_language/model_helpers/huggingface.py index ab67b06c..7a73ad0a 100644 --- a/brainscore_language/model_helpers/huggingface.py +++ b/brainscore_language/model_helpers/huggingface.py @@ -376,9 +376,7 @@ def _register_hook(self, def hook_function(_layer: torch.nn.Module, _input, output: torch.Tensor, key=key): # fix for when taking out only the hidden state, this is different from dropout because of residual state # see: https://github.com/huggingface/transformers/blob/c06d55564740ebdaaf866ffbbbabf8843b34df4b/src/transformers/models/gpt2/modeling_gpt2.py#L428 - if isinstance(output, tuple): - output = output[0] - output = output[0] if len(output) > 1 else output + output = output[0] if isinstance(output, tuple) else output target_dict[key] = output hook = layer.register_forward_hook(hook_function) From 08433e3e263fe32c4fbacde05abbab4182d83e05 Mon Sep 17 00:00:00 2001 From: Khaled K Shehada Date: Fri, 10 Nov 2023 16:28:00 -0500 Subject: [PATCH 8/8] added comment explaining BERT layer assignment --- brainscore_language/models/bert/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/brainscore_language/models/bert/__init__.py b/brainscore_language/models/bert/__init__.py index ae430603..050e6dcc 100644 --- a/brainscore_language/models/bert/__init__.py +++ b/brainscore_language/models/bert/__init__.py @@ -2,5 +2,10 @@ from brainscore_language import ArtificialSubject from brainscore_language.model_helpers.huggingface import HuggingfaceSubject +# layer assignment was determined by scoring each transformer layer against three neural +# benchmarks: Pereira2018.243sentences-linear, Pereira2018.384sentences-linear, and +# Blank2014-linear, and choosing the layer with the highest average score. + +# BERT model_registry['bert-base-uncased'] = lambda: HuggingfaceSubject(model_id='bert-base-uncased', region_layer_mapping={ ArtificialSubject.RecordingTarget.language_system: 'bert.encoder.layer.4'}, bidirectional=True)