From bfec0cf0d3d5453b3640ed2162c5f9243a598f2c Mon Sep 17 00:00:00 2001 From: Negiiiin Date: Thu, 19 Dec 2024 00:58:00 -0500 Subject: [PATCH 1/8] Added llinear probing files --- mmlearn/tasks/__init__.py | 2 + mmlearn/tasks/linear_probing.py | 403 ++++++++++++++++++ .../experiment/linear_probing_eval.yaml | 81 ++++ 3 files changed, 486 insertions(+) create mode 100644 mmlearn/tasks/linear_probing.py create mode 100644 projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml diff --git a/mmlearn/tasks/__init__.py b/mmlearn/tasks/__init__.py index 552cae0..93db59f 100644 --- a/mmlearn/tasks/__init__.py +++ b/mmlearn/tasks/__init__.py @@ -4,6 +4,7 @@ from mmlearn.tasks.ijepa import IJEPA from mmlearn.tasks.zero_shot_classification import ZeroShotClassification from mmlearn.tasks.zero_shot_retrieval import ZeroShotCrossModalRetrieval +from mmlearn.tasks.linear_probing import LinearClassifierModule __all__ = [ @@ -11,4 +12,5 @@ "IJEPA", "ZeroShotCrossModalRetrieval", "ZeroShotClassification", + "LinearClassifierModule", ] diff --git a/mmlearn/tasks/linear_probing.py b/mmlearn/tasks/linear_probing.py new file mode 100644 index 0000000..9069666 --- /dev/null +++ b/mmlearn/tasks/linear_probing.py @@ -0,0 +1,403 @@ +"""A Module for linear evaluation of pretrained encoders.""" + +from contextlib import nullcontext +from functools import partial +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Callable + +import hydra +import inspect +import lightning as L # noqa: N812 +import torch +from lightning.fabric.utilities.cloud_io import _load as pl_load +from lightning_utilities.core.rank_zero import rank_zero_warn +from omegaconf import DictConfig +from lightning.pytorch.utilities.types import OptimizerLRScheduler +from torch import nn +from torchmetrics import MetricCollection, Accuracy, AUROC, Precision, Recall, F1Score +from hydra_zen import store + +from mmlearn.datasets.core import Modalities, find_matching_indices +from mmlearn.datasets.core.modalities import Modality +from mmlearn.tasks.hooks import EvaluationHooks + + +def extract_vision_encoder(encoder: Any, encoder_checkpoint_path: Optional[str]) -> nn.Module: + model: L.LightningModule = hydra.utils.instantiate(encoder) + if encoder_checkpoint_path is None: + rank_zero_warn("No encoder_checkpoint_path path was provided for linear evaluation.") + else: + checkpoint = torch.load(encoder_checkpoint_path) + if 'state_dict' not in checkpoint: + raise KeyError("'state_dict' not found in checkpoint") + + state_dict = checkpoint['state_dict'] + + # Filter keys that are related to vision encoder + encoder_keys = { + k.replace("encoders.rgb.", "") if k.startswith("encoders.rgb") else k: v + for k, v in state_dict.items() if "encoders.rgb" in k + } + try: + if encoder_keys: + model["rgb"].load_state_dict(encoder_keys, strict=True) + print("Encoder state dict loaded successfully") + except Exception as e: + print(f"Error loading state dict: {e}") + return model["rgb"] + + +@store(group="task", provider="mmlearn") +class LinearClassifierModule(L.LightningModule): + """A linear classifier module for evaluating pretrained encoders. + + Parameters + ---------- + encoder : torch.nn.Module + A pretrained encoder model, outputting features for the linear classifier. + modality : str + The modality of the input data to be passed through the encoder. See + `common.constants.Modality` for valid values. The target label key is + inferred from this modality. This means that, for example, that if the + modality is 'rgb', the target label key is expected to be 'rgb_target'. + num_output_features : int + Output features from the encoder, defining the linear classifier's input size. + num_classes : int + Number of classes for the classification task. + task : str + Classification task type. One of 'binary', 'multiclass', or 'multilabel'. + freeze_encoder : bool, default = True + If True, encoder's parameters are frozen during training. + pre_classifier_batch_norm : bool, default = False + If True, a batch normalization layer without affine transformation is + added before the linear classifier, following [1]. + top_k_list : List[int], optional, default = None + A list of integers specifying the `k` values for top-k accuracy metrics. + For each `k` in this list, top-k accuracy is calculated and tracked during + training and validation. This allows for the evaluation of the model's + performance at predicting the top `k` most probable classes. + optimizer : DictConfig, optional, default = None + The configuration for the optimizer. This will be instantiated using + `hydra.utils.instantiate`, so it should include the `_target_` field, + which should point to the optimizer class. + lr_scheduler : DictConfig, optional, default = None + The configuration for the learning rate scheduler. Two fields are expected: + `scheduler` (required) and `extras` (optional). The `scheduler` field should + contain configurations for the learning rate scheduler and must include the + `_target_` field, which, like the optimizer, should point to the scheduler + class. The `extras` field may contain one or more of the following: + - `interval` (str): The interval to apply the learning rate scheduler. + One of "epoch" or "step". Default is "epoch". + - `frequency` (int): The frequency to apply the learning rate scheduler + in the specified interval. Default is 1. + - `monitor` (str): The metric to monitor for schedulers like ReduceLROnPlateau. + - `strict` (bool): Whether to strictly enforce the availability of the + monitored metric (if `True`) or raise a warning if the metric is not found + (if `False`). Default is `True`. + + Attributes + ---------- + accuracy_metrics : torchmetrics.MetricCollection + A collection of metrics that includes accuracy for each `k` in `top_k_list`, + providing a comprehensive evaluation of model performance across different + levels of top-k predictions. + linear_eval : torch.nn.Linear + Linear classification layer atop the encoder. Input and output features are + determined by `encoder_output_features` and `num_classes`, respectively. + + References + ---------- + [1] He, K., Chen, X., Xie, S., Li, Y., Doll'ar, P., & Girshick, R.B. (2021). + Masked Autoencoders Are Scalable Vision Learners. 2022 IEEE/CVF Conference + on Computer Vision and Pattern Recognition (CVPR), 15979-15988. + """ + + def __init__( + self, + # encoder: torch.nn.Module, + encoder: nn.Module, + encoder_checkpoint_path: Optional[str], + modality: str, + num_output_features: int, + num_classes: int, + task: Literal["binary", "multiclass", "multilabel"], + freeze_encoder: bool = True, + pre_classifier_batch_norm: bool = False, + top_k_list: Optional[List[int]] = None, + optimizer: Optional[partial[torch.optim.Optimizer]] = None, + lr_scheduler: Optional[ + Union[ + Dict[str, partial[torch.optim.lr_scheduler.LRScheduler]], + partial[torch.optim.lr_scheduler.LRScheduler], + ] + ] = None, + vision_extractor: Callable[[Any, str], Optional[str]] = extract_vision_encoder, + ): + super().__init__() + assert task in ["binary", "multiclass", "multilabel"], ( + f"Invalid task type: {task}. " + "Expected one of ['binary', 'multiclass', 'multilabel']." + ) + + self.modality = modality + + self.encoder: nn.Module = extract_vision_encoder(encoder, encoder_checkpoint_path) + + linear_layer = nn.Linear(num_output_features, num_classes) + if pre_classifier_batch_norm: + linear_layer = nn.Sequential( + nn.BatchNorm1d(num_output_features, affine=False), + linear_layer, + ) + self.classifier = linear_layer + + self.freeze_encoder = freeze_encoder + self.num_classes = num_classes + + if self.freeze_encoder: + for param in self.encoder.parameters(): + param.requires_grad = False + + self.loss_fn = nn.CrossEntropyLoss() + + self.top_k_list = top_k_list + if task == "multiclass": + if self.top_k_list is None: + self.top_k_list = [1, 5] + accuracy_metrics = { + f"top_{k}_accuracy": Accuracy(task=task, num_classes=num_classes, top_k=k) + for k in self.top_k_list + } + + # Additional metrics for multiclass classification + additional_metrics = { + "precision": Precision(task=task, num_classes=num_classes, average="macro"), + "recall": Recall(task=task, num_classes=num_classes, average="macro"), + "f1_score": F1Score(task=task, num_classes=num_classes, average="macro"), + "auc": AUROC(task=task, num_classes=num_classes, average="macro") # AUROC for multiclass + } + + elif task == "multilabel": + # Accuracy and other metrics for multilabel classification + accuracy_metrics = {"accuracy": Accuracy(task=task, num_labels=num_classes)} + + # Additional metrics for multilabel classification + additional_metrics = { + "precision": Precision(task=task, num_labels=num_classes, average="macro"), + "recall": Recall(task=task, num_labels=num_classes, average="macro"), + "f1_score": F1Score(task=task, num_labels=num_classes, average="macro"), + "auc": AUROC(task=task, num_labels=num_classes) # AUC for multilabel + } + + else: # binary + # Accuracy and other metrics for binary classification + accuracy_metrics = {"accuracy": Accuracy(task=task)} + + # Additional metrics for binary classification + additional_metrics = { + "precision": Precision(task=task), + "recall": Recall(task=task), + "f1_score": F1Score(task=task), + "auc": AUROC(task=task) # AUROC for binary classification + } + + # combine all metrics + metrics = MetricCollection({**accuracy_metrics, **additional_metrics}) + self.train_metrics = metrics.clone(prefix="train/") + self.valid_metrics = metrics.clone(prefix="val/") + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Perform a forward pass through the encoder and linear classifier. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + + Returns + ------- + torch.Tensor + The logits predicted by the classifier. + """ + with torch.no_grad() if self.freeze_encoder else nullcontext(): + x = self.encoder(x) + return self.classifier(x[0]) + + def _get_logits_and_labels( + self, batch: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Return the logits and labels for a batch of data.""" + x : torch.Tensor = batch + y = batch[Modalities.get_modality(self.modality).target] + + logits = self(x) + return logits, y + + def _compute_loss( + self, batch: Dict[str, Any] + ) -> Optional[torch.Tensor]: + if self.loss_fn is None: + return None + + if self.freeze_encoder: + self.encoder.eval() + + logits, y = self._get_logits_and_labels(batch) + + loss: torch.Tensor = self.loss_fn(logits, y) + self.train_metrics.update(logits, y) + + return loss + + def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: + """Compute the loss for the batch. + + Parameters + ---------- + batch : Dict[str, Any] + The batch of data to process. + batch_idx : int + The index of the batch. + + Returns + ------- + torch.Tensor + The loss for the batch. + """ + loss = self._compute_loss(batch) + + if loss is None: + raise ValueError("The loss function must be provided for training.") + + self.log("train/loss", loss, prog_bar=True) + + return loss + + def validation_step( + self, + batch: Dict[str, torch.Tensor], + batch_idx: int, + ) -> torch.Tensor: + """ + Execute a validation step using a single batch. + + Parameters + ---------- + batch : Dict[str, torch.Tensor] + The current batch of validation data, including input tensors and labels. + batch_idx : int + The index of the current validation batch. + + Returns + ------- + torch.Tensor + The loss computed for the batch. + """ + + logits, y = self._get_logits_and_labels(batch) + + loss: torch.Tensor = self.loss_fn(logits, y) + self.log("val/loss", self.all_gather(loss.clone().detach()).mean()) + + self.valid_metrics.update(logits, y) + return loss + + def on_validation_epoch_end(self) -> None: + """Compute validation metrics accumulated over the epoch.""" + val_metrics = self.valid_metrics.compute() + print(f"val_metrics: {val_metrics}") + self.log_dict(val_metrics) + self.valid_metrics.reset() + + + def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912 + """Configure the optimizer and learning rate scheduler.""" + if self.optimizer is None: + rank_zero_warn( + "Optimizer not provided. Training will continue without an optimizer. " + "LR scheduler will not be used.", + ) + return None + + weight_decay: Optional[float] = self.optimizer.keywords.get( + "weight_decay", None + ) + if weight_decay is None: # try getting default value + kw_param = inspect.signature(self.optimizer.func).parameters.get( + "weight_decay" + ) + if kw_param is not None and kw_param.default != inspect.Parameter.empty: + weight_decay = kw_param.default + + parameters = [param for param in self.parameters() if param.requires_grad] + + if weight_decay is not None: + decay_params = [] + no_decay_params = [] + + for param in self.parameters(): + if not param.requires_grad: + continue + + if param.ndim < 2: # includes all bias and normalization parameters + no_decay_params.append(param) + else: + decay_params.append(param) + + parameters = [ + { + "params": decay_params, + "weight_decay": weight_decay, + "name": "weight_decay_params", + }, + { + "params": no_decay_params, + "weight_decay": 0.0, + "name": "no_weight_decay_params", + }, + ] + + optimizer = self.optimizer(parameters) + if not isinstance(optimizer, torch.optim.Optimizer): + raise TypeError( + "Expected optimizer to be an instance of `torch.optim.Optimizer`, " + f"but got {type(optimizer)}.", + ) + + if self.lr_scheduler is not None: + if isinstance(self.lr_scheduler, dict): + if "scheduler" not in self.lr_scheduler: + raise ValueError( + "Expected 'scheduler' key in the learning rate scheduler dictionary.", + ) + + lr_scheduler = self.lr_scheduler["scheduler"](optimizer) + if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler): + raise TypeError( + "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, " + f"but got {type(lr_scheduler)}.", + ) + lr_scheduler_dict: Dict[ + str, Union[torch.optim.lr_scheduler.LRScheduler, Any] + ] = {"scheduler": lr_scheduler} + + if self.lr_scheduler.get("extras"): + extras = self.lr_scheduler["extras"] + if isinstance(extras, partial): + # Extract the keywords from the partial object + lr_scheduler_dict.update(extras.keywords) + + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict} + + lr_scheduler = self.lr_scheduler(optimizer) + if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler): + raise TypeError( + "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, " + f"but got {type(lr_scheduler)}.", + ) + return [optimizer], [lr_scheduler] + + return optimizer diff --git a/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml b/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml new file mode 100644 index 0000000..dd779b1 --- /dev/null +++ b/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml @@ -0,0 +1,81 @@ +# @package _global_ + +defaults: + - /datasets@datasets.train.bach: BACH + - /datasets/transforms@datasets.train.bach.transform: med_clip_vision_transform + - /datasets@datasets.val.bach: BACH + - /modules/optimizers@task.optimizer: AdamW + - /datasets/transforms@datasets.val.bach.transform: med_clip_vision_transform + - /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingLR + - /modules/encoders@task.encoder.rgb: HFCLIPVisionEncoderWithProjection + # - /trainer/logger@trainer.logger.wandb: WandbLogger + - override /task: LinearClassifierModule + - _self_ + +seed: 0 + +datasets: + train: + bach: + split: train + transform: + job_type: train + val: + bach: + split: test + transform: + job_type: eval + +dataloader: + train: + batch_size: 64 + num_workers: 4 + shuffle: False + val: + batch_size: 64 + num_workers: 4 + shuffle: False + +task: + task: multiclass + num_classes: 4 + num_output_features: 512 + modality: rgb + encoder_checkpoint_path: /path/to/checkpoint + top_k_list: [1] + optimizer: + betas: + - 0.9 + - 0.98 + lr: 0.1 + weight_decay: 0.1 + eps: 1.0e-6 + lr_scheduler: + scheduler: + T_max: 250 # make sure to change this if max_epochs or accumulate_grad_batches is changed + extras: + interval: step + + +trainer: + precision: 16-mixed + deterministic: False + benchmark: True + sync_batchnorm: False # set to True if using DDP with batchnorm + log_every_n_steps: 100 + max_epochs: 40 + callbacks: + model_checkpoint: + monitor: val/loss + save_top_k: 1 + save_last: True + every_n_epochs: 1 + dirpath: /checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID} # only works on Vector SLURM environment + model_summary: + max_depth: 2 + + +tags: + - ${experiment_name} + - linear_probing + - classification From 023ca9cf4276ce6dbdea89e58c87c595c711bea3 Mon Sep 17 00:00:00 2001 From: Negiiiin Date: Thu, 19 Dec 2024 11:25:44 -0500 Subject: [PATCH 2/8] Fixed pre-commit issues --- mmlearn/tasks/linear_probing.py | 95 ++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/mmlearn/tasks/linear_probing.py b/mmlearn/tasks/linear_probing.py index 9069666..53b5a6b 100644 --- a/mmlearn/tasks/linear_probing.py +++ b/mmlearn/tasks/linear_probing.py @@ -1,41 +1,54 @@ """A Module for linear evaluation of pretrained encoders.""" +import inspect from contextlib import nullcontext from functools import partial -from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Callable +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import hydra -import inspect import lightning as L # noqa: N812 import torch -from lightning.fabric.utilities.cloud_io import _load as pl_load -from lightning_utilities.core.rank_zero import rank_zero_warn -from omegaconf import DictConfig +from hydra_zen import store from lightning.pytorch.utilities.types import OptimizerLRScheduler +from lightning_utilities.core.rank_zero import rank_zero_warn from torch import nn -from torchmetrics import MetricCollection, Accuracy, AUROC, Precision, Recall, F1Score -from hydra_zen import store +from torchmetrics import AUROC, Accuracy, F1Score, MetricCollection, Precision, Recall -from mmlearn.datasets.core import Modalities, find_matching_indices -from mmlearn.datasets.core.modalities import Modality -from mmlearn.tasks.hooks import EvaluationHooks +from mmlearn.datasets.core import Modalities -def extract_vision_encoder(encoder: Any, encoder_checkpoint_path: Optional[str]) -> nn.Module: +def extract_vision_encoder( + encoder: Any, encoder_checkpoint_path: Optional[str] +) -> nn.Module: + """ + Extract the vision encoder from a PyTorch Lightning model. + + Args: + encoder (Any): The encoder configuration or model to be instantiated. + encoder_checkpoint_path (Optional[str]): Path to the checkpoint file containing + the encoder's state_dict. + + Returns + ------- + nn.Module: The vision encoder module extracted and initialized. + """ model: L.LightningModule = hydra.utils.instantiate(encoder) if encoder_checkpoint_path is None: - rank_zero_warn("No encoder_checkpoint_path path was provided for linear evaluation.") + rank_zero_warn( + "No encoder_checkpoint_path path was provided for linear evaluation." + ) else: checkpoint = torch.load(encoder_checkpoint_path) - if 'state_dict' not in checkpoint: + if "state_dict" not in checkpoint: raise KeyError("'state_dict' not found in checkpoint") - state_dict = checkpoint['state_dict'] + state_dict = checkpoint["state_dict"] # Filter keys that are related to vision encoder encoder_keys = { k.replace("encoders.rgb.", "") if k.startswith("encoders.rgb") else k: v - for k, v in state_dict.items() if "encoders.rgb" in k + for k, v in state_dict.items() + if "encoders.rgb" in k } try: if encoder_keys: @@ -114,7 +127,7 @@ class LinearClassifierModule(L.LightningModule): def __init__( self, # encoder: torch.nn.Module, - encoder: nn.Module, + encoder: nn.Module, encoder_checkpoint_path: Optional[str], modality: str, num_output_features: int, @@ -139,8 +152,10 @@ def __init__( ) self.modality = modality - - self.encoder: nn.Module = extract_vision_encoder(encoder, encoder_checkpoint_path) + + self.encoder: nn.Module = extract_vision_encoder( + encoder, encoder_checkpoint_path + ) linear_layer = nn.Linear(num_output_features, num_classes) if pre_classifier_batch_norm: @@ -164,40 +179,50 @@ def __init__( if self.top_k_list is None: self.top_k_list = [1, 5] accuracy_metrics = { - f"top_{k}_accuracy": Accuracy(task=task, num_classes=num_classes, top_k=k) + f"top_{k}_accuracy": Accuracy( + task=task, num_classes=num_classes, top_k=k + ) for k in self.top_k_list } - + # Additional metrics for multiclass classification additional_metrics = { - "precision": Precision(task=task, num_classes=num_classes, average="macro"), + "precision": Precision( + task=task, num_classes=num_classes, average="macro" + ), "recall": Recall(task=task, num_classes=num_classes, average="macro"), - "f1_score": F1Score(task=task, num_classes=num_classes, average="macro"), - "auc": AUROC(task=task, num_classes=num_classes, average="macro") # AUROC for multiclass + "f1_score": F1Score( + task=task, num_classes=num_classes, average="macro" + ), + "auc": AUROC( + task=task, num_classes=num_classes, average="macro" + ), # AUROC for multiclass } elif task == "multilabel": # Accuracy and other metrics for multilabel classification accuracy_metrics = {"accuracy": Accuracy(task=task, num_labels=num_classes)} - + # Additional metrics for multilabel classification additional_metrics = { - "precision": Precision(task=task, num_labels=num_classes, average="macro"), + "precision": Precision( + task=task, num_labels=num_classes, average="macro" + ), "recall": Recall(task=task, num_labels=num_classes, average="macro"), "f1_score": F1Score(task=task, num_labels=num_classes, average="macro"), - "auc": AUROC(task=task, num_labels=num_classes) # AUC for multilabel + "auc": AUROC(task=task, num_labels=num_classes), # AUC for multilabel } else: # binary # Accuracy and other metrics for binary classification accuracy_metrics = {"accuracy": Accuracy(task=task)} - + # Additional metrics for binary classification additional_metrics = { "precision": Precision(task=task), "recall": Recall(task=task), "f1_score": F1Score(task=task), - "auc": AUROC(task=task) # AUROC for binary classification + "auc": AUROC(task=task), # AUROC for binary classification } # combine all metrics @@ -230,18 +255,16 @@ def _get_logits_and_labels( self, batch: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """Return the logits and labels for a batch of data.""" - x : torch.Tensor = batch + x: torch.Tensor = batch y = batch[Modalities.get_modality(self.modality).target] - + logits = self(x) return logits, y - def _compute_loss( - self, batch: Dict[str, Any] - ) -> Optional[torch.Tensor]: + def _compute_loss(self, batch: Dict[str, Any]) -> Optional[torch.Tensor]: if self.loss_fn is None: return None - + if self.freeze_encoder: self.encoder.eval() @@ -296,7 +319,6 @@ def validation_step( torch.Tensor The loss computed for the batch. """ - logits, y = self._get_logits_and_labels(batch) loss: torch.Tensor = self.loss_fn(logits, y) @@ -312,7 +334,6 @@ def on_validation_epoch_end(self) -> None: self.log_dict(val_metrics) self.valid_metrics.reset() - def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912 """Configure the optimizer and learning rate scheduler.""" if self.optimizer is None: @@ -389,7 +410,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912 if isinstance(extras, partial): # Extract the keywords from the partial object lr_scheduler_dict.update(extras.keywords) - + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict} lr_scheduler = self.lr_scheduler(optimizer) From 12f9150ea6bf64a0bba1911d4d7585805201dd21 Mon Sep 17 00:00:00 2001 From: Negiiiin Date: Thu, 19 Dec 2024 11:32:23 -0500 Subject: [PATCH 3/8] Fixed pre-commit issues --- mmlearn/tasks/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmlearn/tasks/__init__.py b/mmlearn/tasks/__init__.py index 93db59f..d42fe58 100644 --- a/mmlearn/tasks/__init__.py +++ b/mmlearn/tasks/__init__.py @@ -2,9 +2,9 @@ from mmlearn.tasks.contrastive_pretraining import ContrastivePretraining from mmlearn.tasks.ijepa import IJEPA +from mmlearn.tasks.linear_probing import LinearClassifierModule from mmlearn.tasks.zero_shot_classification import ZeroShotClassification from mmlearn.tasks.zero_shot_retrieval import ZeroShotCrossModalRetrieval -from mmlearn.tasks.linear_probing import LinearClassifierModule __all__ = [ From 63b036e96d2f2889cbe79f24eae7056be9be738f Mon Sep 17 00:00:00 2001 From: Negiiiin Date: Thu, 9 Jan 2025 16:30:37 -0500 Subject: [PATCH 4/8] Changed Vision Encoder Extraction --- mmlearn/tasks/linear_probing.py | 70 +++-- .../experiment/linear_probing_eval.yaml | 261 ++++++++++++++++-- 2 files changed, 294 insertions(+), 37 deletions(-) diff --git a/mmlearn/tasks/linear_probing.py b/mmlearn/tasks/linear_probing.py index 53b5a6b..5a7039b 100644 --- a/mmlearn/tasks/linear_probing.py +++ b/mmlearn/tasks/linear_probing.py @@ -3,7 +3,7 @@ import inspect from contextlib import nullcontext from functools import partial -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import hydra import lightning as L # noqa: N812 @@ -15,44 +15,70 @@ from torchmetrics import AUROC, Accuracy, F1Score, MetricCollection, Precision, Recall from mmlearn.datasets.core import Modalities +from mmlearn.modules.layers import MLP def extract_vision_encoder( - encoder: Any, encoder_checkpoint_path: Optional[str] + encoder: Any, + model_checkpoint_path: Optional[str], + keys_to_remove: Optional[List[str]] = None, + keys_to_rename: Optional[Dict[str, str]] = None, # Default for renaming + keys_to_ignore: Optional[List[str]] = None, ) -> nn.Module: """ Extract the vision encoder from a PyTorch Lightning model. Args: encoder (Any): The encoder configuration or model to be instantiated. - encoder_checkpoint_path (Optional[str]): Path to the checkpoint file containing + model_checkpoint_path (Optional[str]): Path to the checkpoint file containing the encoder's state_dict. + keys_to_remove (Optional[list]): List of keys to be removed from the state_dict. + keys_to_rename (Optional[dict]): Dictionary of prefixes or key replacements + mapping + old prefixes to new replacements (default removes 'encoders.rgb.'). + keys_to_ignore (Optional[list]): List of keys to ignore when loading the + state_dict. Returns ------- nn.Module: The vision encoder module extracted and initialized. """ model: L.LightningModule = hydra.utils.instantiate(encoder) - if encoder_checkpoint_path is None: + if model_checkpoint_path is None: rank_zero_warn( - "No encoder_checkpoint_path path was provided for linear evaluation." + "No model_checkpoint_path path was provided for linear evaluation." ) else: - checkpoint = torch.load(encoder_checkpoint_path) + checkpoint = torch.load(model_checkpoint_path) if "state_dict" not in checkpoint: raise KeyError("'state_dict' not found in checkpoint") state_dict = checkpoint["state_dict"] - # Filter keys that are related to vision encoder - encoder_keys = { - k.replace("encoders.rgb.", "") if k.startswith("encoders.rgb") else k: v - for k, v in state_dict.items() - if "encoders.rgb" in k - } + # Remove unwanted keys + if keys_to_remove: + state_dict = { + k: v for k, v in state_dict.items() if k not in keys_to_remove + } + + # Ignore specific keys + if keys_to_ignore: + state_dict = { + k: v for k, v in state_dict.items() if k not in keys_to_ignore + } + + # Rename keys based on input mappings + if keys_to_rename: + state_dict = { + k.replace(old_prefix, new_prefix): v + for k, v in state_dict.items() + for old_prefix, new_prefix in keys_to_rename.items() + if k.startswith(old_prefix) + } + try: - if encoder_keys: - model["rgb"].load_state_dict(encoder_keys, strict=True) + if state_dict: + model["rgb"].load_state_dict(state_dict, strict=True) print("Encoder state dict loaded successfully") except Exception as e: print(f"Error loading state dict: {e}") @@ -76,6 +102,8 @@ class LinearClassifierModule(L.LightningModule): Output features from the encoder, defining the linear classifier's input size. num_classes : int Number of classes for the classification task. + hidden_dims : list[int] + Size of each hidden layer of the model task : str Classification task type. One of 'binary', 'multiclass', or 'multilabel'. freeze_encoder : bool, default = True @@ -128,11 +156,12 @@ def __init__( self, # encoder: torch.nn.Module, encoder: nn.Module, - encoder_checkpoint_path: Optional[str], + model_checkpoint_path: Optional[str], # change name modality: str, num_output_features: int, num_classes: int, - task: Literal["binary", "multiclass", "multilabel"], + hidden_dims: Optional[List[int]] = None, + task: Literal["binary", "multiclass", "multilabel"] = "multiclass", freeze_encoder: bool = True, pre_classifier_batch_norm: bool = False, top_k_list: Optional[List[int]] = None, @@ -143,7 +172,6 @@ def __init__( partial[torch.optim.lr_scheduler.LRScheduler], ] ] = None, - vision_extractor: Callable[[Any, str], Optional[str]] = extract_vision_encoder, ): super().__init__() assert task in ["binary", "multiclass", "multilabel"], ( @@ -154,10 +182,11 @@ def __init__( self.modality = modality self.encoder: nn.Module = extract_vision_encoder( - encoder, encoder_checkpoint_path + encoder, model_checkpoint_path, keys_to_rename={"encoders.rgb.": ""} ) - linear_layer = nn.Linear(num_output_features, num_classes) + linear_layer = MLP(num_output_features, num_classes, hidden_dims) + if pre_classifier_batch_norm: linear_layer = nn.Sequential( nn.BatchNorm1d(num_output_features, affine=False), @@ -330,7 +359,8 @@ def validation_step( def on_validation_epoch_end(self) -> None: """Compute validation metrics accumulated over the epoch.""" val_metrics = self.valid_metrics.compute() - print(f"val_metrics: {val_metrics}") + for metric, value in val_metrics.items(): + print(f" {metric}: {value.item()}") self.log_dict(val_metrics) self.valid_metrics.reset() diff --git a/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml b/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml index dd779b1..bfed37f 100644 --- a/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml +++ b/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml @@ -4,34 +4,260 @@ defaults: - /datasets@datasets.train.bach: BACH - /datasets/transforms@datasets.train.bach.transform: med_clip_vision_transform - /datasets@datasets.val.bach: BACH - - /modules/optimizers@task.optimizer: AdamW - /datasets/transforms@datasets.val.bach.transform: med_clip_vision_transform + - /datasets@datasets.train.lc25k_lung: LC25000 + - /datasets/transforms@datasets.train.lc25k_lung.transform: med_clip_vision_transform + - /datasets@datasets.val.lc25k_lung: LC25000 + - /datasets/transforms@datasets.val.lc25k_lung.transform: med_clip_vision_transform + - /datasets@datasets.train.lc25k_colon: LC25000 + - /datasets/transforms@datasets.train.lc25k_colon.transform: med_clip_vision_transform + - /datasets@datasets.val.lc25k_colon: LC25000 + - /datasets/transforms@datasets.val.lc25k_colon.transform: med_clip_vision_transform + - /datasets@datasets.train.nck_crc: NckCrc + - /datasets/transforms@datasets.train.nck_crc.transform: med_clip_vision_transform + - /datasets@datasets.val.nck_crc: NckCrc + - /datasets/transforms@datasets.val.nck_crc.transform: med_clip_vision_transform + - /datasets@datasets.train.pad_ufes_20: PadUfes20 + - /datasets/transforms@datasets.train.pad_ufes_20.transform: med_clip_vision_transform + - /datasets@datasets.val.pad_ufes_20: PadUfes20 + - /datasets/transforms@datasets.val.pad_ufes_20.transform: med_clip_vision_transform + - /datasets@datasets.train.pcam: PCAM + - /datasets/transforms@datasets.train.pcam.transform: med_clip_vision_transform + - /datasets@datasets.val.pcam: PCAM + - /datasets/transforms@datasets.val.pcam.transform: med_clip_vision_transform + - /datasets@datasets.train.sicap: SICAP + - /datasets/transforms@datasets.train.sicap.transform: med_clip_vision_transform + - /datasets@datasets.val.sicap: SICAP + - /datasets/transforms@datasets.val.sicap.transform: med_clip_vision_transform + - /datasets@datasets.train.pathmnist: MedMNISTPlus + - /datasets/transforms@datasets.train.pathmnist.transform: med_clip_vision_transform + - /datasets@datasets.val.pathmnist: MedMNISTPlus + - /datasets/transforms@datasets.val.pathmnist.transform: med_clip_vision_transform + - /datasets@datasets.train.dermamnist: MedMNISTPlus + - /datasets/transforms@datasets.train.dermamnist.transform: med_clip_vision_transform + - /datasets@datasets.val.dermamnist: MedMNISTPlus + - /datasets/transforms@datasets.val.dermamnist.transform: med_clip_vision_transform + - /datasets@datasets.train.octmnist: MedMNISTPlus + - /datasets/transforms@datasets.train.octmnist.transform: med_clip_vision_transform + - /datasets@datasets.val.octmnist: MedMNISTPlus + - /datasets/transforms@datasets.val.octmnist.transform: med_clip_vision_transform + - /datasets@datasets.train.pneumoniamnist: MedMNISTPlus + - /datasets/transforms@datasets.train.pneumoniamnist.transform: med_clip_vision_transform + - /datasets@datasets.val.pneumoniamnist: MedMNISTPlus + - /datasets/transforms@datasets.val.pneumoniamnist.transform: med_clip_vision_transform + - /datasets@datasets.train.retinamnist: MedMNISTPlus + - /datasets/transforms@datasets.train.retinamnist.transform: med_clip_vision_transform + - /datasets@datasets.val.retinamnist: MedMNISTPlus + - /datasets/transforms@datasets.val.retinamnist.transform: med_clip_vision_transform + - /datasets@datasets.train.breastmnist: MedMNISTPlus + - /datasets/transforms@datasets.train.breastmnist.transform: med_clip_vision_transform + - /datasets@datasets.val.breastmnist: MedMNISTPlus + - /datasets/transforms@datasets.val.breastmnist.transform: med_clip_vision_transform + - /datasets@datasets.train.bloodmnist: MedMNISTPlus + - /datasets/transforms@datasets.train.bloodmnist.transform: med_clip_vision_transform + - /datasets@datasets.val.bloodmnist: MedMNISTPlus + - /datasets/transforms@datasets.val.bloodmnist.transform: med_clip_vision_transform + - /datasets@datasets.train.tissuemnist: MedMNISTPlus + - /datasets/transforms@datasets.train.tissuemnist.transform: med_clip_vision_transform + - /datasets@datasets.val.tissuemnist: MedMNISTPlus + - /datasets/transforms@datasets.val.tissuemnist.transform: med_clip_vision_transform + - /datasets@datasets.train.organamnist: MedMNISTPlus + - /datasets/transforms@datasets.train.organamnist.transform: med_clip_vision_transform + - /datasets@datasets.val.organamnist: MedMNISTPlus + - /datasets/transforms@datasets.val.organamnist.transform: med_clip_vision_transform + - /datasets@datasets.train.organcmnist: MedMNISTPlus + - /datasets/transforms@datasets.train.organcmnist.transform: med_clip_vision_transform + - /datasets@datasets.val.organcmnist: MedMNISTPlus + - /datasets/transforms@datasets.val.organcmnist.transform: med_clip_vision_transform + - /datasets@datasets.train.organsmnist: MedMNISTPlus + - /datasets/transforms@datasets.train.organsmnist.transform: med_clip_vision_transform + - /datasets@datasets.val.organsmnist: MedMNISTPlus + - /datasets/transforms@datasets.val.organsmnist.transform: med_clip_vision_transform + - /modules/optimizers@task.optimizer: AdamW - /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingLR - /modules/encoders@task.encoder.rgb: HFCLIPVisionEncoderWithProjection - # - /trainer/logger@trainer.logger.wandb: WandbLogger + - /trainer/logger@trainer.logger.wandb: WandbLogger - override /task: LinearClassifierModule - _self_ seed: 0 datasets: - train: - bach: - split: train - transform: - job_type: train - val: - bach: - split: test - transform: - job_type: eval - -dataloader: train: - batch_size: 64 - num_workers: 4 - shuffle: False + pcam: + split: train + transform: + job_type: train + bach: + split: train + transform: + job_type: train + lc25k_lung: + split: train + transform: + job_type: train + lc25k_colon: + root_dir: ${oc.env:LC25000_COLON_ROOT_DIR} + split: train + organ: colon + transform: + job_type: train + pathmnist: + split: train + name: pathmnist + transform: + job_type: train + dermamnist: + split: train + name: dermamnist + transform: + job_type: train + octmnist: + split: train + name: octmnist + transform: + job_type: train + pneumoniamnist: + split: train + name: pneumoniamnist + transform: + job_type: train + retinamnist: + split: train + name: retinamnist + transform: + job_type: train + breastmnist: + split: train + name: breastmnist + transform: + job_type: train + bloodmnist: + split: train + name: bloodmnist + transform: + job_type: train + tissuemnist: + split: train + name: tissuemnist + transform: + job_type: train + organamnist: + split: train + name: organamnist + transform: + job_type: train + organcmnist: + split: train + name: organcmnist + transform: + job_type: train + organsmnist: + split: train + name: organsmnist + transform: + job_type: train + nck_crc: + split: train + transform: + job_type: train + pad_ufes_20: + split: train + transform: + job_type: train + sicap: + split: train + transform: + job_type: train val: + pcam: + split: test + transform: + job_type: eval + bach: + split: test + transform: + job_type: eval + lc25k_lung: + split: test + transform: + job_type: eval + lc25k_colon: + root_dir: ${oc.env:LC25000_COLON_ROOT_DIR} + split: test + organ: colon + transform: + job_type: eval + pathmnist: + split: test + name: pathmnist + transform: + job_type: eval + dermamnist: + split: test + name: dermamnist + transform: + job_type: eval + octmnist: + split: test + name: octmnist + transform: + job_type: eval + pneumoniamnist: + split: test + name: pneumoniamnist + transform: + job_type: eval + retinamnist: + split: test + name: retinamnist + transform: + job_type: eval + breastmnist: + split: test + name: breastmnist + transform: + job_type: eval + bloodmnist: + split: test + name: bloodmnist + transform: + job_type: eval + tissuemnist: + split: test + name: tissuemnist + transform: + job_type: eval + organamnist: + split: test + name: organamnist + transform: + job_type: eval + organcmnist: + split: test + name: organcmnist + transform: + job_type: eval + organsmnist: + split: test + name: organsmnist + transform: + job_type: eval + nck_crc: + split: validation + transform: + job_type: eval + pad_ufes_20: + split: test + transform: + job_type: eval + sicap: + split: test + transform: + job_type: eval + +dataloader: + test: batch_size: 64 num_workers: 4 shuffle: False @@ -40,6 +266,7 @@ task: task: multiclass num_classes: 4 num_output_features: 512 + hidden_dims: [256] modality: rgb encoder_checkpoint_path: /path/to/checkpoint top_k_list: [1] From cdc2e820d8ee70677c9b0f5d17ab6464f8ee0f08 Mon Sep 17 00:00:00 2001 From: Negiiiin Date: Sun, 30 Mar 2025 13:54:50 -0400 Subject: [PATCH 5/8] Fix issues raised in PR review --- mmlearn/tasks/__init__.py | 4 +- mmlearn/tasks/linear_probing.py | 188 +++++++------ .../experiment/linear_probing_eval.yaml | 247 +----------------- 3 files changed, 126 insertions(+), 313 deletions(-) diff --git a/mmlearn/tasks/__init__.py b/mmlearn/tasks/__init__.py index d42fe58..7b295f2 100644 --- a/mmlearn/tasks/__init__.py +++ b/mmlearn/tasks/__init__.py @@ -2,7 +2,7 @@ from mmlearn.tasks.contrastive_pretraining import ContrastivePretraining from mmlearn.tasks.ijepa import IJEPA -from mmlearn.tasks.linear_probing import LinearClassifierModule +from mmlearn.tasks.linear_probing import LinearClassifier from mmlearn.tasks.zero_shot_classification import ZeroShotClassification from mmlearn.tasks.zero_shot_retrieval import ZeroShotCrossModalRetrieval @@ -12,5 +12,5 @@ "IJEPA", "ZeroShotCrossModalRetrieval", "ZeroShotClassification", - "LinearClassifierModule", + "LinearClassifier", ] diff --git a/mmlearn/tasks/linear_probing.py b/mmlearn/tasks/linear_probing.py index 5a7039b..c1043a8 100644 --- a/mmlearn/tasks/linear_probing.py +++ b/mmlearn/tasks/linear_probing.py @@ -16,14 +16,17 @@ from mmlearn.datasets.core import Modalities from mmlearn.modules.layers import MLP +from mmlearn.tasks.base import TrainingTask + +from mmlearn.tasks.zero_shot_classification import ZeroShotClassification def extract_vision_encoder( encoder: Any, model_checkpoint_path: Optional[str], + modality_to_extract: Optional[str] = "rgb", keys_to_remove: Optional[List[str]] = None, keys_to_rename: Optional[Dict[str, str]] = None, # Default for renaming - keys_to_ignore: Optional[List[str]] = None, ) -> nn.Module: """ Extract the vision encoder from a PyTorch Lightning model. @@ -61,12 +64,6 @@ def extract_vision_encoder( k: v for k, v in state_dict.items() if k not in keys_to_remove } - # Ignore specific keys - if keys_to_ignore: - state_dict = { - k: v for k, v in state_dict.items() if k not in keys_to_ignore - } - # Rename keys based on input mappings if keys_to_rename: state_dict = { @@ -78,15 +75,15 @@ def extract_vision_encoder( try: if state_dict: - model["rgb"].load_state_dict(state_dict, strict=True) + model[modality_to_extract].load_state_dict(state_dict, strict=True) print("Encoder state dict loaded successfully") except Exception as e: print(f"Error loading state dict: {e}") - return model["rgb"] + return model[modality_to_extract] @store(group="task", provider="mmlearn") -class LinearClassifierModule(L.LightningModule): +class LinearClassifier(TrainingTask): """A linear classifier module for evaluating pretrained encoders. Parameters @@ -98,7 +95,7 @@ class LinearClassifierModule(L.LightningModule): `common.constants.Modality` for valid values. The target label key is inferred from this modality. This means that, for example, that if the modality is 'rgb', the target label key is expected to be 'rgb_target'. - num_output_features : int + embed_dim : int Output features from the encoder, defining the linear classifier's input size. num_classes : int Number of classes for the classification task. @@ -154,18 +151,19 @@ class LinearClassifierModule(L.LightningModule): def __init__( self, - # encoder: torch.nn.Module, encoder: nn.Module, - model_checkpoint_path: Optional[str], # change name + model_checkpoint_path: Optional[str], modality: str, - num_output_features: int, + embed_dim: int, num_classes: int, hidden_dims: Optional[List[int]] = None, task: Literal["binary", "multiclass", "multilabel"] = "multiclass", freeze_encoder: bool = True, - pre_classifier_batch_norm: bool = False, + keys_to_remove: Optional[Dict[str, str]] = None, + keys_to_rename: Optional[Dict[str, str]] = {"encoders.rgb.": ""}, top_k_list: Optional[List[int]] = None, optimizer: Optional[partial[torch.optim.Optimizer]] = None, + pre_classifier_batch_norm: bool = False, lr_scheduler: Optional[ Union[ Dict[str, partial[torch.optim.lr_scheduler.LRScheduler]], @@ -173,7 +171,7 @@ def __init__( ] ] = None, ): - super().__init__() + super().__init__(loss_fn=nn.CrossEntropyLoss()) assert task in ["binary", "multiclass", "multilabel"], ( f"Invalid task type: {task}. " "Expected one of ['binary', 'multiclass', 'multilabel']." @@ -182,16 +180,13 @@ def __init__( self.modality = modality self.encoder: nn.Module = extract_vision_encoder( - encoder, model_checkpoint_path, keys_to_rename={"encoders.rgb.": ""} + encoder, model_checkpoint_path, keys_to_rename=keys_to_rename, + keys_to_remove=keys_to_remove, ) - linear_layer = MLP(num_output_features, num_classes, hidden_dims) + linear_layer = MLP(embed_dim, num_classes, hidden_dims, + norm_layer=nn.BatchNorm1d if pre_classifier_batch_norm else None) - if pre_classifier_batch_norm: - linear_layer = nn.Sequential( - nn.BatchNorm1d(num_output_features, affine=False), - linear_layer, - ) self.classifier = linear_layer self.freeze_encoder = freeze_encoder @@ -201,61 +196,67 @@ def __init__( for param in self.encoder.parameters(): param.requires_grad = False - self.loss_fn = nn.CrossEntropyLoss() + if task == "multilabel": + self.loss_fn = nn.BCEWithLogitsLoss() + self.top_k_list = top_k_list - if task == "multiclass": - if self.top_k_list is None: - self.top_k_list = [1, 5] - accuracy_metrics = { - f"top_{k}_accuracy": Accuracy( - task=task, num_classes=num_classes, top_k=k - ) - for k in self.top_k_list - } - - # Additional metrics for multiclass classification - additional_metrics = { - "precision": Precision( - task=task, num_classes=num_classes, average="macro" - ), - "recall": Recall(task=task, num_classes=num_classes, average="macro"), - "f1_score": F1Score( - task=task, num_classes=num_classes, average="macro" - ), - "auc": AUROC( - task=task, num_classes=num_classes, average="macro" - ), # AUROC for multiclass - } - - elif task == "multilabel": - # Accuracy and other metrics for multilabel classification - accuracy_metrics = {"accuracy": Accuracy(task=task, num_labels=num_classes)} - - # Additional metrics for multilabel classification - additional_metrics = { - "precision": Precision( - task=task, num_labels=num_classes, average="macro" - ), - "recall": Recall(task=task, num_labels=num_classes, average="macro"), - "f1_score": F1Score(task=task, num_labels=num_classes, average="macro"), - "auc": AUROC(task=task, num_labels=num_classes), # AUC for multilabel - } - - else: # binary - # Accuracy and other metrics for binary classification - accuracy_metrics = {"accuracy": Accuracy(task=task)} - - # Additional metrics for binary classification - additional_metrics = { - "precision": Precision(task=task), - "recall": Recall(task=task), - "f1_score": F1Score(task=task), - "auc": AUROC(task=task), # AUROC for binary classification - } + # if task == "multiclass": + # if self.top_k_list is None: + # self.top_k_list = [1, 5] + # accuracy_metrics = { + # f"top_{k}_accuracy": Accuracy( + # task=task, num_classes=num_classes, top_k=k + # ) + # for k in self.top_k_list + # } + + # # Additional metrics for multiclass classification + # additional_metrics = { + # "precision": Precision( + # task=task, num_classes=num_classes, average="macro" + # ), + # "recall": Recall(task=task, num_classes=num_classes, average="macro"), + # "f1_score": F1Score( + # task=task, num_classes=num_classes, average="macro" + # ), + # "auc": AUROC( + # task=task, num_classes=num_classes, average="macro" + # ), # AUROC for multiclass + # } + + # elif task == "multilabel": + # # Accuracy and other metrics for multilabel classification + # accuracy_metrics = {"accuracy": Accuracy(task=task, num_labels=num_classes)} + + # # Additional metrics for multilabel classification + # additional_metrics = { + # "precision": Precision( + # task=task, num_labels=num_classes, average="macro" + # ), + # "recall": Recall(task=task, num_labels=num_classes, average="macro"), + # "f1_score": F1Score(task=task, num_labels=num_classes, average="macro"), + # "auc": AUROC(task=task, num_labels=num_classes), # AUC for multilabel + # } + + # else: # binary + # # Accuracy and other metrics for binary classification + # accuracy_metrics = {"accuracy": Accuracy(task=task)} + + # # Additional metrics for binary classification + # additional_metrics = { + # "precision": Precision(task=task), + # "recall": Recall(task=task), + # "f1_score": F1Score(task=task), + # "auc": AUROC(task=task), # AUROC for binary classification + # } # combine all metrics - metrics = MetricCollection({**accuracy_metrics, **additional_metrics}) + # metrics = MetricCollection({**accuracy_metrics, **additional_metrics}) + metrics = ZeroShotClassification._create_metrics(num_classes=num_classes, + top_k=self.top_k_list, + prefix="", + postfix="",) self.train_metrics = metrics.clone(prefix="train/") self.valid_metrics = metrics.clone(prefix="val/") @@ -349,12 +350,40 @@ def validation_step( The loss computed for the batch. """ logits, y = self._get_logits_and_labels(batch) - + loss: torch.Tensor = self.loss_fn(logits, y) self.log("val/loss", self.all_gather(loss.clone().detach()).mean()) self.valid_metrics.update(logits, y) return loss + + def test_step( + self, + batch: Dict[str, torch.Tensor], + batch_idx: int, + ) -> torch.Tensor: + """ + Execute a test step using a single batch. + + Parameters + ---------- + batch : Dict[str, torch.Tensor] + The current batch of test data, including input tensors and labels. + batch_idx : int + The index of the current test batch. + + Returns + ------- + torch.Tensor + The loss computed for the batch. + """ + logits, y = self._get_logits_and_labels(batch) + + loss: torch.Tensor = self.loss_fn(logits, y) + self.log("val/loss", self.all_gather(loss.clone().detach()).mean()) + + self.test_metrics.update(logits, y) + return loss def on_validation_epoch_end(self) -> None: """Compute validation metrics accumulated over the epoch.""" @@ -363,6 +392,15 @@ def on_validation_epoch_end(self) -> None: print(f" {metric}: {value.item()}") self.log_dict(val_metrics) self.valid_metrics.reset() + + + def on_test_epoch_end(self) -> None: + """Compute test metrics accumulated over the epoch.""" + val_metrics = self.test_metrics.compute() + for metric, value in val_metrics.items(): + print(f" {metric}: {value.item()}") + self.log_dict(val_metrics) + self.test_metrics.reset() def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912 """Configure the optimizer and learning rate scheduler.""" diff --git a/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml b/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml index bfed37f..d92d064 100644 --- a/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml +++ b/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml @@ -5,259 +5,33 @@ defaults: - /datasets/transforms@datasets.train.bach.transform: med_clip_vision_transform - /datasets@datasets.val.bach: BACH - /datasets/transforms@datasets.val.bach.transform: med_clip_vision_transform - - /datasets@datasets.train.lc25k_lung: LC25000 - - /datasets/transforms@datasets.train.lc25k_lung.transform: med_clip_vision_transform - - /datasets@datasets.val.lc25k_lung: LC25000 - - /datasets/transforms@datasets.val.lc25k_lung.transform: med_clip_vision_transform - - /datasets@datasets.train.lc25k_colon: LC25000 - - /datasets/transforms@datasets.train.lc25k_colon.transform: med_clip_vision_transform - - /datasets@datasets.val.lc25k_colon: LC25000 - - /datasets/transforms@datasets.val.lc25k_colon.transform: med_clip_vision_transform - - /datasets@datasets.train.nck_crc: NckCrc - - /datasets/transforms@datasets.train.nck_crc.transform: med_clip_vision_transform - - /datasets@datasets.val.nck_crc: NckCrc - - /datasets/transforms@datasets.val.nck_crc.transform: med_clip_vision_transform - - /datasets@datasets.train.pad_ufes_20: PadUfes20 - - /datasets/transforms@datasets.train.pad_ufes_20.transform: med_clip_vision_transform - - /datasets@datasets.val.pad_ufes_20: PadUfes20 - - /datasets/transforms@datasets.val.pad_ufes_20.transform: med_clip_vision_transform - - /datasets@datasets.train.pcam: PCAM - - /datasets/transforms@datasets.train.pcam.transform: med_clip_vision_transform - - /datasets@datasets.val.pcam: PCAM - - /datasets/transforms@datasets.val.pcam.transform: med_clip_vision_transform - - /datasets@datasets.train.sicap: SICAP - - /datasets/transforms@datasets.train.sicap.transform: med_clip_vision_transform - - /datasets@datasets.val.sicap: SICAP - - /datasets/transforms@datasets.val.sicap.transform: med_clip_vision_transform - - /datasets@datasets.train.pathmnist: MedMNISTPlus - - /datasets/transforms@datasets.train.pathmnist.transform: med_clip_vision_transform - - /datasets@datasets.val.pathmnist: MedMNISTPlus - - /datasets/transforms@datasets.val.pathmnist.transform: med_clip_vision_transform - - /datasets@datasets.train.dermamnist: MedMNISTPlus - - /datasets/transforms@datasets.train.dermamnist.transform: med_clip_vision_transform - - /datasets@datasets.val.dermamnist: MedMNISTPlus - - /datasets/transforms@datasets.val.dermamnist.transform: med_clip_vision_transform - - /datasets@datasets.train.octmnist: MedMNISTPlus - - /datasets/transforms@datasets.train.octmnist.transform: med_clip_vision_transform - - /datasets@datasets.val.octmnist: MedMNISTPlus - - /datasets/transforms@datasets.val.octmnist.transform: med_clip_vision_transform - - /datasets@datasets.train.pneumoniamnist: MedMNISTPlus - - /datasets/transforms@datasets.train.pneumoniamnist.transform: med_clip_vision_transform - - /datasets@datasets.val.pneumoniamnist: MedMNISTPlus - - /datasets/transforms@datasets.val.pneumoniamnist.transform: med_clip_vision_transform - - /datasets@datasets.train.retinamnist: MedMNISTPlus - - /datasets/transforms@datasets.train.retinamnist.transform: med_clip_vision_transform - - /datasets@datasets.val.retinamnist: MedMNISTPlus - - /datasets/transforms@datasets.val.retinamnist.transform: med_clip_vision_transform - - /datasets@datasets.train.breastmnist: MedMNISTPlus - - /datasets/transforms@datasets.train.breastmnist.transform: med_clip_vision_transform - - /datasets@datasets.val.breastmnist: MedMNISTPlus - - /datasets/transforms@datasets.val.breastmnist.transform: med_clip_vision_transform - - /datasets@datasets.train.bloodmnist: MedMNISTPlus - - /datasets/transforms@datasets.train.bloodmnist.transform: med_clip_vision_transform - - /datasets@datasets.val.bloodmnist: MedMNISTPlus - - /datasets/transforms@datasets.val.bloodmnist.transform: med_clip_vision_transform - - /datasets@datasets.train.tissuemnist: MedMNISTPlus - - /datasets/transforms@datasets.train.tissuemnist.transform: med_clip_vision_transform - - /datasets@datasets.val.tissuemnist: MedMNISTPlus - - /datasets/transforms@datasets.val.tissuemnist.transform: med_clip_vision_transform - - /datasets@datasets.train.organamnist: MedMNISTPlus - - /datasets/transforms@datasets.train.organamnist.transform: med_clip_vision_transform - - /datasets@datasets.val.organamnist: MedMNISTPlus - - /datasets/transforms@datasets.val.organamnist.transform: med_clip_vision_transform - - /datasets@datasets.train.organcmnist: MedMNISTPlus - - /datasets/transforms@datasets.train.organcmnist.transform: med_clip_vision_transform - - /datasets@datasets.val.organcmnist: MedMNISTPlus - - /datasets/transforms@datasets.val.organcmnist.transform: med_clip_vision_transform - - /datasets@datasets.train.organsmnist: MedMNISTPlus - - /datasets/transforms@datasets.train.organsmnist.transform: med_clip_vision_transform - - /datasets@datasets.val.organsmnist: MedMNISTPlus - - /datasets/transforms@datasets.val.organsmnist.transform: med_clip_vision_transform - /modules/optimizers@task.optimizer: AdamW - /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingLR - /modules/encoders@task.encoder.rgb: HFCLIPVisionEncoderWithProjection - /trainer/logger@trainer.logger.wandb: WandbLogger - - override /task: LinearClassifierModule + - override /task: LinearClassifier - _self_ seed: 0 datasets: train: - pcam: - split: train - transform: - job_type: train bach: split: train transform: job_type: train - lc25k_lung: - split: train - transform: - job_type: train - lc25k_colon: - root_dir: ${oc.env:LC25000_COLON_ROOT_DIR} - split: train - organ: colon - transform: - job_type: train - pathmnist: - split: train - name: pathmnist - transform: - job_type: train - dermamnist: - split: train - name: dermamnist - transform: - job_type: train - octmnist: - split: train - name: octmnist - transform: - job_type: train - pneumoniamnist: - split: train - name: pneumoniamnist - transform: - job_type: train - retinamnist: - split: train - name: retinamnist - transform: - job_type: train - breastmnist: - split: train - name: breastmnist - transform: - job_type: train - bloodmnist: - split: train - name: bloodmnist - transform: - job_type: train - tissuemnist: - split: train - name: tissuemnist - transform: - job_type: train - organamnist: - split: train - name: organamnist - transform: - job_type: train - organcmnist: - split: train - name: organcmnist - transform: - job_type: train - organsmnist: - split: train - name: organsmnist - transform: - job_type: train - nck_crc: - split: train - transform: - job_type: train - pad_ufes_20: - split: train - transform: - job_type: train - sicap: - split: train - transform: - job_type: train val: - pcam: - split: test - transform: - job_type: eval bach: split: test transform: job_type: eval - lc25k_lung: - split: test - transform: - job_type: eval - lc25k_colon: - root_dir: ${oc.env:LC25000_COLON_ROOT_DIR} - split: test - organ: colon - transform: - job_type: eval - pathmnist: - split: test - name: pathmnist - transform: - job_type: eval - dermamnist: - split: test - name: dermamnist - transform: - job_type: eval - octmnist: - split: test - name: octmnist - transform: - job_type: eval - pneumoniamnist: - split: test - name: pneumoniamnist - transform: - job_type: eval - retinamnist: - split: test - name: retinamnist - transform: - job_type: eval - breastmnist: - split: test - name: breastmnist - transform: - job_type: eval - bloodmnist: - split: test - name: bloodmnist - transform: - job_type: eval - tissuemnist: - split: test - name: tissuemnist - transform: - job_type: eval - organamnist: - split: test - name: organamnist - transform: - job_type: eval - organcmnist: - split: test - name: organcmnist - transform: - job_type: eval - organsmnist: - split: test - name: organsmnist - transform: - job_type: eval - nck_crc: - split: validation - transform: - job_type: eval - pad_ufes_20: - split: test - transform: - job_type: eval - sicap: - split: test - transform: - job_type: eval dataloader: - test: + train: + batch_size: 64 + num_workers: 4 + shuffle: False + val: batch_size: 64 num_workers: 4 shuffle: False @@ -265,10 +39,9 @@ dataloader: task: task: multiclass num_classes: 4 - num_output_features: 512 - hidden_dims: [256] + embed_dim: 512 modality: rgb - encoder_checkpoint_path: /path/to/checkpoint + model_checkpoint_path: /projects/multimodal/checkpoints/mmlearn/med_benchmarking/vit_base_patch16_224_ep11.ckpt top_k_list: [1] optimizer: betas: @@ -279,9 +52,11 @@ task: eps: 1.0e-6 lr_scheduler: scheduler: - T_max: 250 # make sure to change this if max_epochs or accumulate_grad_batches is changed + T_max: 188 # make sure to change this if max_epochs or accumulate_grad_batches is changed extras: interval: step + keys_to_rename: + "encoders.rgb.": "" trainer: From 595c9e0b3703762189176bd39b8e0d8d784f626e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:55:07 +0000 Subject: [PATCH 6/8] [pre-commit.ci] Add auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mmlearn/tasks/linear_probing.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/mmlearn/tasks/linear_probing.py b/mmlearn/tasks/linear_probing.py index c1043a8..00b7b7a 100644 --- a/mmlearn/tasks/linear_probing.py +++ b/mmlearn/tasks/linear_probing.py @@ -12,12 +12,10 @@ from lightning.pytorch.utilities.types import OptimizerLRScheduler from lightning_utilities.core.rank_zero import rank_zero_warn from torch import nn -from torchmetrics import AUROC, Accuracy, F1Score, MetricCollection, Precision, Recall from mmlearn.datasets.core import Modalities from mmlearn.modules.layers import MLP from mmlearn.tasks.base import TrainingTask - from mmlearn.tasks.zero_shot_classification import ZeroShotClassification @@ -180,12 +178,18 @@ def __init__( self.modality = modality self.encoder: nn.Module = extract_vision_encoder( - encoder, model_checkpoint_path, keys_to_rename=keys_to_rename, + encoder, + model_checkpoint_path, + keys_to_rename=keys_to_rename, keys_to_remove=keys_to_remove, ) - linear_layer = MLP(embed_dim, num_classes, hidden_dims, - norm_layer=nn.BatchNorm1d if pre_classifier_batch_norm else None) + linear_layer = MLP( + embed_dim, + num_classes, + hidden_dims, + norm_layer=nn.BatchNorm1d if pre_classifier_batch_norm else None, + ) self.classifier = linear_layer @@ -198,7 +202,6 @@ def __init__( if task == "multilabel": self.loss_fn = nn.BCEWithLogitsLoss() - self.top_k_list = top_k_list # if task == "multiclass": @@ -253,10 +256,9 @@ def __init__( # combine all metrics # metrics = MetricCollection({**accuracy_metrics, **additional_metrics}) - metrics = ZeroShotClassification._create_metrics(num_classes=num_classes, - top_k=self.top_k_list, - prefix="", - postfix="",) + metrics = ZeroShotClassification._create_metrics( + num_classes=num_classes, top_k=self.top_k_list, prefix="", postfix="" + ) self.train_metrics = metrics.clone(prefix="train/") self.valid_metrics = metrics.clone(prefix="val/") @@ -350,13 +352,13 @@ def validation_step( The loss computed for the batch. """ logits, y = self._get_logits_and_labels(batch) - + loss: torch.Tensor = self.loss_fn(logits, y) self.log("val/loss", self.all_gather(loss.clone().detach()).mean()) self.valid_metrics.update(logits, y) return loss - + def test_step( self, batch: Dict[str, torch.Tensor], @@ -392,8 +394,7 @@ def on_validation_epoch_end(self) -> None: print(f" {metric}: {value.item()}") self.log_dict(val_metrics) self.valid_metrics.reset() - - + def on_test_epoch_end(self) -> None: """Compute test metrics accumulated over the epoch.""" val_metrics = self.test_metrics.compute() From fa7259e5baf66c3aba14dd44d0721c979d164c9c Mon Sep 17 00:00:00 2001 From: Negiiiin Date: Sun, 30 Mar 2025 15:24:35 -0400 Subject: [PATCH 7/8] Fix pre-commit isssues --- mmlearn/tasks/linear_probing.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mmlearn/tasks/linear_probing.py b/mmlearn/tasks/linear_probing.py index 00b7b7a..3b54463 100644 --- a/mmlearn/tasks/linear_probing.py +++ b/mmlearn/tasks/linear_probing.py @@ -154,12 +154,12 @@ def __init__( modality: str, embed_dim: int, num_classes: int, + top_k_list: List[int], hidden_dims: Optional[List[int]] = None, task: Literal["binary", "multiclass", "multilabel"] = "multiclass", freeze_encoder: bool = True, - keys_to_remove: Optional[Dict[str, str]] = None, - keys_to_rename: Optional[Dict[str, str]] = {"encoders.rgb.": ""}, - top_k_list: Optional[List[int]] = None, + keys_to_remove: Optional[List[str]] = None, + keys_to_rename: Optional[Dict[str, str]] = None, optimizer: Optional[partial[torch.optim.Optimizer]] = None, pre_classifier_batch_norm: bool = False, lr_scheduler: Optional[ @@ -204,6 +204,7 @@ def __init__( self.loss_fn = nn.BCEWithLogitsLoss() self.top_k_list = top_k_list +<<<<<<< HEAD # if task == "multiclass": # if self.top_k_list is None: # self.top_k_list = [1, 5] @@ -256,6 +257,9 @@ def __init__( # combine all metrics # metrics = MetricCollection({**accuracy_metrics, **additional_metrics}) +======= + +>>>>>>> Fix pre-commit isssues metrics = ZeroShotClassification._create_metrics( num_classes=num_classes, top_k=self.top_k_list, prefix="", postfix="" ) From b9b01f13b1d3d0fb25eae964f69435c5245eae10 Mon Sep 17 00:00:00 2001 From: fcogidi <41602287+fcogidi@users.noreply.github.com> Date: Mon, 31 Mar 2025 09:17:39 -0400 Subject: [PATCH 8/8] Update linear probing implementation and configs --- mmlearn/tasks/__init__.py | 4 +- mmlearn/tasks/linear_evaluation.py | 538 ++++++++++++++++++ mmlearn/tasks/linear_probing.py | 497 ---------------- projects/ijepa/configs/__init__.py | 45 +- ...duce_imagenet.yaml => in1k_vit_small.yaml} | 60 +- .../configs/experiment/linear_evaluation.yaml | 90 +++ .../experiment/linear_probing_eval.yaml | 15 +- 7 files changed, 715 insertions(+), 534 deletions(-) create mode 100644 mmlearn/tasks/linear_evaluation.py delete mode 100644 mmlearn/tasks/linear_probing.py rename projects/ijepa/configs/experiment/{reproduce_imagenet.yaml => in1k_vit_small.yaml} (57%) create mode 100644 projects/ijepa/configs/experiment/linear_evaluation.yaml diff --git a/mmlearn/tasks/__init__.py b/mmlearn/tasks/__init__.py index 7b295f2..244615f 100644 --- a/mmlearn/tasks/__init__.py +++ b/mmlearn/tasks/__init__.py @@ -2,7 +2,7 @@ from mmlearn.tasks.contrastive_pretraining import ContrastivePretraining from mmlearn.tasks.ijepa import IJEPA -from mmlearn.tasks.linear_probing import LinearClassifier +from mmlearn.tasks.linear_evaluation import LinearEvaluation from mmlearn.tasks.zero_shot_classification import ZeroShotClassification from mmlearn.tasks.zero_shot_retrieval import ZeroShotCrossModalRetrieval @@ -12,5 +12,5 @@ "IJEPA", "ZeroShotCrossModalRetrieval", "ZeroShotClassification", - "LinearClassifier", + "LinearEvaluation", ] diff --git a/mmlearn/tasks/linear_evaluation.py b/mmlearn/tasks/linear_evaluation.py new file mode 100644 index 0000000..fa26d04 --- /dev/null +++ b/mmlearn/tasks/linear_evaluation.py @@ -0,0 +1,538 @@ +"""A Module for linear evaluation of pretrained encoders.""" + +import re +import warnings +from contextlib import nullcontext +from functools import partial +from typing import Any, Callable, Literal, Optional, Union + +import torch +from hydra_zen import store +from torch import nn +from torchmetrics import AUROC, Accuracy, F1Score, MetricCollection, Precision, Recall + +from mmlearn.datasets.core import Modalities +from mmlearn.modules.layers import MLP +from mmlearn.tasks.base import TrainingTask + + +@store(group="task", provider="mmlearn") +class LinearEvaluation(TrainingTask): + """Linear evaluation task for pretrained encoders. + + Parameters + ---------- + encoder : torch.nn.Module + A pretrained encoder model, outputting features for the linear classifier. + modality : str + The modality of the input data to be passed through the encoder. See + `common.constants.Modality` for valid values. The target label key is + inferred from this modality. This means that, for example, that if the + modality is 'rgb', the target label key is expected to be 'rgb_target'. + num_output_features : int + Output features from the encoder, defining the linear classifier's input size. + num_classes : int + Number of classes for the classification task. + hidden_dims : list[int] + Size of each hidden layer of the model + task : str + Classification task type. One of 'binary', 'multiclass', or 'multilabel'. + freeze_encoder : bool, default = True + If True, encoder's parameters are frozen during training. + pre_classifier_batch_norm : bool, default = False + If True, a batch normalization layer without affine transformation is + added before the linear classifier, following [1]. + top_k_list : list[int], optional, default = None + A list of integers specifying the `k` values for top-k accuracy metrics. + For each `k` in this list, top-k accuracy is calculated and tracked during + training and validation. This allows for the evaluation of the model's + performance at predicting the top `k` most probable classes. + optimizer : DictConfig, optional, default = None + The configuration for the optimizer. This will be instantiated using + `hydra.utils.instantiate`, so it should include the `_target_` field, + which should point to the optimizer class. + lr_scheduler : DictConfig, optional, default = None + The configuration for the learning rate scheduler. Two fields are expected: + `scheduler` (required) and `extras` (optional). The `scheduler` field should + contain configurations for the learning rate scheduler and must include the + `_target_` field, which, like the optimizer, should point to the scheduler + class. The `extras` field may contain one or more of the following: + - `interval` (str): The interval to apply the learning rate scheduler. + One of "epoch" or "step". Default is "epoch". + - `frequency` (int): The frequency to apply the learning rate scheduler + in the specified interval. Default is 1. + - `monitor` (str): The metric to monitor for schedulers like ReduceLROnPlateau. + - `strict` (bool): Whether to strictly enforce the availability of the + monitored metric (if `True`) or raise a warning if the metric is not found + (if `False`). Default is `True`. + + Attributes + ---------- + accuracy_metrics : torchmetrics.MetricCollection + A collection of metrics that includes accuracy for each `k` in `top_k_list`, + providing a comprehensive evaluation of model performance across different + levels of top-k predictions. + linear_eval : torch.nn.Linear + Linear classification layer atop the encoder. Input and output features are + determined by `encoder_output_features` and `num_classes`, respectively. + + References + ---------- + [1] He, K., Chen, X., Xie, S., Li, Y., Doll'ar, P., & Girshick, R.B. (2021). + Masked Autoencoders Are Scalable Vision Learners. 2022 IEEE/CVF Conference + on Computer Vision and Pattern Recognition (CVPR), 15979-15988. + """ + + def __init__( + self, + encoder: nn.Module, + checkpoint_path: Optional[str], + modality: str, + num_output_features: int, + num_classes: int, + pre_classifier_batch_norm: bool = False, + classifier_hidden_dims: Optional[list[int]] = None, + classifier_norm_layer: Optional[Callable[..., torch.nn.Module]] = None, + classifier_activation_layer: Optional[ + Callable[..., torch.nn.Module] + ] = torch.nn.ReLU, + classifier_bias: Union[bool, list[bool]] = True, + classifier_dropout: Union[float, list[float]] = 0.0, + freeze_encoder: bool = True, + encoder_input_kwargs: Optional[dict[str, Any]] = None, + encoder_outputs_processor: Optional[Callable[..., torch.Tensor]] = None, + encoder_state_dict_key: str = "state_dict", + state_dict_pattern_replacement_map: Optional[dict[str, str]] = None, + state_dict_patterns_to_exclude: Optional[list[str]] = None, + task: Literal["binary", "multiclass", "multilabel"] = "multiclass", + top_k_list: Optional[list[int]] = None, + optimizer: Optional[partial[torch.optim.Optimizer]] = None, + lr_scheduler: Optional[ + Union[ + dict[str, partial[torch.optim.lr_scheduler.LRScheduler]], + partial[torch.optim.lr_scheduler.LRScheduler], + ] + ] = None, + compute_validation_loss: bool = True, + compute_test_loss: bool = True, + ) -> None: + # input validation + assert task in ["binary", "multiclass", "multilabel"], ( + f"Invalid task type: {task}. " + "Expected one of ['binary', 'multiclass', 'multilabel']." + ) + + super().__init__( + optimizer=optimizer, + lr_scheduler=lr_scheduler, + loss_fn=nn.CrossEntropyLoss() + if task == "multiclass" + else nn.BCEWithLogitsLoss(), + compute_validation_loss=compute_validation_loss, + compute_test_loss=compute_test_loss, + ) + + self.encoder = encoder + self.modality = Modalities.get_modality(modality) + self.num_output_features = num_output_features + self.num_classes = num_classes + self.pre_classifier_batch_norm = pre_classifier_batch_norm + self.freeze_encoder = freeze_encoder + self.encoder_outputs_processor = encoder_outputs_processor + self.task = task + self.top_k_list = top_k_list + self.encoder_input_kwargs = encoder_input_kwargs + + checkpoint_dict = torch.load( + checkpoint_path, map_location=self.device, weights_only=True + ) + state_dict = get_state_dict( + checkpoint_dict, + state_dict_key=encoder_state_dict_key, + pattern_replacement_map=state_dict_pattern_replacement_map, + patterns_to_exclude=state_dict_patterns_to_exclude, + ) + self.encoder.load_state_dict(state_dict) + + linear_layer = MLP( + in_dim=num_output_features, + out_dim=num_classes, + hidden_dims=classifier_hidden_dims, + norm_layer=classifier_norm_layer, + activation_layer=classifier_activation_layer, + bias=classifier_bias, + dropout=classifier_dropout, + ) + + if pre_classifier_batch_norm: + linear_layer = nn.Sequential( + nn.BatchNorm1d(num_output_features, affine=False), + linear_layer, + ) + self.classifier = linear_layer + + if self.freeze_encoder: + for param in self.encoder.parameters(): + param.requires_grad = False + + if task == "multiclass": + if self.top_k_list is None: + self.top_k_list = [1, 5] + + metrics = MetricCollection( + { + f"top_{k}_accuracy": Accuracy( + task=task, num_classes=num_classes, top_k=k + ) + for k in self.top_k_list + } + ) + + metrics.add_metrics( + { + "precision": Precision( + task=task, num_classes=num_classes, average="macro" + ), + "recall": Recall( + task=task, num_classes=num_classes, average="macro" + ), + "f1_score": F1Score( + task=task, num_classes=num_classes, average="macro" + ), + "auc": AUROC(task=task, num_classes=num_classes, average="macro"), + } + ) + elif task == "multilabel": + metrics = MetricCollection( + { + "accuracy": Accuracy(task=task, num_labels=num_classes), + "precision": Precision( + task=task, num_labels=num_classes, average="macro" + ), + "recall": Recall( + task=task, num_labels=num_classes, average="macro" + ), + "f1_score": F1Score( + task=task, num_labels=num_classes, average="macro" + ), + "auc": AUROC(task=task, num_labels=num_classes), + } + ) + else: # binary + metrics = MetricCollection( + { + "accuracy": Accuracy(task=task), + "precision": Precision(task=task), + "recall": Recall(task=task), + "f1_score": F1Score(task=task), + "auc": AUROC(task=task), + } + ) + + self._metrics = { + "train": metrics.clone(prefix="train/"), + "val": metrics.clone(prefix="val/"), + "test": metrics.clone(prefix="test/"), + } + + def forward(self, inputs: dict[str, torch.Tensor]) -> torch.Tensor: + """Perform a forward pass through the encoder and linear classifier. + + Parameters + ---------- + inputs : dict[str, torch.Tensor] + Dictionary containing input tensors for the encoder. + + Returns + ------- + torch.Tensor + The logits predicted by the classifier. + """ + with torch.no_grad() if self.freeze_encoder else nullcontext(): + enc_out = self.encoder(inputs, **self.encoder_input_kwargs) + if self.encoder_outputs_processor is not None: + enc_out = self.encoder_outputs_processor(enc_out) + + return self.classifier(enc_out) + + def on_fit_start(self) -> None: + """Move the metrics to the device of the Lightning module.""" + self._metrics = { + step_name: metric.to(self.device) + for step_name, metric in self._metrics.items() + if step_name in ["train", "val"] + } + + def on_train_epoch_start(self) -> None: + """Set the encoder to evaluation mode if it is frozen.""" + self.encoder = self.encoder.train(mode=not self.freeze_encoder) + + def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor: + """Compute the loss for the batch. + + Parameters + ---------- + batch : dict[str, Any] + The batch of data to process. + batch_idx : int + The index of the batch. + + Returns + ------- + torch.Tensor + The loss for the batch. + """ + return self._shared_step(batch, "train") + + def on_train_epoch_end(self) -> None: + """Compute metrics at the end of a training epoch.""" + self._on_epoch_end("train") + + def validation_step( + self, batch: dict[str, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """ + Execute a validation step using a single batch. + + Parameters + ---------- + batch : dict[str, torch.Tensor] + The current batch of validation data, including input tensors and labels. + batch_idx : int + The index of the current validation batch. + + Returns + ------- + torch.Tensor + The loss computed for the batch. + """ + return self._shared_step(batch, "val") + + def on_validation_epoch_end(self) -> None: + """Compute validation metrics accumulated over the epoch.""" + self._on_epoch_end("val") + + def on_test_start(self) -> None: + """Move the metrics to the device of the Lightning module.""" + self._metrics["test"] = self._metrics["test"].to(self.device) + + def test_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Execute a test step using a single batch. + + Parameters + ---------- + batch : dict[str, torch.Tensor] + The current batch of test data, including input tensors and labels. + batch_idx : int + The index of the current test batch. + + Returns + ------- + torch.Tensor + The loss computed for the batch. + """ + return self._shared_step(batch, "test") + + def on_test_epoch_end(self) -> None: + """Compute test metrics accumulated over the epoch.""" + self._on_epoch_end("test") + + def _shared_step( + self, batch: dict[str, torch.Tensor], step_name: Literal["train", "val", "test"] + ) -> Optional[torch.Tensor]: + """ + Execute a shared step for training, validation, or testing. + + Parameters + ---------- + batch : dict[str, torch.Tensor] + The current batch of data. + step_name : Literal["train", "val", "test"] + The name of the step to execute. + """ + if step_name == "train" and self.loss_fn is None: + raise ValueError("The loss function must be provided for training.") + + logits = self(batch) + y = batch[self.modality.target] + + if self.loss_fn is not None: + loss: torch.Tensor = self.loss_fn(logits, y) + self.log(f"{step_name}/loss", loss, prog_bar=True, sync_dist=True) + + self._metrics[step_name].update(logits, y) + + return loss if self.loss_fn is not None else None + + def _on_epoch_end(self, step_name: Literal["train", "val", "test"]) -> None: + """ + Compute metrics at the end of an epoch. + + Parameters + ---------- + step_name : Literal["train", "val", "test"] + The name of the step to execute + """ + metrics = self._metrics[step_name].compute() + self.log_dict(metrics, prog_bar=step_name in ["val", "test"]) + self._metrics[step_name].reset() + + +def get_state_dict( # noqa: PLR0912 + checkpoint_dict: dict[str, Any], + state_dict_key: str = "state_dict", + pattern_replacement_map: Optional[dict[str, str]] = None, + patterns_to_exclude: Optional[list[str]] = None, +) -> dict[str, Any]: + """Process a state dictionary by applying regex pattern replacements and exclusions. + + Parameters + ---------- + checkpoint_dict : dict[str, Any] + Dictionary containing the state dict in one of its keys. + state_dict_key : str, default="state_dict" + Key in ``checkpoint_dict`` containing the state dictionary to process. + pattern_replacement_map : dict[str, str], optional, default=None + Dictionary mapping regex patterns to their replacement strings. + patterns_to_exclude : list[str], optional, default=None + List of regex patterns for keys to exclude from the processed state dictionary. + + Returns + ------- + Processed state dictionary + + Raises + ------ + TypeError + If inputs are not of expected types. + KeyError + If state_dict_key is not in ``checkpoint_dict``. + ValueError + If regex patterns are invalid. + """ + if not isinstance(checkpoint_dict, dict): + raise TypeError( + "Expected ``checkpoint_dict`` to be a dictionary, " + f"but got {type(checkpoint_dict)}" + ) + if state_dict_key not in checkpoint_dict: + raise KeyError( + f"Key '{state_dict_key}' not found in ``checkpoint_dict``. " + f"Available keys: {list(checkpoint_dict.keys())}" + ) + + state_dict = checkpoint_dict[state_dict_key] + if not isinstance(state_dict, dict): + raise TypeError( + "Expected state dictionary in ``checkpoint_dict`` to be a dictionary, " + f"but got {type(state_dict)}" + ) + + if pattern_replacement_map is None: + pattern_replacement_map = {} + if patterns_to_exclude is None: + patterns_to_exclude = [] + + if not isinstance(pattern_replacement_map, dict): + raise TypeError( + "Expected ``pattern_replacement_map`` to be a dictionary, " + f"but got {type(pattern_replacement_map)}" + ) + if not isinstance(patterns_to_exclude, list): + raise TypeError( + "Expected ``patterns_to_exclude`` to be a list, " + f"but got {type(patterns_to_exclude)}" + ) + + processed_state_dict = {} + + # apply pattern replacements + for key, value in state_dict.items(): + if not isinstance(key, str): + raise TypeError( + f"Dictionary keys must be strings for regex operations, found {type(key)}" + ) + + new_key = key + for pattern, replacement in pattern_replacement_map.items(): + try: + new_key = re.sub(pattern, replacement, new_key) + except re.error as e: + raise ValueError(f"Invalid regex pattern '{pattern}': {str(e)}") from e + + # check for key collisions + if new_key in processed_state_dict: + warnings.warn( + f"Key '{new_key}' already exists and will be overwritten.", + UserWarning, + stacklevel=2, + ) + + processed_state_dict[new_key] = value + + # apply exclusions + if patterns_to_exclude: + filtered_dict = {} + for key, value in processed_state_dict.items(): + exclude_key = False + for pattern in patterns_to_exclude: + try: + if re.match(pattern, key): + exclude_key = True + break + except re.error as e: + raise ValueError( + f"Invalid regex pattern '{pattern}' in exclusion list: {str(e)}" + ) from e + + if not exclude_key: + filtered_dict[key] = value + + processed_state_dict = filtered_dict + + return processed_state_dict + + +@store(group="helpers", provider="mmlearn", zen_partial=False) # type: ignore[misc] +def avg_pool_last_n_hidden_states( + encoder_output: tuple[torch.Tensor, Optional[list[torch.Tensor]]], n: int = 1 +) -> torch.Tensor: + """Average pool the last ``n`` intermediate layer outputs of an encoder. + + Parameters + ---------- + encoder_output : tuple[torch.Tensor, Optional[list[torch.Tensor]]] + Tuple of encoder outputs where the first element is the output of the last layer + and the second element is an optional list of intermediate layer outputs. + n : int, default=1 + The number of layers to average pool. + + Returns + ------- + torch.Tensor + The average pooled encoder output. + + Raises + ------ + ValueError + If intermediate layer outputs are not available or if ``n`` is less than 1 + or greater than the number of available intermediate layers. + """ + if encoder_output[1] is None: + raise ValueError("Intermediate layer outputs are not available.") + + if n < 1: + raise ValueError("Number of layers to average pool must be greater than 0.") + + if n > len(encoder_output[1]): + raise ValueError( + f"Requested {n} layers for average pooling, but only {len(encoder_output[1])} " + "intermediate layers are available." + ) + # each layer output is a tensor of shape (batch_size, num_patches, num_features) + # take the average across the num_patches dimension, then concatenate the results + return torch.cat( + [layer_output.mean(dim=1) for layer_output in encoder_output[1][-n:]], + dim=-1, + ) diff --git a/mmlearn/tasks/linear_probing.py b/mmlearn/tasks/linear_probing.py deleted file mode 100644 index 3b54463..0000000 --- a/mmlearn/tasks/linear_probing.py +++ /dev/null @@ -1,497 +0,0 @@ -"""A Module for linear evaluation of pretrained encoders.""" - -import inspect -from contextlib import nullcontext -from functools import partial -from typing import Any, Dict, List, Literal, Optional, Tuple, Union - -import hydra -import lightning as L # noqa: N812 -import torch -from hydra_zen import store -from lightning.pytorch.utilities.types import OptimizerLRScheduler -from lightning_utilities.core.rank_zero import rank_zero_warn -from torch import nn - -from mmlearn.datasets.core import Modalities -from mmlearn.modules.layers import MLP -from mmlearn.tasks.base import TrainingTask -from mmlearn.tasks.zero_shot_classification import ZeroShotClassification - - -def extract_vision_encoder( - encoder: Any, - model_checkpoint_path: Optional[str], - modality_to_extract: Optional[str] = "rgb", - keys_to_remove: Optional[List[str]] = None, - keys_to_rename: Optional[Dict[str, str]] = None, # Default for renaming -) -> nn.Module: - """ - Extract the vision encoder from a PyTorch Lightning model. - - Args: - encoder (Any): The encoder configuration or model to be instantiated. - model_checkpoint_path (Optional[str]): Path to the checkpoint file containing - the encoder's state_dict. - keys_to_remove (Optional[list]): List of keys to be removed from the state_dict. - keys_to_rename (Optional[dict]): Dictionary of prefixes or key replacements - mapping - old prefixes to new replacements (default removes 'encoders.rgb.'). - keys_to_ignore (Optional[list]): List of keys to ignore when loading the - state_dict. - - Returns - ------- - nn.Module: The vision encoder module extracted and initialized. - """ - model: L.LightningModule = hydra.utils.instantiate(encoder) - if model_checkpoint_path is None: - rank_zero_warn( - "No model_checkpoint_path path was provided for linear evaluation." - ) - else: - checkpoint = torch.load(model_checkpoint_path) - if "state_dict" not in checkpoint: - raise KeyError("'state_dict' not found in checkpoint") - - state_dict = checkpoint["state_dict"] - - # Remove unwanted keys - if keys_to_remove: - state_dict = { - k: v for k, v in state_dict.items() if k not in keys_to_remove - } - - # Rename keys based on input mappings - if keys_to_rename: - state_dict = { - k.replace(old_prefix, new_prefix): v - for k, v in state_dict.items() - for old_prefix, new_prefix in keys_to_rename.items() - if k.startswith(old_prefix) - } - - try: - if state_dict: - model[modality_to_extract].load_state_dict(state_dict, strict=True) - print("Encoder state dict loaded successfully") - except Exception as e: - print(f"Error loading state dict: {e}") - return model[modality_to_extract] - - -@store(group="task", provider="mmlearn") -class LinearClassifier(TrainingTask): - """A linear classifier module for evaluating pretrained encoders. - - Parameters - ---------- - encoder : torch.nn.Module - A pretrained encoder model, outputting features for the linear classifier. - modality : str - The modality of the input data to be passed through the encoder. See - `common.constants.Modality` for valid values. The target label key is - inferred from this modality. This means that, for example, that if the - modality is 'rgb', the target label key is expected to be 'rgb_target'. - embed_dim : int - Output features from the encoder, defining the linear classifier's input size. - num_classes : int - Number of classes for the classification task. - hidden_dims : list[int] - Size of each hidden layer of the model - task : str - Classification task type. One of 'binary', 'multiclass', or 'multilabel'. - freeze_encoder : bool, default = True - If True, encoder's parameters are frozen during training. - pre_classifier_batch_norm : bool, default = False - If True, a batch normalization layer without affine transformation is - added before the linear classifier, following [1]. - top_k_list : List[int], optional, default = None - A list of integers specifying the `k` values for top-k accuracy metrics. - For each `k` in this list, top-k accuracy is calculated and tracked during - training and validation. This allows for the evaluation of the model's - performance at predicting the top `k` most probable classes. - optimizer : DictConfig, optional, default = None - The configuration for the optimizer. This will be instantiated using - `hydra.utils.instantiate`, so it should include the `_target_` field, - which should point to the optimizer class. - lr_scheduler : DictConfig, optional, default = None - The configuration for the learning rate scheduler. Two fields are expected: - `scheduler` (required) and `extras` (optional). The `scheduler` field should - contain configurations for the learning rate scheduler and must include the - `_target_` field, which, like the optimizer, should point to the scheduler - class. The `extras` field may contain one or more of the following: - - `interval` (str): The interval to apply the learning rate scheduler. - One of "epoch" or "step". Default is "epoch". - - `frequency` (int): The frequency to apply the learning rate scheduler - in the specified interval. Default is 1. - - `monitor` (str): The metric to monitor for schedulers like ReduceLROnPlateau. - - `strict` (bool): Whether to strictly enforce the availability of the - monitored metric (if `True`) or raise a warning if the metric is not found - (if `False`). Default is `True`. - - Attributes - ---------- - accuracy_metrics : torchmetrics.MetricCollection - A collection of metrics that includes accuracy for each `k` in `top_k_list`, - providing a comprehensive evaluation of model performance across different - levels of top-k predictions. - linear_eval : torch.nn.Linear - Linear classification layer atop the encoder. Input and output features are - determined by `encoder_output_features` and `num_classes`, respectively. - - References - ---------- - [1] He, K., Chen, X., Xie, S., Li, Y., Doll'ar, P., & Girshick, R.B. (2021). - Masked Autoencoders Are Scalable Vision Learners. 2022 IEEE/CVF Conference - on Computer Vision and Pattern Recognition (CVPR), 15979-15988. - """ - - def __init__( - self, - encoder: nn.Module, - model_checkpoint_path: Optional[str], - modality: str, - embed_dim: int, - num_classes: int, - top_k_list: List[int], - hidden_dims: Optional[List[int]] = None, - task: Literal["binary", "multiclass", "multilabel"] = "multiclass", - freeze_encoder: bool = True, - keys_to_remove: Optional[List[str]] = None, - keys_to_rename: Optional[Dict[str, str]] = None, - optimizer: Optional[partial[torch.optim.Optimizer]] = None, - pre_classifier_batch_norm: bool = False, - lr_scheduler: Optional[ - Union[ - Dict[str, partial[torch.optim.lr_scheduler.LRScheduler]], - partial[torch.optim.lr_scheduler.LRScheduler], - ] - ] = None, - ): - super().__init__(loss_fn=nn.CrossEntropyLoss()) - assert task in ["binary", "multiclass", "multilabel"], ( - f"Invalid task type: {task}. " - "Expected one of ['binary', 'multiclass', 'multilabel']." - ) - - self.modality = modality - - self.encoder: nn.Module = extract_vision_encoder( - encoder, - model_checkpoint_path, - keys_to_rename=keys_to_rename, - keys_to_remove=keys_to_remove, - ) - - linear_layer = MLP( - embed_dim, - num_classes, - hidden_dims, - norm_layer=nn.BatchNorm1d if pre_classifier_batch_norm else None, - ) - - self.classifier = linear_layer - - self.freeze_encoder = freeze_encoder - self.num_classes = num_classes - - if self.freeze_encoder: - for param in self.encoder.parameters(): - param.requires_grad = False - - if task == "multilabel": - self.loss_fn = nn.BCEWithLogitsLoss() - - self.top_k_list = top_k_list -<<<<<<< HEAD - # if task == "multiclass": - # if self.top_k_list is None: - # self.top_k_list = [1, 5] - # accuracy_metrics = { - # f"top_{k}_accuracy": Accuracy( - # task=task, num_classes=num_classes, top_k=k - # ) - # for k in self.top_k_list - # } - - # # Additional metrics for multiclass classification - # additional_metrics = { - # "precision": Precision( - # task=task, num_classes=num_classes, average="macro" - # ), - # "recall": Recall(task=task, num_classes=num_classes, average="macro"), - # "f1_score": F1Score( - # task=task, num_classes=num_classes, average="macro" - # ), - # "auc": AUROC( - # task=task, num_classes=num_classes, average="macro" - # ), # AUROC for multiclass - # } - - # elif task == "multilabel": - # # Accuracy and other metrics for multilabel classification - # accuracy_metrics = {"accuracy": Accuracy(task=task, num_labels=num_classes)} - - # # Additional metrics for multilabel classification - # additional_metrics = { - # "precision": Precision( - # task=task, num_labels=num_classes, average="macro" - # ), - # "recall": Recall(task=task, num_labels=num_classes, average="macro"), - # "f1_score": F1Score(task=task, num_labels=num_classes, average="macro"), - # "auc": AUROC(task=task, num_labels=num_classes), # AUC for multilabel - # } - - # else: # binary - # # Accuracy and other metrics for binary classification - # accuracy_metrics = {"accuracy": Accuracy(task=task)} - - # # Additional metrics for binary classification - # additional_metrics = { - # "precision": Precision(task=task), - # "recall": Recall(task=task), - # "f1_score": F1Score(task=task), - # "auc": AUROC(task=task), # AUROC for binary classification - # } - - # combine all metrics - # metrics = MetricCollection({**accuracy_metrics, **additional_metrics}) -======= - ->>>>>>> Fix pre-commit isssues - metrics = ZeroShotClassification._create_metrics( - num_classes=num_classes, top_k=self.top_k_list, prefix="", postfix="" - ) - self.train_metrics = metrics.clone(prefix="train/") - self.valid_metrics = metrics.clone(prefix="val/") - - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Perform a forward pass through the encoder and linear classifier. - - Parameters - ---------- - x : torch.Tensor - The input tensor. - - Returns - ------- - torch.Tensor - The logits predicted by the classifier. - """ - with torch.no_grad() if self.freeze_encoder else nullcontext(): - x = self.encoder(x) - return self.classifier(x[0]) - - def _get_logits_and_labels( - self, batch: Dict[str, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Return the logits and labels for a batch of data.""" - x: torch.Tensor = batch - y = batch[Modalities.get_modality(self.modality).target] - - logits = self(x) - return logits, y - - def _compute_loss(self, batch: Dict[str, Any]) -> Optional[torch.Tensor]: - if self.loss_fn is None: - return None - - if self.freeze_encoder: - self.encoder.eval() - - logits, y = self._get_logits_and_labels(batch) - - loss: torch.Tensor = self.loss_fn(logits, y) - self.train_metrics.update(logits, y) - - return loss - - def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: - """Compute the loss for the batch. - - Parameters - ---------- - batch : Dict[str, Any] - The batch of data to process. - batch_idx : int - The index of the batch. - - Returns - ------- - torch.Tensor - The loss for the batch. - """ - loss = self._compute_loss(batch) - - if loss is None: - raise ValueError("The loss function must be provided for training.") - - self.log("train/loss", loss, prog_bar=True) - - return loss - - def validation_step( - self, - batch: Dict[str, torch.Tensor], - batch_idx: int, - ) -> torch.Tensor: - """ - Execute a validation step using a single batch. - - Parameters - ---------- - batch : Dict[str, torch.Tensor] - The current batch of validation data, including input tensors and labels. - batch_idx : int - The index of the current validation batch. - - Returns - ------- - torch.Tensor - The loss computed for the batch. - """ - logits, y = self._get_logits_and_labels(batch) - - loss: torch.Tensor = self.loss_fn(logits, y) - self.log("val/loss", self.all_gather(loss.clone().detach()).mean()) - - self.valid_metrics.update(logits, y) - return loss - - def test_step( - self, - batch: Dict[str, torch.Tensor], - batch_idx: int, - ) -> torch.Tensor: - """ - Execute a test step using a single batch. - - Parameters - ---------- - batch : Dict[str, torch.Tensor] - The current batch of test data, including input tensors and labels. - batch_idx : int - The index of the current test batch. - - Returns - ------- - torch.Tensor - The loss computed for the batch. - """ - logits, y = self._get_logits_and_labels(batch) - - loss: torch.Tensor = self.loss_fn(logits, y) - self.log("val/loss", self.all_gather(loss.clone().detach()).mean()) - - self.test_metrics.update(logits, y) - return loss - - def on_validation_epoch_end(self) -> None: - """Compute validation metrics accumulated over the epoch.""" - val_metrics = self.valid_metrics.compute() - for metric, value in val_metrics.items(): - print(f" {metric}: {value.item()}") - self.log_dict(val_metrics) - self.valid_metrics.reset() - - def on_test_epoch_end(self) -> None: - """Compute test metrics accumulated over the epoch.""" - val_metrics = self.test_metrics.compute() - for metric, value in val_metrics.items(): - print(f" {metric}: {value.item()}") - self.log_dict(val_metrics) - self.test_metrics.reset() - - def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912 - """Configure the optimizer and learning rate scheduler.""" - if self.optimizer is None: - rank_zero_warn( - "Optimizer not provided. Training will continue without an optimizer. " - "LR scheduler will not be used.", - ) - return None - - weight_decay: Optional[float] = self.optimizer.keywords.get( - "weight_decay", None - ) - if weight_decay is None: # try getting default value - kw_param = inspect.signature(self.optimizer.func).parameters.get( - "weight_decay" - ) - if kw_param is not None and kw_param.default != inspect.Parameter.empty: - weight_decay = kw_param.default - - parameters = [param for param in self.parameters() if param.requires_grad] - - if weight_decay is not None: - decay_params = [] - no_decay_params = [] - - for param in self.parameters(): - if not param.requires_grad: - continue - - if param.ndim < 2: # includes all bias and normalization parameters - no_decay_params.append(param) - else: - decay_params.append(param) - - parameters = [ - { - "params": decay_params, - "weight_decay": weight_decay, - "name": "weight_decay_params", - }, - { - "params": no_decay_params, - "weight_decay": 0.0, - "name": "no_weight_decay_params", - }, - ] - - optimizer = self.optimizer(parameters) - if not isinstance(optimizer, torch.optim.Optimizer): - raise TypeError( - "Expected optimizer to be an instance of `torch.optim.Optimizer`, " - f"but got {type(optimizer)}.", - ) - - if self.lr_scheduler is not None: - if isinstance(self.lr_scheduler, dict): - if "scheduler" not in self.lr_scheduler: - raise ValueError( - "Expected 'scheduler' key in the learning rate scheduler dictionary.", - ) - - lr_scheduler = self.lr_scheduler["scheduler"](optimizer) - if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler): - raise TypeError( - "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, " - f"but got {type(lr_scheduler)}.", - ) - lr_scheduler_dict: Dict[ - str, Union[torch.optim.lr_scheduler.LRScheduler, Any] - ] = {"scheduler": lr_scheduler} - - if self.lr_scheduler.get("extras"): - extras = self.lr_scheduler["extras"] - if isinstance(extras, partial): - # Extract the keywords from the partial object - lr_scheduler_dict.update(extras.keywords) - - return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict} - - lr_scheduler = self.lr_scheduler(optimizer) - if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler): - raise TypeError( - "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, " - f"but got {type(lr_scheduler)}.", - ) - return [optimizer], [lr_scheduler] - - return optimizer diff --git a/projects/ijepa/configs/__init__.py b/projects/ijepa/configs/__init__.py index aefa1e4..1543137 100644 --- a/projects/ijepa/configs/__init__.py +++ b/projects/ijepa/configs/__init__.py @@ -7,15 +7,54 @@ import torch from torchvision import transforms from mmlearn.conf import external_store +from timm.data.transforms import ResizeKeepRatio + logger = getLogger() +@external_store(group="datasets/transforms") +def linear_eval_transforms( + crop_size: int = 224, + normalization: tuple = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + job_type: Literal["train", "eval"] = "train", +) -> transforms.Compose: + """ + Create transforms for linear evaluation. + + Parameters + ---------- + crop_size : int, default=224 + Size of the image crop. + normalization : tuple, default=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + Mean and std for normalization. + job_type : {"train", "eval"}, default="train" + Type of the job (training or evaluation) for which the transforms are needed. + + Returns + ------- + transforms.Compose + Composed transforms for linear evaluation with images. + """ + transforms_list = [] + if job_type == "train": + transforms_list.append(transforms.RandomResizedCrop(crop_size)) + transforms_list.append(transforms.RandomHorizontalFlip()) + else: + transforms_list.append(ResizeKeepRatio(crop_size + 32, interpolation="bicubic")) + transforms_list.append(transforms.CenterCrop(crop_size)) + + transforms_list.append(transforms.ToTensor()) + transforms_list.append(transforms.Normalize(normalization[0], normalization[1])) + + return transforms.Compose(transforms_list) + + @external_store(group="datasets/transforms") def ijepa_transforms( crop_size: int = 224, crop_scale: tuple = (0.3, 1.0), - color_jitter: float = 0.0, + color_jitter_strength: float = 0.0, horizontal_flip: bool = False, color_distortion: bool = False, gaussian_blur: bool = False, @@ -31,7 +70,7 @@ def ijepa_transforms( Size of the image crop. crop_scale : tuple, default=(0.3, 1.0) Range for the random resized crop scaling. - color_jitter : float, default=0.0 + color_jitter_strength : float, default=0.0 Strength of color jitter. horizontal_flip : bool, default=False Whether to apply random horizontal flip. @@ -89,7 +128,7 @@ def __call__(self, img): if horizontal_flip: transforms_list.append(transforms.RandomHorizontalFlip()) if color_distortion: - transforms_list.append(get_color_distortion(s=color_jitter)) + transforms_list.append(get_color_distortion(s=color_jitter_strength)) if gaussian_blur: transforms_list.append(GaussianBlur(p=0.5)) else: diff --git a/projects/ijepa/configs/experiment/reproduce_imagenet.yaml b/projects/ijepa/configs/experiment/in1k_vit_small.yaml similarity index 57% rename from projects/ijepa/configs/experiment/reproduce_imagenet.yaml rename to projects/ijepa/configs/experiment/in1k_vit_small.yaml index 58b01fe..4f76c8c 100644 --- a/projects/ijepa/configs/experiment/reproduce_imagenet.yaml +++ b/projects/ijepa/configs/experiment/in1k_vit_small.yaml @@ -5,13 +5,12 @@ defaults: - /datasets/transforms@datasets.train.transform: ijepa_transforms - /datasets@datasets.val: ImageNet - /datasets/transforms@datasets.val.transform: ijepa_transforms - - /modules/encoders@task.encoder: vit_base + - /modules/encoders@task.encoder: vit_small - /modules/encoders@task.predictor: vit_predictor - /modules/optimizers@task.optimizer: AdamW - - /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingLR + - /modules/lr_schedulers@task.lr_scheduler.scheduler: linear_warmup_cosine_annealing_lr - /trainer/callbacks@trainer.callbacks.lr_monitor: LearningRateMonitor - /trainer/callbacks@trainer.callbacks.model_checkpoint: ModelCheckpoint - - /trainer/callbacks@trainer.callbacks.early_stopping: EarlyStopping - /trainer/callbacks@trainer.callbacks.model_summary: ModelSummary - /trainer/logger@trainer.logger.wandb: WandbLogger - override /task: IJEPA @@ -20,6 +19,16 @@ defaults: seed: 0 datasets: + train: + transform: + color_jitter_strength: 0.4 + horizontal_flip: true + color_distortion: true + gaussian_blur: false + crop_scale: + - 0.3 + - 1.0 + crop_size: 224 val: split: val transform: @@ -28,45 +37,50 @@ datasets: dataloader: train: batch_size: 256 - num_workers: 10 + num_workers: 8 + pin_memory: true + drop_last: true val: batch_size: 256 - num_workers: 10 + num_workers: 8 + pin_memory: false task: + ema_decay: 0.996 + ema_decay_end: 1.0 + ema_anneal_end_step: ${task.lr_scheduler.scheduler.max_steps} + predictor: + kwargs: + embed_dim: 384 + predictor_embed_dim: 384 + depth: 6 + num_heads: 6 optimizer: - betas: - - 0.9 - - 0.999 lr: 1.0e-3 weight_decay: 0.05 - eps: 1.0e-8 lr_scheduler: scheduler: - T_max: ${trainer.max_epochs} + warmup_steps: 12_510 + max_steps: 125_100 + start_factor: 0.2 + eta_min: 1.0e-6 extras: - interval: epoch + interval: step trainer: - max_epochs: 300 - precision: 16-mixed + max_epochs: 100 + precision: bf16-mixed deterministic: False benchmark: True sync_batchnorm: False # Set to True if using DDP with batchnorm - log_every_n_steps: 100 - accumulate_grad_batches: 4 + log_every_n_steps: 10 + accumulate_grad_batches: 1 check_val_every_n_epoch: 1 callbacks: model_checkpoint: - monitor: val/loss - save_top_k: 1 save_last: True - every_n_epochs: 1 - dirpath: /checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID} # only works on Vector SLURM environment - early_stopping: - monitor: val/loss - patience: 5 - mode: min + every_n_epochs: 10 + dirpath: /checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID} # only works on VI's SLURM environment model_summary: max_depth: 2 diff --git a/projects/ijepa/configs/experiment/linear_evaluation.yaml b/projects/ijepa/configs/experiment/linear_evaluation.yaml new file mode 100644 index 0000000..5033578 --- /dev/null +++ b/projects/ijepa/configs/experiment/linear_evaluation.yaml @@ -0,0 +1,90 @@ +# @package _global_ + +defaults: + - /datasets@datasets.train: ImageNet + - /datasets/transforms@datasets.train.transform: linear_eval_transforms + - /datasets@datasets.val: ImageNet + - /datasets/transforms@datasets.val.transform: linear_eval_transforms + - /modules/encoders@task.encoder: vit_small + - /helpers@task.encoder_outputs_processor: avg_pool_last_n_hidden_states + - /modules/optimizers@task.optimizer: SGD + - /modules/lr_schedulers@task.lr_scheduler.scheduler: MultiStepLR + - /trainer/callbacks@trainer.callbacks.lr_monitor: LearningRateMonitor + - /trainer/callbacks@trainer.callbacks.model_checkpoint: ModelCheckpoint + - /trainer/callbacks@trainer.callbacks.model_summary: ModelSummary + - /trainer/logger@trainer.logger.wandb: WandbLogger + - override /task: LinearEvaluation + - _self_ + +seed: 0 + +datasets: + val: + split: val + transform: + job_type: eval + +dataloader: + train: + batch_size: 256 + num_workers: 8 + pin_memory: true + drop_last: true + val: + batch_size: 256 + num_workers: 8 + pin_memory: true + +task: + encoder: + kwargs: + modality: rgb + checkpoint_path: ??? + modality: ${task.encoder.kwargs.modality} + num_output_features: 1536 # 384 * task.encoder_output_processors.n for vit_small + num_classes: 1_000 + state_dict_pattern_replacement_map: + encoder.: "" + state_dict_patterns_to_exclude: + - "predictor.*" + encoder_input_kwargs: + return_hidden_states: True + encoder_outputs_processor: + _partial_: True + n: 4 + optimizer: + lr: 0.01 + weight_decay: 0.0005 + momentum: 0.9 + nesterov: True + lr_scheduler: + scheduler: + milestones: [8, 16, 24] + gamma: 0.1 + extras: + interval: epoch + +trainer: + max_epochs: 28 + precision: bf16-mixed + deterministic: False + benchmark: True + sync_batchnorm: False # Set to True if using DDP with batchnorm + log_every_n_steps: 10 + accumulate_grad_batches: 1 + check_val_every_n_epoch: 1 + callbacks: + model_checkpoint: + monitor: val/loss + save_top_k: 1 + save_last: True + every_n_epochs: 1 + dirpath: /checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID} # only works on VI's SLURM environment + model_summary: + max_depth: 2 + +tags: + - ${experiment_name} + - linear evaluation + - vit_small + - ImageNet diff --git a/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml b/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml index d92d064..913258c 100644 --- a/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml +++ b/projects/med_benchmarking/configs/experiment/linear_probing_eval.yaml @@ -9,7 +9,7 @@ defaults: - /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingLR - /modules/encoders@task.encoder.rgb: HFCLIPVisionEncoderWithProjection - /trainer/logger@trainer.logger.wandb: WandbLogger - - override /task: LinearClassifier + - override /task: LinearEvaluation - _self_ seed: 0 @@ -37,12 +37,12 @@ dataloader: shuffle: False task: - task: multiclass - num_classes: 4 - embed_dim: 512 + encoder: ??? modality: rgb - model_checkpoint_path: /projects/multimodal/checkpoints/mmlearn/med_benchmarking/vit_base_patch16_224_ep11.ckpt - top_k_list: [1] + num_output_features: ??? + checkpoint_path: ??? + state_dict_pattern_replacement_map: + "encoders.rgb.": "" optimizer: betas: - 0.9 @@ -55,9 +55,6 @@ task: T_max: 188 # make sure to change this if max_epochs or accumulate_grad_batches is changed extras: interval: step - keys_to_rename: - "encoders.rgb.": "" - trainer: precision: 16-mixed