Skip to content
168 changes: 128 additions & 40 deletions brainscore_language/model_helpers/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import OrderedDict

import functools
import itertools
import logging
import numpy as np
import re
Expand All @@ -9,7 +10,7 @@
from numpy.core import defchararray
from torch.utils.hooks import RemovableHandle
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer, BatchEncoding
from transformers.modeling_outputs import CausalLMOutput
from typing import Union, List, Tuple, Dict, Callable

Expand All @@ -26,6 +27,7 @@ def __init__(
region_layer_mapping: dict,
model=None,
tokenizer=None,
bidirectional=False,
task_heads: Union[None, Dict[ArtificialSubject.Task, Callable]] = None,
):
"""
Expand All @@ -34,6 +36,7 @@ def __init__(
This can be left empty, but the model will not be able to be tested on neural benchmarks
:param model: the model to run inference from. Using `AutoModelForCausalLM.from_pretrained` if `None`.
:param tokenizer: the model's associated tokenizer. Using `AutoTokenizer.from_pretrained` if `None`.
:param bidirectional: whether to use bidirectional (masked) modeling [default: False]
:param task_heads: a mapping from one or multiple tasks
(:class:`~brainscore_language.artificial_subject.ArtificialSubject.Task`) to a function outputting the
requested task output, given the basemodel's base output
Expand All @@ -42,7 +45,18 @@ def __init__(
self._logger = logging.getLogger(fullname(self))
self.model_id = model_id
self.region_layer_mapping = region_layer_mapping
self.basemodel = (model if model is not None else AutoModelForCausalLM.from_pretrained(self.model_id))
self.bidirectional = bidirectional

if model is not None:
self.basemodel = model
elif self.bidirectional:
self.basemodel = AutoModelForMaskedLM.from_pretrained(self.model_id)
else:
self.basemodel = AutoModelForCausalLM.from_pretrained(self.model_id)

# Context window = # positional embeddings - # special tokens [CLS, SEP]
self.context_window = getattr(self.basemodel.config, "max_position_embeddings", 0) - 2

self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.basemodel.to(self.device)
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(self.model_id,
Expand Down Expand Up @@ -71,6 +85,107 @@ def start_neural_recording(self,
recording_type: ArtificialSubject.RecordingType):
self.neural_recordings.append((recording_target, recording_type))

def _causal_inference(self, text):
output = {'behavior': [], 'neural': []}
number_of_tokens = 0

for part_number, text_part in enumerate(tqdm(text, desc='digest text')):
# prepare string representation of context
context = prepare_context(text[:part_number + 1])
context_tokens, number_of_tokens = self._tokenize(context, number_of_tokens)

# setup hooks in the model's layers and perform inference on `context_tokens`
base_output, layer_representations = self._run_model_with_hooks(context_tokens)

# update output dict with new behavioral output and/or neural representations
output = self._format_output(base_output, layer_representations, text_part, context, part_number, output)

return output

def _masked_inference(self, text):
output = {'behavior': [], 'neural': []}
text_tokens = []

# Preprocessing: get the tokens for each text part
remaining_tokens = 0
for text_part in text:
part_tokens = self.tokenizer.tokenize(text_part)
remaining_tokens += len(part_tokens)
text_tokens.append(part_tokens)

start_part = 0
number_of_tokens = 0
for part_number, text_part in enumerate(tqdm(text, desc='digest text')):
number_of_tokens += len(text_tokens[part_number])
while number_of_tokens > (self.context_window / 2) and (start_part < part_number):
number_of_tokens -= len(text_tokens[start_part])
start_part += 1

end_part = part_number + 1
if self.behavioral_task == ArtificialSubject.Task.reading_times:
# For reading time estimation, this part should be masked, otherwise
# surprisal will be very low since the model will have seen the tokens.
end_part = part_number

unmasked_context_tokens = list(itertools.chain.from_iterable(text_tokens[start_part:end_part]))
context = prepare_context(text[start_part: part_number + 1])

# Add [MASK] tokens to the second half of the context
unmasked_context_tokens += [self.tokenizer.mask_token] * min(remaining_tokens, self.context_window - len(unmasked_context_tokens))
masked_part = self.tokenizer.convert_tokens_to_string(unmasked_context_tokens)
context_tokens = self.tokenizer(masked_part, return_tensors="pt", return_overflowing_tokens=self._tokenizer_returns_overflow)
context_tokens.to(self.device)
remaining_tokens -= len(text_tokens[part_number])

# Tokenize the text part without masking for comparison with model logits (e.g. for reading time estimates)
self.current_tokens = self.tokenizer(text_part, return_tensors="pt", add_special_tokens=False, return_overflowing_tokens=self._tokenizer_returns_overflow)
self.current_tokens.to(self.device)

# Setup hooks in the model's layers and perform inference on `context_tokens`
base_output, layer_representations = self._run_model_with_hooks(context_tokens)

# Post processing model output: removing logits for [MASK] tokens in the future context
mask_token_index = torch.where(context_tokens["input_ids"] == self.tokenizer.mask_token_id)[1][0]
base_output.logits = base_output.logits[:, :mask_token_index + 1]
if self.tokenizer.cls_token is not None:
base_output.logits = base_output.logits[:, 1:]

# Update output dict with new behavioral output and/or neural representations
output = self._format_output(base_output, layer_representations, text_part, context, part_number, output)

return output

def _format_output(self, base_output, layer_representations, text_part, context, part_number, output):
if not output:
output = {'behavior': [], 'neural': []}

stimuli_coords = {
'stimulus': ('presentation', [text_part]),
'context': ('presentation', [context]),
'part_number': ('presentation', [part_number]),
}
if self.behavioral_task:
behavioral_output = self.output_to_behavior(base_output=base_output)
behavior = BehavioralAssembly(
[behavioral_output],
coords=stimuli_coords,
dims=['presentation']
)
output["behavior"].append(behavior)
if self.neural_recordings:
representations = self.output_to_representations(layer_representations, stimuli_coords=stimuli_coords)
output["neural"].append(representations)
return output

def _run_model_with_hooks(self, context_tokens):
# prepare recording hooks
hooks, layer_representations = self._setup_hooks()
with torch.no_grad():
base_output = self.basemodel(**context_tokens)
for hook in hooks:
hook.remove()
return base_output, layer_representations

def digest_text(self, text: Union[str, List[str]]) -> Dict[str, DataAssembly]:
"""
:param text: the text to be used for inference e.g. "the quick brown fox"
Expand All @@ -80,41 +195,10 @@ def digest_text(self, text: Union[str, List[str]]) -> Dict[str, DataAssembly]:
if type(text) == str:
text = [text]

output = {'behavior': [], 'neural': []}
number_of_tokens = 0

text_iterator = tqdm(text, desc='digest text') if len(text) > 100 else text # show progress bar if many parts
for part_number, text_part in enumerate(text_iterator):
# prepare string representation of context
context = prepare_context(text[:part_number + 1])
context_tokens, number_of_tokens = self._tokenize(context, number_of_tokens)

# prepare recording hooks
hooks, layer_representations = self._setup_hooks()

# run and remove hooks
with torch.no_grad():
base_output = self.basemodel(**context_tokens)
for hook in hooks:
hook.remove()

# format output
stimuli_coords = {
'stimulus': ('presentation', [text_part]),
'context': ('presentation', [context]),
'part_number': ('presentation', [part_number]),
}
if self.behavioral_task:
behavioral_output = self.output_to_behavior(base_output=base_output)
behavior = BehavioralAssembly(
[behavioral_output],
coords=stimuli_coords,
dims=['presentation']
)
output['behavior'].append(behavior)
if self.neural_recordings:
representations = self.output_to_representations(layer_representations, stimuli_coords=stimuli_coords)
output['neural'].append(representations)
if self.bidirectional:
output = self._masked_inference(text)
else:
output = self._causal_inference(text)

# merge over text parts
self._logger.debug("Merging outputs")
Expand Down Expand Up @@ -198,10 +282,14 @@ def _setup_hooks(self):
return hooks, layer_representations

def output_to_representations(self, layer_representations: Dict[Tuple[str, str, str], np.ndarray], stimuli_coords):
# Choose to first token [CLS] in bidirectional models, the last token for causal models, to represent passage
representation_values = np.concatenate([
# Choose to use last token (-1) of values[batch, token, unit] to represent passage.
values[:, -1:, :].squeeze(0).cpu() for values in layer_representations.values()],
# values are [batch, token, unit]
values[:, -1:, :].squeeze(0).cpu() if not self.bidirectional # use last token (-1) to represent passage
else values[:, :1, :].squeeze(0).cpu() # for bidirectional models, use first token
for values in layer_representations.values()],
axis=-1) # concatenate along neuron axis

neuroid_coords = {
'layer': ('neuroid', np.concatenate([[layer] * values.shape[-1]
for (recording_target, recording_type, layer), values
Expand Down Expand Up @@ -288,7 +376,7 @@ def _register_hook(self,
def hook_function(_layer: torch.nn.Module, _input, output: torch.Tensor, key=key):
# fix for when taking out only the hidden state, this is different from dropout because of residual state
# see: https://github.com/huggingface/transformers/blob/c06d55564740ebdaaf866ffbbbabf8843b34df4b/src/transformers/models/gpt2/modeling_gpt2.py#L428
output = output[0] if len(output) > 1 else output
output = output[0] if isinstance(output, tuple) else output
target_dict[key] = output

hook = layer.register_forward_hook(hook_function)
Expand Down
11 changes: 11 additions & 0 deletions brainscore_language/models/bert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from brainscore_language import model_registry
from brainscore_language import ArtificialSubject
from brainscore_language.model_helpers.huggingface import HuggingfaceSubject

# layer assignment was determined by scoring each transformer layer against three neural
# benchmarks: Pereira2018.243sentences-linear, Pereira2018.384sentences-linear, and
# Blank2014-linear, and choosing the layer with the highest average score.

# BERT
model_registry['bert-base-uncased'] = lambda: HuggingfaceSubject(model_id='bert-base-uncased', region_layer_mapping={
ArtificialSubject.RecordingTarget.language_system: 'bert.encoder.layer.4'}, bidirectional=True)
45 changes: 45 additions & 0 deletions brainscore_language/models/bert/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
import pytest

from brainscore_language import load_model
from brainscore_language.artificial_subject import ArtificialSubject


@pytest.mark.parametrize('model_identifier, expected_reading_times', [
('bert-base-uncased', [np.nan, 15.068062, 13.729589, 16.449226,
18.178684, 18.060932, 17.804218, 26.74436]),
])
def test_reading_times(model_identifier, expected_reading_times):
model = load_model(model_identifier)
text = ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy']
model.start_behavioral_task(task=ArtificialSubject.Task.reading_times)
reading_times = model.digest_text(text)['behavior']
np.testing.assert_allclose(
reading_times,
expected_reading_times,
atol=0.0001)


@pytest.mark.parametrize('model_identifier, expected_next_words', [
('bert-base-uncased', ['eyes', 'into', 'fallen', 'again']),
])
def test_next_word(model_identifier, expected_next_words):
model = load_model(model_identifier)
text = ['The quick brown', 'fox jumps', 'over the', 'lazy dog']
model.start_behavioral_task(task=ArtificialSubject.Task.next_word)
next_word_predictions = model.digest_text(text)['behavior']
np.testing.assert_array_equal(next_word_predictions, expected_next_words)


@pytest.mark.parametrize('model_identifier, feature_size', [
('bert-base-uncased', 768),
])
def test_neural(model_identifier, feature_size):
model = load_model(model_identifier)
text = ['the quick brown fox', 'jumps over', 'the lazy dog']
model.start_neural_recording(recording_target=ArtificialSubject.RecordingTarget.language_system,
recording_type=ArtificialSubject.RecordingType.fMRI)
representations = model.digest_text(text)['neural']
assert len(representations['presentation']) == 3
np.testing.assert_array_equal(representations['stimulus'], text)
assert len(representations['neuroid']) == feature_size
45 changes: 33 additions & 12 deletions tests/test_model_helpers/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,38 @@


class TestNextWord:
@pytest.mark.parametrize('model_identifier, expected_next_word', [
pytest.param('bert-base-uncased', '.', marks=pytest.mark.memory_intense),
pytest.param('gpt2-xl', 'jumps', marks=pytest.mark.memory_intense),
('distilgpt2', 'es'),
@pytest.mark.parametrize('model_identifier, expected_next_word, bidirectional', [
pytest.param('bert-base-uncased', 'and', True, marks=pytest.mark.memory_intense),
pytest.param('bert-base-uncased', '.', False, marks=pytest.mark.memory_intense),
pytest.param('gpt2-xl', 'jumps', False, marks=pytest.mark.memory_intense),
('distilgpt2', 'es', False),
])
def test_single_string(self, model_identifier, expected_next_word):
def test_single_string(self, model_identifier, expected_next_word, bidirectional):
"""
This is a simple test that takes in text = 'the quick brown fox', and tests the next word.
This test is a stand-in prototype to check if our model definitions are correct.
"""

model = HuggingfaceSubject(model_id=model_identifier, region_layer_mapping={})
model = HuggingfaceSubject(model_id=model_identifier, region_layer_mapping={}, bidirectional=bidirectional)
text = 'the quick brown fox'
_logger.info(f'Running {model.identifier()} with text "{text}"')
model.start_behavioral_task(task=ArtificialSubject.Task.next_word)
next_word = model.digest_text(text)['behavior'].values
assert next_word == expected_next_word

@pytest.mark.parametrize('model_identifier, expected_next_words', [
pytest.param('bert-base-uncased', ['.', '.', '.'], marks=pytest.mark.memory_intense),
pytest.param('gpt2-xl', ['jumps', 'the', 'dog'], marks=pytest.mark.memory_intense),
('distilgpt2', ['es', 'the', 'fox']),
@pytest.mark.parametrize('model_identifier, expected_next_words, bidirectional', [
pytest.param('bert-base-uncased', [';', 'the', 'water'], True, marks=pytest.mark.memory_intense),
pytest.param('bert-base-uncased', ['.', '.', '.'], False, marks=pytest.mark.memory_intense),
pytest.param('gpt2-xl', ['jumps', 'the', 'dog'], False, marks=pytest.mark.memory_intense),
('distilgpt2', ['es', 'the', 'fox'], False),
])
def test_list_input(self, model_identifier, expected_next_words):
def test_list_input(self, model_identifier, expected_next_words, bidirectional):
"""
This is a simple test that takes in text = ['the quick brown fox', 'jumps over', 'the lazy'], and tests the
next word for each text part in the list.
This test is a stand-in prototype to check if our model definitions are correct.
"""
model = HuggingfaceSubject(model_id=model_identifier, region_layer_mapping={})
model = HuggingfaceSubject(model_id=model_identifier, region_layer_mapping={}, bidirectional=bidirectional)
text = ['the quick brown fox', 'jumps over', 'the lazy']
_logger.info(f'Running {model.identifier()} with text "{text}"')
model.start_behavioral_task(task=ArtificialSubject.Task.next_word)
Expand Down Expand Up @@ -173,6 +175,25 @@ def test_one_text_single_target(self):
assert len(representations['neuroid']) == 768
_logger.info(f'representation shape is correct: {representations.shape}')

@pytest.mark.memory_intense
def test_one_text_single_target_bidirectional(self):
"""
This is a simple test that takes in text = 'the quick brown fox', and asserts that a bidirectiona BERT model
layer indexed by `representation_layer` has 1 text presentation and 768 neurons. This test is a stand-in prototype
to check if our model definitions are correct.
"""
model = HuggingfaceSubject(model_id='bert-base-uncased', region_layer_mapping={
ArtificialSubject.RecordingTarget.language_system: 'bert.encoder.layer.4'})
text = 'the quick brown fox'
_logger.info(f'Running {model.identifier()} with text "{text}"')
model.start_neural_recording(recording_target=ArtificialSubject.RecordingTarget.language_system,
recording_type=ArtificialSubject.RecordingType.fMRI)
representations = model.digest_text(text)['neural']
assert len(representations['presentation']) == 1
assert representations['stimulus'].squeeze() == text
assert len(representations['neuroid']) == 768
_logger.info(f'representation shape is correct: {representations.shape}')

@pytest.mark.memory_intense
def test_one_text_two_targets(self):
model = HuggingfaceSubject(model_id='distilgpt2', region_layer_mapping={
Expand Down