diff --git a/brainscore_vision/model_helpers/activations/temporal/core/__init__.py b/brainscore_vision/model_helpers/activations/temporal/core/__init__.py index 90f0214ac..59e228fe5 100644 --- a/brainscore_vision/model_helpers/activations/temporal/core/__init__.py +++ b/brainscore_vision/model_helpers/activations/temporal/core/__init__.py @@ -1,3 +1,3 @@ from .extractor import ActivationsExtractor -from .executor import BatchExecutor +from .executor import BatchExecutor, OnlineExecutor from .inferencer import * \ No newline at end of file diff --git a/brainscore_vision/model_helpers/activations/temporal/core/executor.py b/brainscore_vision/model_helpers/activations/temporal/core/executor.py index 9670a85ff..f84bbcb01 100644 --- a/brainscore_vision/model_helpers/activations/temporal/core/executor.py +++ b/brainscore_vision/model_helpers/activations/temporal/core/executor.py @@ -1,5 +1,6 @@ import os import logging +import random import numpy as np from tqdm.auto import tqdm @@ -10,6 +11,13 @@ from brainscore_vision.model_helpers.utils import fullname from joblib import Parallel, delayed +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torchvision import transforms # Import torchvision transforms +from jepa.src.models.attentive_pooler import AttentiveClassifier # Ensure this import path is correct +from brainio.assemblies import NeuroidAssembly # a utility to apply a list of functions to inputs sequentially but only iterate over the first input def _pipeline(*funcs): @@ -217,3 +225,589 @@ def execute(self, layers): self.clear_stimuli() return layer_activations + +# Define the ReadoutModel class +class ReadoutModel(nn.Module): + def __init__(self, embed_dim=252, num_classes=1): + super(ReadoutModel, self).__init__() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.num_classes = num_classes + self.embed_dim = embed_dim + self.linear_layer = None + self.attentive_pooler = AttentiveClassifier(embed_dim=embed_dim, num_classes=num_classes) + + def forward(self, x, mode=None): + x = x.float() + x = x.view(x.shape[0], x.shape[1], -1) # Flatten keeping the last dimension + N, T, D = x.shape + if self.linear_layer is None: + self.linear_layer = nn.Sequential( + nn.Linear(D, self.embed_dim), + nn.ReLU(), + nn.Linear(self.embed_dim, self.embed_dim), + ) + self.linear_layer = self.linear_layer.to(self.device) + x = self.linear_layer(x.flatten(0, 1)) + x = x.view(N, T, self.embed_dim) + x = self.attentive_pooler(x) + if self.num_classes == 1: + return torch.sigmoid(x) + else: + return F.softmax(x, dim=-1) + +# Define the WarmupScheduler class +class WarmupScheduler: + def __init__(self, optimizer, warmup_steps, initial_lr): + self.optimizer = optimizer + self.warmup_steps = warmup_steps + self.initial_lr = initial_lr + self.current_step = 0 + + def step(self): + self.current_step += 1 + if self.current_step < self.warmup_steps: + lr = self.initial_lr * (self.current_step / self.warmup_steps) + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + +# Define the OnlineExecutor class +class OnlineExecutor(BatchExecutor): + """Executor for online processing using a readout model with data augmentation, warmup, validation, and early stopping. + + Parameters + ---------- + get_activations : function + Function that takes a list of processed stimuli and a list of layers, and returns a dictionary of activations. + preprocessing : function + Function that takes a stimulus and returns a processed stimulus. + batch_size: int + Number of stimuli to process in each batch. + batch_padding: bool + Whether to pad the batch with the last stimulus to make it the same size as the specified batch size. + batch_grouper: function + Function that takes a stimulus and returns the property based on which the stimuli can be grouped. + max_workers: int + Number of workers for parallel processing. If None, the number of workers will be the number of CPUs. + readout_model_params: dict + Parameters for initializing the ReadoutModel. + augmentation_function: Callable + Function to apply data augmentation to the batch before processing. + n_epochs: int + Number of epochs to train the readout model. + lr: float + Learning rate for training the readout model. + """ + + def __init__(self, + get_activations: Callable[[List[Any]], Dict[str, np.array]], + preprocessing: Callable[[List[Stimulus]], Any], + batch_size: int, + batch_padding: bool = False, + batch_grouper: Callable[[Stimulus], Hashable] = None, + max_workers: int = None, + augmentation_function: Callable[[torch.Tensor], torch.Tensor] = None, + n_epochs: int = 1000, + lr: float = 1e-3, + num_classes: int = 1): + super().__init__(get_activations, preprocessing, batch_size, batch_padding, batch_grouper, max_workers) + # Initialize the readout model with the given parameters + self.readout_model = ReadoutModel(num_classes=num_classes) + self.n_epochs = n_epochs + self.lr = lr + + # Data augmentation function + self.augmentation_function = self.default_augmentation_function + + # Loss function and optimizer for training the readout model + self.num_classes = num_classes + self.criterion = nn.BCELoss() if self.num_classes == 1 else nn.CrossEntropyLoss() + self.optimizer = optim.AdamW(self.readout_model.parameters(), lr=self.lr) + + # Early stopping parameters + self.patience = 300 + self.best_loss = float('inf') + self.no_improvement_count = 0 + + def default_augmentation_function(self, video_batch): + """ + Default data augmentation function for video inputs using PyTorch transforms. + + Parameters + ---------- + video_batch : torch.Tensor + A batch of videos to augment. + + Returns + ------- + augmented_batch : torch.Tensor + Augmented batch of videos. + """ + # Define the augmentation transformations + transform = transforms.Compose([ + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomResizedCrop(size=(video_batch[0].shape[-2], video_batch[0].shape[-1]), scale=(0.95, 1.0)), + ]) + + augmented_batch = [] + for video in video_batch: + augmented_video = torch.stack([transform(frame) for frame in video]) # Apply augmentation frame-by-frame + augmented_batch.append(augmented_video) + return augmented_batch + + def _get_batches_trainer( + self, + data, + batch_size: int, + padding: bool = False + ): + """Group the data into balanced batches based on the grouper, + with random reuse of the minority class samples to maintain balance. + + Parameters + ---------- + data : array-like + List of data to be grouped, where each item is a tuple (video, label). + batch_size, grouper, padding : int, function, bool, directly set by the class. + + Returns + ------- + indices : list + Indices of the source data after sorting. + masks : list of list + Masks for each batch to indicate whether the datum is a padding sample. + all_batches : list of list + List of batches, where each batch is a list of videos. + all_labels : list of list + List of batches, where each batch is a list of labels corresponding to the videos. + """ + + N = len(data) + + # Separate videos and labels for easier processing + videos, labels, _ = zip(*data) # Unzip the list of tuples into separate lists + videos = np.array(videos, dtype='object') # Convert to numpy array + labels = np.array(labels, dtype='object') + + # Separate positive and negative samples + pos_indices = np.where(labels == 1)[0] + neg_indices = np.where(labels == 0)[0] + pos_videos, pos_labels = videos[pos_indices], labels[pos_indices] + neg_videos, neg_labels = videos[neg_indices], labels[neg_indices] + + sorted_pos_indices = np.arange(len(pos_videos)) # Default sorted indices + sorted_neg_indices = np.arange(len(neg_videos)) # Default sorted indices + + # Sorted videos and labels based on properties + sorted_pos_videos = pos_videos[sorted_pos_indices] + sorted_pos_labels = pos_labels[sorted_pos_indices] + sorted_neg_videos = neg_videos[sorted_neg_indices] + sorted_neg_labels = neg_labels[sorted_neg_indices] + + index_pos, index_neg = 0, 0 + all_batches = [] + all_labels = [] + all_indices = [] + indices = [] + masks = [] + + while index_pos < len(sorted_pos_videos) or index_neg < len(sorted_neg_videos): + batch_videos = [] + batch_labels = [] + batch_indices = [] + + # Fill the batch with an equal number of positives and negatives + while len(batch_videos) < batch_size: + if index_pos < len(sorted_pos_videos) and (len(batch_videos) < batch_size / 2): + batch_videos.append(sorted_pos_videos[index_pos]) + batch_labels.append(sorted_pos_labels[index_pos]) + batch_indices.append(pos_indices[sorted_pos_indices[index_pos]]) + index_pos += 1 + elif index_neg < len(sorted_neg_videos) and (len(batch_videos) < batch_size): + batch_videos.append(sorted_neg_videos[index_neg]) + batch_labels.append(sorted_neg_labels[index_neg]) + batch_indices.append(neg_indices[sorted_neg_indices[index_neg]]) + index_neg += 1 + else: + break + + # If we run out of one class, randomly resample from the minority class + while len(batch_videos) < batch_size: + if len(batch_videos) < batch_size / 2: + # Resample from the positives if we're short on positives + random_index = random.choice(sorted_pos_indices) + batch_videos.append(pos_videos[random_index]) + batch_labels.append(1) + batch_indices.append(pos_indices[random_index]) + else: + # Resample from the negatives if we're short on negatives + random_index = random.choice(sorted_neg_indices) + batch_videos.append(neg_videos[random_index]) + batch_labels.append(0) + batch_indices.append(neg_indices[random_index]) + + if padding: + num_padding = batch_size - len(batch_videos) + if num_padding: + padding_video = batch_videos[-1] # Get the last video for padding + padding_label = batch_labels[-1] # Get the last label for padding + batch_videos += [padding_video] * num_padding # Add padding videos + batch_labels += [padding_label] * num_padding # Add padding labels + else: + num_padding = 0 + + mask = [True] * (len(batch_videos) - num_padding) + [False] * num_padding + + # Shuffle the combined batch + combined = list(zip(batch_videos, batch_labels, batch_indices, mask)) + random.shuffle(combined) + batch_videos, batch_labels, batch_indices, mask = zip(*combined) + + all_batches.append(batch_videos) + all_labels.append(batch_labels) + all_indices.append(batch_indices) + indices.extend(batch_indices) + masks.append(mask) + + return indices, masks, all_batches, all_labels + + + def _get_batches( + self, + data, + batch_size: int, + grouper: Callable[[Stimulus], Hashable] = None, + padding: bool = False + ): + """Group the data into batches based on the grouper. + + Parameters + ---------- + data : array-like + List of data to be grouped, where each item is a tuple (video, label). + batch_size, grouper, padding : int, function, bool, directly set by the class. + + Returns + ------- + indices : list + Indices of the source data after sorting. + masks : list of list + Masks for each batch to indicate whether the datum is a padding sample. + all_batches : list of list + List of batches, where each batch is a list of videos. + all_labels : list of list + List of batches, where each batch is a list of labels corresponding to the videos. + """ + + N = len(data) + + # Separate videos and labels for easier processing + videos, labels, _ = zip(*data) # Unzip the list of tuples into separate lists + videos = np.array(videos, dtype='object') # Convert to numpy array + labels = np.array(labels, dtype='object') + + if grouper is None: + sorted_indices = np.arange(N) # Default sorted indices + sorted_properties = [0] * N # Dummy properties since no grouping is required + else: + properties = np.array([hash(grouper(video)) for video in videos]) # Hash the properties + sorted_indices = np.argsort(properties) + sorted_properties = properties[sorted_indices] + + # Sort videos and labels based on properties + sorted_videos = videos[sorted_indices] + sorted_labels = labels[sorted_indices] + + inverse_indices = np.argsort(sorted_indices) # Inverse transform for retrieving original order + inverse_indices = list(inverse_indices) + + index = 0 + all_batches = [] + all_labels = [] + all_indices = [] + indices = [] + masks = [] + while index < N: + property_anchor = sorted_properties[index] + batch_videos = [] + batch_labels = [] + while index < N and len(batch_videos) < batch_size and sorted_properties[index] == property_anchor: + batch_videos.append(sorted_videos[index]) + batch_labels.append(sorted_labels[index]) + index += 1 + + batch_indices = inverse_indices[index - len(batch_videos):index] + + if padding: + num_padding = batch_size - len(batch_videos) + if num_padding: + padding_video = batch_videos[-1] # Get the last video for padding + padding_label = batch_labels[-1] # Get the last label for padding + batch_videos += [padding_video] * num_padding # Add padding videos + batch_labels += [padding_label] * num_padding # Add padding labels + else: + num_padding = 0 + + masks.append([True] * (len(batch_videos) - num_padding) + [False] * num_padding) + + all_batches.append(batch_videos) + all_labels.append(batch_labels) + all_indices.append(batch_indices) + indices.extend(batch_indices) + + return indices, masks, all_batches, all_labels + + def pad_tensors_to_max_length(self, tensors): + """ + Pads a list of tensors to the maximum length along the time dimension (T). + + Parameters + ---------- + tensors : list of torch.Tensor + A list of tensors, each of shape (3, T, 224, 224) where T may vary. + + Returns + ------- + padded_tensors : list of torch.Tensor + A list of tensors, each padded to the maximum length along the time dimension. + """ + # Determine the maximum length along the time dimension (T) + max_length = max(tensor.size(1) for tensor in tensors) # max(T) + + # Pad each tensor along the time dimension to the maximum length + padded_tensors = [] + for tensor in tensors: + C, T, H, W = tensor.shape + # Calculate the amount of padding needed + pad_length = max_length - T + if pad_length > 0: + # Pad tensor with zeros to make it (3, max_length, 224, 224) + padding = (0, 0, 0, 0, 0, pad_length) # (width, height, time) + padded_tensor = torch.nn.functional.pad(tensor, padding, mode='constant', value=0) + else: + padded_tensor = tensor + + padded_tensors.append(padded_tensor) + + return padded_tensors + + def get_dummy_activations(self, model_inputs, layers, features_dict): + """ + Generates a dictionary of dummy activations with zero tensors matching the shape of the model outputs. + + Parameters + ---------- + model_inputs : list + The preprocessed input stimuli. + layers : list + The list of layers for which activations are required. + + Returns + ------- + batch_activations : OrderedDict + A dictionary where each key corresponds to a layer and each value is a zero tensor matching + the shape of the model's output for that layer. + """ + batch_activations = OrderedDict() + for layer in layers: + # Simulate a dummy output with zeros of the appropriate shape + dummy_shape = (len(model_inputs), *features_dict[layer].shape[1:]) # Adjust this to the actual output shape needed + batch_activations[layer] = np.zeros(dummy_shape) + return batch_activations + + def execute(self, layers, train=False): + if train: + return self.execute_train(layers) + return self.execute_test(layers) + + def execute_train(self, layers): + seed = 42 # You can choose any integer you like + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # If using multi-GPU + + # Ensure deterministic behavior for CUDA operations + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + indices, masks, batches, labels = self._get_batches_trainer(self.stimuli, self.batch_size, + padding=self.batch_padding) + + before_pipe = _pipeline(*self.before_hooks) + after_pipe = _pipeline(*self.after_hooks) + + shuffled_data = list(zip(indices, masks, batches, labels)) + random.shuffle(shuffled_data) + indices, masks, batches, labels = zip(*shuffled_data) + + # Split data into training and validation + validation_split = 0.1 + num_validation_batches = int(validation_split * len(batches)) + train_batches = list(zip(masks, batches, labels))[:-num_validation_batches] + val_batches = list(zip(masks, batches, labels))[-num_validation_batches:] + + # Training loop with early stopping + # Initialize warmup scheduler and cosine annealing scheduler + self.total_steps = self.n_epochs * len(self.stimuli) // self.batch_size + self.warmup_steps = int(0.05 * self.total_steps) + self.warmup_scheduler = WarmupScheduler(self.optimizer, self.warmup_steps, self.lr) + self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.total_steps - self.warmup_steps, eta_min=0) + + # Check for GPU availability + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Move the model to GPU and enable data parallelism if multiple GPUs are available + self.readout_model.to(device) + #if torch.cuda.device_count() > 1: + # self.readout_model = torch.nn.DataParallel(self.readout_model) + + self.best_accuracy = 0 + for epoch in range(self.n_epochs): + epoch_loss = 0.0 + total_correct = 0 + total_steps, total_samples = 0, 0 + self.readout_model.train() + + # Training Phase + with tqdm(train_batches, desc=f"Training Epoch {epoch+1}", total=len(train_batches)) as pbar: + for idx, (mask, batch, label) in enumerate(pbar): + batch = [before_pipe(stimulus) for stimulus in batch] + model_inputs = self._mapper.map(self.preprocess, batch) + model_inputs = self.pad_tensors_to_max_length(model_inputs) + if self.augmentation_function: + model_inputs = self.augmentation_function(model_inputs) # Use the augmentation function on the batch + features_dict = self.get_activations(model_inputs, layers) + for layer, features in features_dict.items(): + features_tensor = torch.tensor(np.stack(features)).to(device) + labels = torch.tensor(label).to(device) + + self.optimizer.zero_grad() + readout_outputs = self.readout_model(features_tensor) + readout_outputs = readout_outputs.squeeze() + loss = self.criterion(readout_outputs, labels.float()) + loss.backward() + self.optimizer.step() + + epoch_loss += loss.item() + + # Calculate accuracy + predicted = (readout_outputs > 0.5).float() # Assuming a binary classification with threshold 0.5 + correct = (predicted == labels.float()).sum().item() + total_correct += correct + total_samples += labels.size(0) + total_steps += 1 + + # Update tqdm bar with the average accuracy + average_accuracy = total_correct / total_samples + pbar.set_postfix({'loss': epoch_loss / total_steps, 'accuracy': average_accuracy * 100}) + + self.warmup_scheduler.step() # Apply warmup scheduler + + self.scheduler.step() # Apply cosine scheduler + + # Validation Phase + val_loss = 0.0 + total_correct = 0 + total_samples = 0 + self.readout_model.eval() + + with torch.no_grad(): + with tqdm(val_batches, desc="Validation", total=len(val_batches)) as pbar: + for mask, batch, label in pbar: + batch = [before_pipe(stimulus) for stimulus in batch] + model_inputs = self._mapper.map(self.preprocess, batch) + model_inputs = self.pad_tensors_to_max_length(model_inputs) + features_dict = self.get_activations(model_inputs, layers) + + for layer, features in features_dict.items(): + features_tensor = torch.tensor(np.stack(features)).to(device) + labels = torch.tensor(label).to(device) + readout_outputs = self.readout_model(features_tensor) + readout_outputs = readout_outputs.squeeze() + # Compute loss + loss = self.criterion(readout_outputs, labels.float()) + val_loss += loss.item() + + # Calculate accuracy + predicted = (readout_outputs > 0.5).float() # Assuming binary classification + correct = (predicted == labels.float()).sum().item() + total_correct += correct + total_samples += labels.size(0) + + # Update tqdm bar with loss and accuracy + avg_val_loss = val_loss / (total_samples / labels.size(0)) + avg_val_accuracy = total_correct / total_samples * 100 + pbar.set_postfix({'loss': avg_val_loss, 'accuracy': avg_val_accuracy}) + + # Compute the average loss and accuracy + avg_val_loss = val_loss / len(val_batches) + avg_val_accuracy = total_correct / total_samples * 100 + print(f"Epoch {epoch+1} completed with training loss: {epoch_loss/len(train_batches)}, validation loss: {avg_val_loss}, validation accuracy: {avg_val_accuracy:.2f}%") + self._logger.info(f"Epoch {epoch+1} completed with training loss: {epoch_loss/len(train_batches)}, validation loss: {avg_val_loss}, validation accuracy: {avg_val_accuracy:.2f}%") + + # Early stopping check based on accuracy + if avg_val_accuracy > self.best_accuracy: + self.best_accuracy = avg_val_accuracy + self.no_improvement_count = 0 + torch.save(self.readout_model.state_dict(), 'transformer_readout.pt') + print(f"New best model saved with validation accuracy: {avg_val_accuracy:.2f}%") + self._logger.info(f"New best model saved with validation accuracy: {avg_val_accuracy:.2f}%") + else: + self.no_improvement_count += 1 + if self.no_improvement_count >= self.patience: + self._logger.info("Early stopping triggered.") + break + + indices, masks, batches, labels = self._get_batches(self.stimuli, self.batch_size, + grouper=self.batch_grouper, + padding=self.batch_padding) + + # Final execution with dummy activations + layer_activations = OrderedDict() + for mask, batch in tqdm(zip(masks, batches), desc="activations", total=len(batches)): + batch = [before_pipe(stimulus) for stimulus in batch] + model_inputs = self._mapper.map(self.preprocess, batch) + + # Get dummy activations with zero tensors that match the shape of model outputs + batch_activations = self.get_dummy_activations(model_inputs, layers, features_dict) + assert isinstance(batch_activations, OrderedDict) + + for layer, activations in batch_activations.items(): + results = [after_pipe(arr, layer, stimulus) + for not_pad, arr, stimulus in zip(mask, activations, batch) + if not_pad] + layer_activations.setdefault(layer, []).extend(results) + + # Reorganize activations in the original order + for layer, activations in layer_activations.items(): + layer_activations[layer] = [activations[i] for i in indices] + + self.clear_stimuli() + return layer_activations + + def execute_test(self, layers): + indices, masks, batches, labels = self._get_batches(self.stimuli, self.batch_size, + grouper=self.batch_grouper, + padding=self.batch_padding) + + before_pipe = _pipeline(*self.before_hooks) + after_pipe = _pipeline(*self.after_hooks) + + layer_activations = OrderedDict() + for mask, batch in tqdm(zip(masks, batches), desc="activations", total=len(batches)): + batch = [before_pipe(stimulus) for stimulus in batch] + model_inputs = self._mapper.map(self.preprocess, batch) + batch_activations = self.get_activations(model_inputs, layers) + assert isinstance(batch_activations, OrderedDict) + for layer, activations in batch_activations.items(): + results = [after_pipe(arr, layer, stimulus) + for not_pad, arr, stimulus in zip(mask, activations, batch) + if not_pad] + layer_activations.setdefault(layer, []).extend(results) + + for layer, activations in layer_activations.items(): + layer_activations[layer] = [activations[i] for i in indices] + + self.clear_stimuli() + return layer_activations diff --git a/brainscore_vision/model_helpers/activations/temporal/core/extractor.py b/brainscore_vision/model_helpers/activations/temporal/core/extractor.py index d85783b4f..a0b0f71e1 100644 --- a/brainscore_vision/model_helpers/activations/temporal/core/extractor.py +++ b/brainscore_vision/model_helpers/activations/temporal/core/extractor.py @@ -13,6 +13,7 @@ from brainscore_vision.model_helpers.utils import fullname from result_caching import store_xarray from .inferencer import Inferencer +from .inferencer.video import OnlineTemporalInferencer from ..inputs import Stimulus @@ -69,7 +70,9 @@ def __call__( if number_of_trials is not None and (number_of_trials > 1 or require_variance): self._logger.warning("CAUTION: number_of_trials > 1 or require_variance=True is not supported yet. " "Bypassing...") - if isinstance(stimuli, StimulusSet): + if isinstance(self.inferencer, OnlineTemporalInferencer): + return self.online_stimulus_set(stimulus_set=stimuli, layers=layers, stimuli_identifier=stimuli_identifier) + elif isinstance(stimuli, StimulusSet): return self.from_stimulus_set(stimulus_set=stimuli, layers=layers, stimuli_identifier=stimuli_identifier) else: return self.from_paths(stimuli_paths=stimuli, layers=layers, stimuli_identifier=stimuli_identifier) @@ -116,6 +119,29 @@ def from_paths( activations = self._expand_paths(activations, original_paths=stimuli_paths) return activations + def online_stimulus_set( + self, + stimulus_set : StimulusSet, + layers : List[str], + stimuli_identifier : str = None, + ): + """ + :param stimuli_identifier: a stimuli identifier for the stored results file. + False to disable saving. None to use `stimulus_set.identifier` + """ + if stimuli_identifier is None and hasattr(stimulus_set, 'identifier'): + stimuli_identifier = stimulus_set.identifier + for hook in self._stimulus_set_hooks.copy().values(): # copy to avoid stale handles + stimulus_set = hook(stimulus_set) + stimuli_paths = [(str(stimulus_set.get_stimulus(stimulus_id)), label, train_flag) + for (stimulus_id, label, train_flag) in zip(stimulus_set['stimulus_id'], + stimulus_set['label'], + stimulus_set['train_flag']) + ] + activations = self._from_paths(stimuli_paths=stimuli_paths, layers=layers) + activations = attach_stimulus_set_meta(activations, stimulus_set) + return activations + @store_xarray(identifier_ignore=['stimuli_paths', 'layers'], combine_fields={'layers': 'layer'}) def _from_paths_stored(self, identifier, layers, stimuli_identifier, stimuli_paths): stimuli_paths.sort() diff --git a/brainscore_vision/model_helpers/activations/temporal/core/inferencer/base.py b/brainscore_vision/model_helpers/activations/temporal/core/inferencer/base.py index 118845d8d..05f51b8de 100644 --- a/brainscore_vision/model_helpers/activations/temporal/core/inferencer/base.py +++ b/brainscore_vision/model_helpers/activations/temporal/core/inferencer/base.py @@ -11,7 +11,7 @@ from brainio.assemblies import NeuroidAssembly, walk_coords from brainscore_vision.model_helpers.utils import fullname -from brainscore_vision.model_helpers.activations.temporal.core.executor import BatchExecutor +from brainscore_vision.model_helpers.activations.temporal.core.executor import BatchExecutor, OnlineExecutor from brainscore_vision.model_helpers.activations.temporal.utils import stack_with_nan_padding, batch_2d_resize from brainscore_vision.model_helpers.activations.temporal.inputs import Stimulus @@ -84,6 +84,8 @@ def __init__( batch_grouper : Callable[[Stimulus], Hashable] = None, batch_padding : bool = False, max_workers : int = None, + online_execution : bool = False, + num_classes : int = 1, *args, **kwargs ): @@ -96,6 +98,9 @@ def __init__( self.visual_degrees = visual_degrees self.dtype = dtype self._executor = BatchExecutor(get_activations, preprocessing, batch_size, batch_padding, batch_grouper, max_workers) + if online_execution: + self._executor = OnlineExecutor(get_activations, preprocessing, batch_size, batch_padding, batch_grouper, + max_workers, num_classes) self._stimulus_set_hooks = {} self._batch_activations_hooks = {} self._logger = logging.getLogger(fullname(self)) diff --git a/brainscore_vision/model_helpers/activations/temporal/core/inferencer/video/__init__.py b/brainscore_vision/model_helpers/activations/temporal/core/inferencer/video/__init__.py index 5d99e466a..3f4970e15 100644 --- a/brainscore_vision/model_helpers/activations/temporal/core/inferencer/video/__init__.py +++ b/brainscore_vision/model_helpers/activations/temporal/core/inferencer/video/__init__.py @@ -1,2 +1,2 @@ -from .base import TemporalInferencer +from .base import TemporalInferencer, OnlineTemporalInferencer from .temporal_context import * \ No newline at end of file diff --git a/brainscore_vision/model_helpers/activations/temporal/core/inferencer/video/base.py b/brainscore_vision/model_helpers/activations/temporal/core/inferencer/video/base.py index e1cf7a9fa..3098a6c48 100644 --- a/brainscore_vision/model_helpers/activations/temporal/core/inferencer/video/base.py +++ b/brainscore_vision/model_helpers/activations/temporal/core/inferencer/video/base.py @@ -1,11 +1,17 @@ import numpy as np from typing import Union, Tuple, Callable, Hashable, List, Dict from pathlib import Path +from tqdm.auto import tqdm +from collections import OrderedDict + +import gc from brainscore_vision.model_helpers.activations.temporal.inputs import Video, Stimulus from brainscore_vision.model_helpers.activations.temporal.utils import assembly_align_to_fps, stack_with_nan_padding from brainio.assemblies import NeuroidAssembly +from brainscore_vision.model_helpers.activations.temporal.core.executor import OnlineExecutor + from ..base import Inferencer from . import time_aligner as time_aligners @@ -132,3 +138,152 @@ def _check_video(self, video: Video): assert self.num_frames[0] <= estimated_num_frames <= self.num_frames[1], f"The number of frames must be within {self.num_frames}, but got {estimated_num_frames}" if self.duration is not None: assert self.duration[0] <= video.duration <= self.duration[1], f"The duration must be within {self.duration}, but got {video.duration}" + + +class OnlineTemporalInferencer(Inferencer): + """Inferencer for video stimuli. The model takes video stimuli as input and generate the activations over time. + Then, the activations will be aligned to video time by the time_aligner specified in the constructor. The aligned + activations will be again unified to the fps specified within the constructor (self.fps). Finally, the activations + will be packaged into a NeuroidAssembly. + + NOTE: for all the time_alignment method, the inference of time bins will only be done with the longest video, but ignore all other input videos. + + Example: + temporal_inferencer = TemporalInferenver(..., fps=10) + model_assembly = temporal_inferencer(video_paths[1000ms], layers) + model_assembly.time_bins -> [(0, 100), (100, 200), ..., (900, 1000)] # 1000ms, 10fps + + Parameters + ---------- + fps: float + frame rate of the model sampling. + + num_frames: int, or (int, int) + - If None, the model accepts videos of any length. + - If a single int is passed, specify how many frames the model takes. + - If a tuple of two ints is passed, specify the range of the number of frames the model takes (inclusive). If you need to specify infinite, use np.inf. + + duration: float, or (float, float) + - If None, the model accepts videos of any length. + - If a single float is passed, specify the duration of the model takes, in ms. + - If a tuple of two floats is passed, specify the range of the duration the model takes (inclusive). If you need to specify infinite, use np.inf. + + time_alignment: str + specify the method to align the activations in time. + The options and specifications are in the time_aligners module. The current options are: + - evenly_spaced: align the activations to have evenly spaced time bins across the whole video time span. + - ignore_time: ignore the time information and make a single time bin of the entire video. + - estimate_layer_fps: estimate the fps of the layer based on the video fps. + - per_frame_aligned: align the activations to the video frames. + + convert_img_to_video: bool + whether to convert the input images to videos. + img_duration: float + specify the duration of the images, in ms. This will work only if convert_img_to_video is True. + batch_size: int + number of stimuli to process in each batch. + batch_grouper: function + function that takes a stimulus and return the property based on which the stimuli can be grouped. + """ + def __init__( + self, + *args, + fps : float, + num_frames : Union[int, Tuple[int, int]] = None, + duration : Union[float, Tuple[float, float]] = None, + time_alignment : str = "evenly_spaced", + convert_img_to_video : bool = True, + img_duration : float = 1000.0, + batch_size : int = 32, + online_execution: bool = False, + batch_grouper : Callable[[Video], Hashable] = lambda video: (round(video.duration, 6), video.fps), # not including video.frame_size because most preprocessors will change the frame size to be the same + **kwargs, + ): + super().__init__(*args, stimulus_type=Video, batch_size=batch_size, + batch_grouper=batch_grouper, online_execution=online_execution, + **kwargs) + # Initialize the executor with the chosen class + self.fps = fps + self.num_frames = self._make_range(num_frames, type="num_frames") + self.duration = self._make_range(duration, type="duration") + assert hasattr(time_aligners, time_alignment), f"Unknown time alignment method: {time_alignment}" + self.time_aligner = getattr(time_aligners, time_alignment) + + if convert_img_to_video: + assert img_duration is not None, "img_duration should be specified if convert_img_to_video is True" + self.img_duration = img_duration + self.convert_to_video = convert_img_to_video + + @property + def identifier(self) -> str: + id = f"{super().identifier}.{self.time_aligner.__name__}.fps={float(self.fps)}" + if self.convert_to_video: + id += f".img_dur={float(self.img_duration)}" + return id + + def load_stimulus(self, path: Union[str, Path]) -> Video: + path, label, train_flag = path + if self.convert_to_video and Stimulus.is_image_path(path): + video = Video.from_img_path(path, self.img_duration, self.fps) + else: + video = Video.from_path(path) + video = video.set_fps(self.fps) + self._check_video(video) + return video, label, train_flag + + # given the paths of the stimuli and the layers, return the model activations as a NeuroidAssembly + def __call__(self, paths: List[Union[str, Path]], layers: List[str]): + stimuli = self.load_stimuli(paths) + paths = [path for (path, label, train_flag) in paths] + layer_activations = self.inference(stimuli, layers) + stimuli = [stim for (stim, label, train_flag) in stimuli] + layer_assemblies = OrderedDict() + for layer in tqdm(layers, desc="Packaging layers"): + layer_assemblies[layer] = self.package_layer(layer_activations[layer], self.layer_activation_format[layer], stimuli) + del layer_activations[layer] + gc.collect() # reduce memory usage + model_assembly = self.package(layer_assemblies, paths) + return model_assembly + + # process the list of stimulus and return the activations (list of np.array, + # whose length is the number of stimuli) of the specified layers + def inference(self, stimuli : List[Stimulus], layers : List[str]) -> Dict[str, List[np.array]]: + self._executor.add_stimuli(stimuli) + train_flag = stimuli[0][2] + return self._executor.execute(layers, train_flag) + + def package_layer( + self, + layer_activations : List[np.array], + layer_spec : str, + stimuli : List[Stimulus] + ): + assert len(layer_activations) == len(stimuli) + longest_stimulus = stimuli[np.argmax(np.array([stimulus.duration for stimulus in stimuli]))] + ignore_time = self.time_aligner is time_aligners.ignore_time + channels = self._map_dims(layer_spec) + layer_activations = stack_with_nan_padding(layer_activations) + assembly = self._package(layer_activations, ["stimulus_path"] + channels) + # align to the longest stimulus + assembly = self.time_aligner(assembly, longest_stimulus) + if "channel_temporal" in channels and not ignore_time: + channels.remove("channel_temporal") + assembly = self._stack_neuroid(assembly, channels) + if not ignore_time: + assembly = assembly_align_to_fps(assembly, self.fps) + return assembly + + def _make_range(self, num, type="num_frames"): + if num is None: + return (1 if type=='num_frames' else 0, np.inf) + if isinstance(num, (tuple, list)): + return num + else: + return (num, num) + + def _check_video(self, video: Video): + if self.num_frames is not None: + estimated_num_frames = int(self.fps * video.duration / 1000) + assert self.num_frames[0] <= estimated_num_frames <= self.num_frames[1], f"The number of frames must be within {self.num_frames}, but got {estimated_num_frames}" + if self.duration is not None: + assert self.duration[0] <= video.duration <= self.duration[1], f"The duration must be within {self.duration}, but got {video.duration}" diff --git a/brainscore_vision/model_helpers/brain_transformation/__init__.py b/brainscore_vision/model_helpers/brain_transformation/__init__.py index a1a3c1dc7..56dc6c51f 100644 --- a/brainscore_vision/model_helpers/brain_transformation/__init__.py +++ b/brainscore_vision/model_helpers/brain_transformation/__init__.py @@ -2,7 +2,7 @@ from brainscore_vision.model_helpers.brain_transformation.temporal import TemporalAligned from brainscore_vision.model_interface import BrainModel from brainscore_vision.utils import LazyLoad -from .behavior import BehaviorArbiter, LabelBehavior, ProbabilitiesMapping, OddOneOut +from .behavior import BehaviorArbiter, LabelBehavior, ProbabilitiesMapping, OddOneOut, VideoReadoutMapping from .neural import LayerMappedModel, LayerSelection, LayerScores STANDARD_REGION_BENCHMARKS = { @@ -21,7 +21,7 @@ class ModelCommitment(BrainModel): def __init__(self, identifier, activations_model, layers, behavioral_readout_layer=None, region_layer_map=None, - visual_degrees=8): + visual_degrees=8, num_classes=1): self.layers = layers self.activations_model = activations_model # We set the visual degrees of the ActivationsExtractorHelper here to avoid changing its signature. @@ -46,9 +46,12 @@ def __init__(self, identifier, layer=behavioral_readout_layer) odd_one_out = OddOneOut(identifier=identifier, activations_model=activations_model, layer=behavioral_readout_layer) + video_readout_behavior = VideoReadoutMapping(identifier=identifier, activations_model=activations_model, + layer=behavioral_readout_layer, num_classes=num_classes) self.behavior_model = BehaviorArbiter({BrainModel.Task.label: logits_behavior, BrainModel.Task.probabilities: probabilities_behavior, BrainModel.Task.odd_one_out: odd_one_out, + BrainModel.Task.video_readout: video_readout_behavior, }) self.do_behavior = False @@ -62,9 +65,9 @@ def start_task(self, task: BrainModel.Task, *args, **kwargs): else: self.do_behavior = False - def look_at(self, stimuli, number_of_trials: int = 1, require_variance: bool = False): + def look_at(self, stimuli, number_of_trials: int = 1, require_variance: bool = False, **kwargs): if self.do_behavior: - return self.behavior_model.look_at(stimuli, number_of_trials=number_of_trials, require_variance=require_variance) + return self.behavior_model.look_at(stimuli, number_of_trials=number_of_trials, require_variance=require_variance, **kwargs) else: return self.layer_model.look_at(stimuli, number_of_trials=number_of_trials) @@ -94,4 +97,4 @@ def __getitem__(self, region): def commit_region(self, region): benchmark = self.region_benchmarks[region] best_layer = self.layer_selection(selection_identifier=region, benchmark=benchmark) - self[region] = best_layer + self[region] = best_layer \ No newline at end of file diff --git a/brainscore_vision/model_helpers/brain_transformation/behavior.py b/brainscore_vision/model_helpers/brain_transformation/behavior.py index 2c01012a0..6bd7325b4 100644 --- a/brainscore_vision/model_helpers/brain_transformation/behavior.py +++ b/brainscore_vision/model_helpers/brain_transformation/behavior.py @@ -2,16 +2,23 @@ from collections import OrderedDict from typing import Union, List +import math import numpy as np import pandas as pd import xarray as xr +import random import sklearn.linear_model import sklearn.multioutput +import torch +import torch.nn as nn +import torch.nn.functional as F -from brainio.assemblies import walk_coords, array_is_element, BehavioralAssembly, DataAssembly +from brainio.assemblies import walk_coords, array_is_element, BehavioralAssembly, DataAssembly, NeuroidAssembly from brainio.stimuli import StimulusSet from brainscore_vision.model_helpers.utils import make_list from brainscore_vision.model_interface import BrainModel +from torch.utils.data import Dataset +from torch.utils.data.sampler import Sampler class BehaviorArbiter(BrainModel): @@ -346,3 +353,447 @@ def calculate_choices(self, similarity_matrix, triplets): idx = triplet[2 - np.argmax(sims)] choice_predictions.append(idx) return choice_predictions + +from brainscore_vision.model_helpers.activations.temporal.core.inferencer.video import OnlineTemporalInferencer + +class VideoReadoutMapping(BrainModel): + def __init__(self, identifier, activations_model, layer, num_classes=1): + """ + :param identifier: a string to identify the model + :param activations_model: the model from which to retrieve representations for stimuli + :param layer: the single behavioral readout layer or a list of layers to read out of. + """ + self._identifier = identifier + self.activations_model = activations_model + self.readout = make_list(layer) + self.classifier = None + self.current_task = None + self.num_classes = num_classes + self.n_time_bins = None + + @property + def identifier(self): + return self._identifier + + def start_task(self, task: BrainModel.Task, fitting_stimuli, simulation=False): + assert task in BrainModel.Task.video_readout + self.current_task = task + fitting_stimuli['train_flag'] = [True]*len(fitting_stimuli['stimulus_id']) + fitting_features = self.activations_model(fitting_stimuli, layers=self.readout) + assert all(fitting_features['stimulus_id'].values == fitting_stimuli['stimulus_id']), \ + "stimulus_id ordering is incorrect" + fitting_features = fitting_features.transpose('presentation', 'time_bin', 'neuroid') + self.classifier = VideoReadoutMapping.TransformerReadout(np.prod(fitting_features.shape[2:]), + self.num_classes) + + if not isinstance(self.activations_model._extractor.inferencer, OnlineTemporalInferencer): + self.n_time_bins = fitting_features.shape[1] + if self.num_classes == 1: + self.classifier.fit(fitting_features, + fitting_features['label'].values) + else: + self.classifier.fit(fitting_features, + fitting_features['contacts'].values) + + def look_at(self, stimuli, number_of_trials=1, require_variance=False, simulation=False): + stimuli['train_flag'] = [False]*len(stimuli['stimulus_id']) + if not simulation: + features = self.activations_model(stimuli, layers=self.readout) + prediction = self.classifier.predict(features) + return prediction + else: + assert '-SIM' in self._identifier, "model is unable to do simulation" + features = self.activations_model(stimuli, layers=self.readout) + + # Determine the number of splits + n_time_bins = self.n_time_bins + n_splits = features['time_bin'].size // n_time_bins + + # Split the features into smaller chunks along the time_bin dimension + splits = [features.isel(time_bin=slice(i * n_time_bins, (i + 1) * n_time_bins)) for i in range(n_splits)] + + # Prepare the data for each split + split_data = [] + for i, split in enumerate(splits): + # Duplicate the presentation dimensions with updated indices + new_split = split.copy() + + # Update the stimulus_id by inserting the index before the '.mp4' extension + new_stimulus_ids = [] + for id in features['stimulus_id'].values: + if id.endswith(".mp4"): # Check if the stimulus ID ends with '.mp4' + base_id = id.rsplit('_', 1)[0] # Split to remove the last part after the underscore + new_id = f"{base_id}_img_snippet_{i}.mp4" # Insert index before the '.mp4' extension + else: + new_id = id # If it doesn't match the expected pattern, keep it unchanged + new_stimulus_ids.append(new_id) + + # Create a new NeuroidAssembly + layer_assembly = NeuroidAssembly(split.values, + coords= + {'stimulus_id': ('presentation', new_stimulus_ids), + 'length': ('presentation', [n_time_bins] * features['length'].values.shape[0]), + 'label': ('presentation', features['label'].values)}, + dims=['neuroid', 'time_bin', 'presentation']) + + split_data.append(layer_assembly) + + # Combine all splits back into a single xarray dataset or NeuroidAssembly + combined_features = xr.concat(split_data, dim='presentation') + prediction = self.classifier.predict(combined_features) + return prediction + + class TransformerReadout: + def __init__(self, model_dim, num_classes=1): + super(VideoReadoutMapping.TransformerReadout, self).__init__() + self.model = VideoReadoutMapping.ReadoutModel(model_dim, num_classes=num_classes) + self.num_classes = num_classes + self.num_epochs = 1000 + self.lr = 1e-4 + self.val_after = 5 + self.best_val_accuracy = 0 + self.convergence_thresh = 20 + self.counter_converge = 0 + self.prob_threshold = 0.5 + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + self.set_seed(43) + + def set_seed(self, seed_value=42): + """Set seed for reproducibility.""" + torch.manual_seed(seed_value) + torch.cuda.manual_seed_all(seed_value) # For CUDA devices + np.random.seed(seed_value) + random.seed(seed_value) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + def build_loader(self, features, labels, mode='train', indices=None): + if mode == 'train': + indices = list(range(features.shape[0])) + random.shuffle(indices) + split_point = int(0.9 * len(indices)) + train_indices = indices[:split_point] + val_indices = indices[split_point:] + + train_dataset = VideoReadoutMapping.TransformerLoader(features, labels, + indices=train_indices, + num_classes=self.num_classes) + + if self.num_classes == 1: + sampler = VideoReadoutMapping.BalancedBatchSampler(train_dataset.positive_indices, + train_dataset.negative_indices, + batch_size=128, seed=42) + + train_loader = VideoReadoutMapping.MultiEpochsDataLoader(train_dataset, + batch_sampler=sampler, + num_workers=4) + else: + train_loader = VideoReadoutMapping.MultiEpochsDataLoader(train_dataset, + batch_size=128, + shuffle=True, + num_workers=4) + + val_dataset = VideoReadoutMapping.TransformerLoader(features, labels, indices=val_indices, + num_classes=self.num_classes) + val_loader = VideoReadoutMapping.MultiEpochsDataLoader(val_dataset, + batch_size=128, + shuffle=False, + num_workers=4) + else: + train_dataset = VideoReadoutMapping.TransformerLoader(features, None, num_classes=self.num_classes) + train_loader = VideoReadoutMapping.MultiEpochsDataLoader(train_dataset, + batch_size=128, shuffle=False, + num_workers=4) + val_loader = None + return train_loader, val_loader + + def fit(self, features, labels): + train_loader, val_loader = self.build_loader(features, labels, mode='train') + self.model = self.model.to(self.device) + if torch.cuda.device_count() > 1: + self.model = nn.DataParallel(self.model) + # Optimizer + optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr) + criterion = nn.BCELoss() + if self.num_classes != 1: + criterion = nn.CrossEntropyLoss() + + # Define the total number of steps + total_steps = self.num_epochs * len(train_loader) + warmup_steps = int(0.05 * total_steps) # for example, 10% of total steps + + # Initialize the warmup scheduler + warmup_scheduler = VideoReadoutMapping.WarmupScheduler(optimizer, warmup_steps, initial_lr=self.lr) + + # Replace the ReduceLROnPlateau scheduler with CosineAnnealingLR + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=0) + + for epoch in range(self.num_epochs): + train_loss, train_acc = self.train(train_loader, optimizer, criterion, + epoch, scheduler, warmup_scheduler, warmup_steps) + # optimize for prob + if epoch % self.val_after == 0: + val_accuracy = self.validate(val_loader) + print(f'Epoch:{epoch+1}, Val Accuracy:{val_accuracy*100:.2f}%') + + if val_accuracy > self.best_val_accuracy: + self.counter_converge = 0 + self.best_val_accuracy = val_accuracy + print(f'Saving best model with val accuracy:{val_accuracy:.5f}') + torch.save(self.model.state_dict(), 'transformer_readout.pt') + else: + self.counter_converge += 1 + + if self.counter_converge >= self.convergence_thresh: + break + + # optimize prob threshold + + def binary_accuracy(self, preds, y): + # Round predictions to the closest integer (0 or 1) + if self.num_classes == 1: + rounded_preds = (preds > self.prob_threshold).float() + else: + rounded_preds = torch.argmax(preds, dim=-1) + correct = (rounded_preds == y).float() # Convert into float for division + acc = correct.sum() / len(correct) + return acc + + def train(self, data_loader, optimizer, criterion, + epoch, scheduler, warmup_scheduler, + warmup_steps, log_step=1): + self.model.train() + total_loss = 0 + total_acc = 0 + + for batch_idx, data in enumerate(data_loader): + # Warmup for the initial warmup_steps + if epoch * len(data_loader) + batch_idx < warmup_steps: + warmup_scheduler.step() + else: + # Once the warmup steps are completed, use cosine annealing + scheduler.step(epoch * len(data_loader) + batch_idx - warmup_steps) + + inputs, labels = data['feature'], data['label'] + inputs, targets = inputs.to(self.device), labels.to(self.device) + + optimizer.zero_grad() + outputs = self.model(inputs) # Ensure the output is of correct shape + outputs = outputs.squeeze() + if self.num_classes == 1: + loss = criterion(outputs, targets.float()) + else: + loss = criterion(outputs, targets.long()) + acc = self.binary_accuracy(outputs, targets) + + if batch_idx % log_step == 0: + print(f'Epoch:{epoch+1}, Step: [{batch_idx}/{len(data_loader)}], Train Accuracy:{acc:.5f}') + + loss.backward() + optimizer.step() + + total_loss += loss.item() + total_acc += acc.item() + + # Calculate the average loss and accuracy over all batches + avg_loss = total_loss / len(data_loader) + avg_acc = total_acc / len(data_loader) + + return avg_loss, avg_acc + + def validate(self, val_loader): + self.model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for data in val_loader: + inputs, labels = data['feature'], data['label'] + inputs, labels = inputs.to(self.device), labels.to(self.device) + outputs = self.model(inputs) + outputs = outputs.squeeze(dim=1) + if self.num_classes == 1: + predicted = (outputs > self.prob_threshold).float() + else: + predicted = torch.argmax(outputs, dim=-1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + self.model.train() + return correct / total + + def predict(self, features): + features = features.transpose('presentation', 'time_bin', 'neuroid') + test_loader, _ = self.build_loader(features, None, mode='test') + self.model.load_state_dict(torch.load('transformer_readout.pt')) + self.model = self.model.to(self.device) + if torch.cuda.device_count() > 1: + self.model = torch.nn.DataParallel(self.model) + + self.model.eval() # Set the model to evaluation mode + predictions, proba = [], [] + + with torch.no_grad(): # No gradients needed + for data in test_loader: + inputs = data['feature'] + inputs = inputs.to(self.device) + outputs = self.model(inputs) + outputs = outputs.squeeze(dim=1) + if self.num_classes == 1: + proba.extend(outputs.tolist()) + predicted = (outputs > self.prob_threshold).float().tolist() + else: + proba.extend(torch.max(outputs, dim=-1).values.tolist()) + predicted = torch.argmax(outputs, dim=-1).tolist() + predictions.extend(predicted) + + all_scenarios = ['roll', 'drop', 'towers', 'link', + 'collision', 'contain', 'dominoes'] + scenario = [next((sc for sc in all_scenarios if sc in filename), 'unknown') + for filename in features['stimulus_id'].data] + map_ = {'collision': 'Collide', 'contain': 'Contain', 'link': 'Link', 'towers': 'Support', + 'dominoes': 'Dominoes', 'drop': 'Drop', 'roll': 'Roll'} + proba = BehavioralAssembly(proba, + coords= + {'stimulus_id': ('presentation', features['stimulus_id'].data), + 'label': ('presentation', features['label'].data), + 'choice': ('presentation', predictions), + 'scenario': ('presentation', [map_[s] for s in scenario]), + 'choice_threshold': ('presentation', [self.prob_threshold]*len(predictions))}, + dims=['presentation']) + return proba + + # Define the ReadoutModel class + class ReadoutModel(nn.Module): + def __init__(self, input_dim, embed_dim=252, num_classes=1): + super(VideoReadoutMapping.ReadoutModel, self).__init__() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.num_classes = num_classes + self.embed_dim = embed_dim + self.linear_layer = nn.Sequential( + nn.Linear(input_dim, self.embed_dim), + nn.ReLU(), + nn.Linear(self.embed_dim, self.embed_dim), + ) + from jepa.src.models.attentive_pooler import AttentiveClassifier # Ensure this import path is correct + self.attentive_pooler = AttentiveClassifier(embed_dim=embed_dim, num_classes=num_classes) + + def forward(self, x, mode=None): + x = x.float() + x = x.view(x.shape[0], x.shape[1], -1) # Flatten keeping the last dimension + N, T, D = x.shape + x = self.linear_layer(x.flatten(0, 1)) + x = x.view(N, T, self.embed_dim) + x = self.attentive_pooler(x) + if self.num_classes == 1: + return torch.sigmoid(x) + else: + return F.softmax(x, dim=-1) + + class TransformerLoader(Dataset): + def __init__(self, features, labels, indices=None, num_classes=1): + self.indices = indices + self.labels = labels + self.features = features + if self.labels is not None and num_classes == 1: + self.positive_indices, self.negative_indices = [], [] + for i, l in enumerate(indices): + if self.labels[l] == 1: + self.positive_indices += [i] + else: + self.negative_indices += [i] + + def __len__(self): + length = len(self.indices) if self.indices is not None else self.features.shape[0] + return length + + def __getitem__(self, idx): + if self.indices is not None: + actual_idx = self.indices[idx] + else: + actual_idx = idx + feature = np.nan_to_num(self.features[actual_idx], nan=0.0) + if self.labels is not None: + label = self.labels[actual_idx] + else: + label = 0#None + return { + 'feature': feature, + 'label': label, + } + + class BalancedBatchSampler(Sampler): + def __init__(self, positive_indices, negative_indices, batch_size, seed=None): + self.positive_indices = np.array(positive_indices) + self.negative_indices = np.array(negative_indices) + self.batch_size = batch_size + self.seed = seed + assert batch_size % 2 == 0, "Batch size must be even." + + # If a seed is provided, use it for random operations + if self.seed is not None: + np.random.seed(self.seed) + torch.manual_seed(self.seed) + + def __iter__(self): + np.random.shuffle(self.positive_indices) + np.random.shuffle(self.negative_indices) + n_batches = min(len(self.positive_indices), len(self.negative_indices)) // (self.batch_size // 2) + + for i in range(n_batches): + batch_indices = np.concatenate([ + self.positive_indices[i*(self.batch_size//2):(i+1)*(self.batch_size//2)], + self.negative_indices[i*(self.batch_size//2):(i+1)*(self.batch_size//2)] + ]) + np.random.shuffle(batch_indices) + yield batch_indices + + def __len__(self): + return min(len(self.positive_indices), len(self.negative_indices)) // (self.batch_size // 2) + + class MultiEpochsDataLoader(torch.utils.data.DataLoader): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._DataLoader__initialized = False + self.batch_sampler = VideoReadoutMapping._RepeatSampler(self.batch_sampler) + self._DataLoader__initialized = True + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + + class _RepeatSampler(object): + """ Sampler that repeats forever. + Args: + sampler (Sampler) + """ + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) + + def __len__(self): + return len(self.sampler) + + class WarmupScheduler: + def __init__(self, optimizer, warmup_steps, initial_lr): + self.optimizer = optimizer + self.warmup_steps = warmup_steps + self.initial_lr = initial_lr + self.step_num = 0 + + def step(self): + self.step_num += 1 + if self.step_num <= self.warmup_steps: + lr = self.initial_lr * self.step_num / self.warmup_steps + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr diff --git a/brainscore_vision/model_interface.py b/brainscore_vision/model_interface.py index 48622282d..046e19411 100644 --- a/brainscore_vision/model_interface.py +++ b/brainscore_vision/model_interface.py @@ -118,6 +118,35 @@ class Task: - choice (choice) object 'dog' 'cat' 'chair' 'flower' 'plane' """ + video_readout = 'video_readout' + """ + Predict transformer readout hypothesis for each video stimulus. + Output a :class:`~brainio.assemblies.BehavioralAssembly` with probabilities as the values. + + The model must be supplied with `fitting_stimuli` in the second argument which allow it to train a readout + for a particular set of labels and video distribution. + The `fitting_stimuli` are a :class:`~brainio.stimuli.StimulusSet` and must include an `label` column + which is used as the labels to fit to. Each stimuli has a 'train' column to decide + + Example: + + Setting up a video readout task `start_task(BrainModel.Task.video_readout, )` + (where `fitting_stimuli` includes 2 distinct labels) + and calling `look_at()` could output + + .. code-block:: python + + + array([[0.9] + [0.4] + [0.5]]), # the probabilities + Coordinates: + * presentation (presentation) MultiIndex + - stimulus_path (presentation) object '/home/me/.brainio/demo_stimuli/video.mp4' ... + - choice (presentation) object '1' '0' '0' + - choice_threshold (presentation) object '0.5' '0.5' '0.5' + """ + odd_one_out = 'odd_one_out' """ Predict the odd-one-out elements for a list of triplets of stimuli. @@ -211,4 +240,4 @@ def look_at(self, stimuli: Union[StimulusSet, List[str]], number_of_trials=1) \ E.g. 10 or 35. Non-stochastic models can likely ignore this parameter. :return: task behaviors or recordings as instructed """ - raise NotImplementedError() + raise NotImplementedError() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 83ab968f7..6746bef12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "networkx", "eva-decord", "psutil", + "jepa@ git+https://github.com/thekej/jepa.git#egg=jepa", ] [project.optional-dependencies]