From b689cb121c32a3b2315348a220ac9360e54ce667 Mon Sep 17 00:00:00 2001 From: Jklubienski Date: Wed, 10 Sep 2025 08:30:59 +0100 Subject: [PATCH 1/3] Implement TIGER TIL regression task --- .../offline/regression/tiger_til_score.yaml | 137 ++++++++++++++++ src/eva/core/data/datasets/__init__.py | 6 + .../core/data/datasets/regression/__init__.py | 6 + .../data/datasets/regression/embeddings.py | 39 +++++ .../datasets/regression/multi_embeddings.py | 108 +++++++++++++ src/eva/core/data/transforms/__init__.py | 4 +- .../core/data/transforms/dtype/__init__.py | 3 +- src/eva/core/data/transforms/dtype/tensor.py | 15 ++ src/eva/core/metrics/__init__.py | 7 +- src/eva/core/metrics/defaults/__init__.py | 6 +- .../metrics/defaults/regression/__init__.py | 5 + .../defaults/regression/regression_metrics.py | 33 ++++ src/eva/vision/data/datasets/__init__.py | 2 + .../data/datasets/regression/__init__.py | 5 + .../datasets/regression/tiger_til_score.py | 148 ++++++++++++++++++ 15 files changed, 516 insertions(+), 8 deletions(-) create mode 100644 configs/vision/pathology/offline/regression/tiger_til_score.yaml create mode 100644 src/eva/core/data/datasets/regression/__init__.py create mode 100644 src/eva/core/data/datasets/regression/embeddings.py create mode 100644 src/eva/core/data/datasets/regression/multi_embeddings.py create mode 100644 src/eva/core/data/transforms/dtype/tensor.py create mode 100644 src/eva/core/metrics/defaults/regression/__init__.py create mode 100644 src/eva/core/metrics/defaults/regression/regression_metrics.py create mode 100644 src/eva/vision/data/datasets/regression/__init__.py create mode 100644 src/eva/vision/data/datasets/regression/tiger_til_score.py diff --git a/configs/vision/pathology/offline/regression/tiger_til_score.yaml b/configs/vision/pathology/offline/regression/tiger_til_score.yaml new file mode 100644 index 000000000..8b105b54e --- /dev/null +++ b/configs/vision/pathology/offline/regression/tiger_til_score.yaml @@ -0,0 +1,137 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 20} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/tiger_til} + max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 100} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: ${oc.env:SAVE_LAST, false} + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MeanAbsoluteError} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: ${oc.env:PATIENCE, 20} + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.ClassificationEmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings/${oc.env:MODEL_NAME, dino_vits16}/tiger_til} + dataloader_idx_map: + 0: train + 1: val + 2: test + metadata_keys: ["wsi_id"] + backbone: + class_path: eva.vision.models.ModelFromRegistry + init_args: + model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} + model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} + overwrite: false + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.HeadModule + init_args: + head: + class_path: eva.vision.models.networks.ABMIL + init_args: + input_size: ${oc.env:IN_FEATURES, 384} + output_size: &NUM_CLASSES 1 + # task: regression + criterion: torch.nn.MSELoss + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.001} + betas: [0.9, 0.999] + metrics: + common: + - class_path: eva.core.metrics.AverageLoss + - class_path: eva.core.metrics.RegressionMetrics + init_args: + prefix: null + postfix: null +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.datasets.MultiEmbeddingsRegressionDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + embeddings_transforms: + class_path: eva.core.data.transforms.Pad2DTensor + init_args: + pad_size: &N_PATCHES ${oc.env:N_PATCHES, 200} + target_transforms: + class_path: eva.core.data.transforms.dtype.SqueezeTensor + val: + class_path: eva.datasets.MultiEmbeddingsRegressionDataset + init_args: + <<: *DATASET_ARGS + split: val + test: + class_path: eva.datasets.MultiEmbeddingsRegressionDataset + init_args: + <<: *DATASET_ARGS + split: test + predict: + - class_path: eva.vision.datasets.TIGERTILScore + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/training/wsitils} + sampler: + class_path: eva.vision.data.wsi.patching.samplers.ForegroundGridSampler + init_args: + max_samples: *N_PATCHES + width: 224 + height: 224 + target_mpp: 0.5 + split: train + coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv + image_transforms: + class_path: eva.vision.data.transforms.common.ResizeAndCrop + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.TIGERTILScore + init_args: + <<: *PREDICT_DATASET_ARGS + split: val + - class_path: eva.vision.datasets.TIGERTILScore + init_args: + <<: *PREDICT_DATASET_ARGS + split: test + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32} + num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} + shuffle: true + val: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + test: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} + num_workers: *N_DATA_WORKERS diff --git a/src/eva/core/data/datasets/__init__.py b/src/eva/core/data/datasets/__init__.py index c5e366827..6da04d3b4 100644 --- a/src/eva/core/data/datasets/__init__.py +++ b/src/eva/core/data/datasets/__init__.py @@ -6,6 +6,10 @@ MultiEmbeddingsClassificationDataset, ) from eva.core.data.datasets.dataset import TorchDataset +from eva.core.data.datasets.regression import ( + EmbeddingsRegressionDataset, + MultiEmbeddingsRegressionDataset, +) from eva.core.data.datasets.typings import DataSample __all__ = [ @@ -13,6 +17,8 @@ "MapDataset", "EmbeddingsClassificationDataset", "MultiEmbeddingsClassificationDataset", + "EmbeddingsRegressionDataset", + "MultiEmbeddingsRegressionDataset", "TorchDataset", "DataSample", ] diff --git a/src/eva/core/data/datasets/regression/__init__.py b/src/eva/core/data/datasets/regression/__init__.py new file mode 100644 index 000000000..2c000653f --- /dev/null +++ b/src/eva/core/data/datasets/regression/__init__.py @@ -0,0 +1,6 @@ +"""Embedding regression datasets API.""" + +from eva.core.data.datasets.regression.embeddings import EmbeddingsRegressionDataset +from eva.core.data.datasets.regression.multi_embeddings import MultiEmbeddingsRegressionDataset + +__all__ = ["EmbeddingsRegressionDataset", "MultiEmbeddingsRegressionDataset"] diff --git a/src/eva/core/data/datasets/regression/embeddings.py b/src/eva/core/data/datasets/regression/embeddings.py new file mode 100644 index 000000000..fa995177d --- /dev/null +++ b/src/eva/core/data/datasets/regression/embeddings.py @@ -0,0 +1,39 @@ +"""Embeddings regression dataset.""" + +import os + +import torch +from typing_extensions import override + +from eva.core.data.datasets import embeddings as embeddings_base + + +class EmbeddingsRegressionDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]): + """Embeddings dataset class for regression tasks. + + NOTE: This barely changes from the EmbeddingsClassificationDataset + but they have been kept apart for abstraction + + """ + + @override + def load_embeddings(self, index: int) -> torch.Tensor: + filename = self.filename(index) + embeddings_path = os.path.join(self._root, filename) + tensor = torch.load(embeddings_path, map_location="cpu") + if isinstance(tensor, list): + if len(tensor) > 1: + raise ValueError( + f"Expected a single tensor in the .pt file, but found {len(tensor)}." + ) + tensor = tensor[0] + return tensor.squeeze(0) + + @override + def load_target(self, index: int) -> torch.Tensor: + target = self._data.at[index, self._column_mapping["target"]] + return torch.tensor(float(target), dtype=torch.float32) + + @override + def __len__(self) -> int: + return len(self._data) diff --git a/src/eva/core/data/datasets/regression/multi_embeddings.py b/src/eva/core/data/datasets/regression/multi_embeddings.py new file mode 100644 index 000000000..57e9bce74 --- /dev/null +++ b/src/eva/core/data/datasets/regression/multi_embeddings.py @@ -0,0 +1,108 @@ +"""Dataset class for where a sample corresponds to multiple embeddings (regression).""" + +import os +from typing import Callable, Dict, List, Literal + +import torch +from typing_extensions import override + +from eva.core.data.datasets import embeddings as embeddings_base + + +class MultiEmbeddingsRegressionDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]): + """Dataset class for regression with multiple embeddings per sample.""" + + def __init__( + self, + root: str, + manifest_file: str, + split: Literal["train", "val", "test"], + column_mapping: Dict[str, str] = embeddings_base.default_column_mapping, + embeddings_transforms: Callable | None = None, + target_transforms: Callable | None = None, + ): + """Initialize dataset. + + Expects a manifest file listing the paths of `.pt` files containing tensor embeddings. + + The manifest must have a `column_mapping["multi_id"]` column that contains the + unique identifier group of embeddings. For oncology datasets, this would be usually + the slide id. Each row in the manifest file points to a .pt file that can contain + one or multiple embeddings (either as a list or stacked tensors). There can also be + multiple rows for the same `multi_id`, in which case the embeddings from the different + .pt files corresponding to that same `multi_id` will be stacked along the first dimension. + + Args: + root: Root directory of the dataset. + manifest_file: The path to the manifest file, which is relative to + the `root` argument. + split: The dataset split to use. The `split` column of the manifest + file will be splitted based on this value. + column_mapping: Defines the map between the variables and the manifest + columns. It will overwrite the `default_column_mapping` with + the provided values, so that `column_mapping` can contain only the + values which are altered or missing. + embeddings_transforms: A function/transform that transforms the embedding. + target_transforms: A function/transform that transforms the target. + """ + super().__init__( + manifest_file=manifest_file, + root=root, + split=split, + column_mapping=column_mapping, + embeddings_transforms=embeddings_transforms, + target_transforms=target_transforms, + ) + self._multi_ids: List[int] + + @override + def setup(self): + super().setup() + self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique()) + + @override + def load_embeddings(self, index: int) -> torch.Tensor: + """Loads and stacks all embedding corresponding to the `index`'th multi_id.""" + # Get all embeddings for the given index (multi_id) + multi_id = self._multi_ids[index] + embedding_paths = self._data.loc[ + self._data[self._column_mapping["multi_id"]] == multi_id, + self._column_mapping["path"], + ].to_list() + + embeddings = [] + for path in embedding_paths: + embedding = torch.load(os.path.join(self._root, path), map_location="cpu") + if isinstance(embedding, list): + embedding = torch.stack(embedding, dim=0) + embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding) + embeddings = torch.cat(embeddings, dim=0) + + if embeddings.ndim != 2: + raise ValueError( + f"Expected 2D tensor, got \ + {embeddings.ndim} for {multi_id}." + ) + + return embeddings + + @override + def load_target(self, index: int) -> torch.Tensor: + """Returns the target corresponding to the `index`'th multi_id. + + This method assumes that all the embeddings corresponding to the same `multi_id` + have the same target. If this is not the case, it will raise an error. + """ + multi_id = self._multi_ids[index] + targets = self._data.loc[ + self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"] + ] + + if not targets.nunique() == 1: + raise ValueError(f"Multiple targets found for {multi_id}.") + + return torch.tensor(targets.iloc[0], dtype=torch.float32) + + @override + def __len__(self) -> int: + return len(self._multi_ids) diff --git a/src/eva/core/data/transforms/__init__.py b/src/eva/core/data/transforms/__init__.py index 385ba9e15..f36289d7d 100644 --- a/src/eva/core/data/transforms/__init__.py +++ b/src/eva/core/data/transforms/__init__.py @@ -1,7 +1,7 @@ """Core data transforms.""" -from eva.core.data.transforms.dtype import ArrayToFloatTensor, ArrayToTensor +from eva.core.data.transforms.dtype import ArrayToFloatTensor, ArrayToTensor, SqueezeTensor from eva.core.data.transforms.padding import Pad2DTensor from eva.core.data.transforms.sampling import SampleFromAxis -__all__ = ["ArrayToFloatTensor", "ArrayToTensor", "Pad2DTensor", "SampleFromAxis"] +__all__ = ["ArrayToFloatTensor", "ArrayToTensor", "Pad2DTensor", "SampleFromAxis", "SqueezeTensor"] diff --git a/src/eva/core/data/transforms/dtype/__init__.py b/src/eva/core/data/transforms/dtype/__init__.py index 50b6fb207..943e6d431 100644 --- a/src/eva/core/data/transforms/dtype/__init__.py +++ b/src/eva/core/data/transforms/dtype/__init__.py @@ -1,5 +1,6 @@ """Type casting related transforms.""" from eva.core.data.transforms.dtype.array import ArrayToFloatTensor, ArrayToTensor +from eva.core.data.transforms.dtype.tensor import SqueezeTensor -__all__ = ["ArrayToFloatTensor", "ArrayToTensor"] +__all__ = ["ArrayToFloatTensor", "ArrayToTensor", "SqueezeTensor"] diff --git a/src/eva/core/data/transforms/dtype/tensor.py b/src/eva/core/data/transforms/dtype/tensor.py new file mode 100644 index 000000000..6d4a3d43c --- /dev/null +++ b/src/eva/core/data/transforms/dtype/tensor.py @@ -0,0 +1,15 @@ +"""Transformations to change the shape of tensors.""" + +import torch + + +class SqueezeTensor: + """Squeezes a [B, 1] tensor to [B].""" + + def __call__(self, tensor: torch.Tensor) -> torch.Tensor: + """Call method for the transformation. + + Args: + tensor: The input tensor to be squeezed. + """ + return tensor.squeeze(-1) diff --git a/src/eva/core/metrics/__init__.py b/src/eva/core/metrics/__init__.py index aed8c33ea..32b5af0c4 100644 --- a/src/eva/core/metrics/__init__.py +++ b/src/eva/core/metrics/__init__.py @@ -2,7 +2,11 @@ from eva.core.metrics.average_loss import AverageLoss from eva.core.metrics.binary_balanced_accuracy import BinaryBalancedAccuracy -from eva.core.metrics.defaults import BinaryClassificationMetrics, MulticlassClassificationMetrics +from eva.core.metrics.defaults import ( + BinaryClassificationMetrics, + MulticlassClassificationMetrics, + RegressionMetrics, +) from eva.core.metrics.structs import Metric, MetricCollection, MetricModule, MetricsSchema __all__ = [ @@ -10,6 +14,7 @@ "BinaryBalancedAccuracy", "BinaryClassificationMetrics", "MulticlassClassificationMetrics", + "RegressionMetrics", "Metric", "MetricCollection", "MetricModule", diff --git a/src/eva/core/metrics/defaults/__init__.py b/src/eva/core/metrics/defaults/__init__.py index be65d7579..3a9bb789c 100644 --- a/src/eva/core/metrics/defaults/__init__.py +++ b/src/eva/core/metrics/defaults/__init__.py @@ -4,8 +4,6 @@ BinaryClassificationMetrics, MulticlassClassificationMetrics, ) +from eva.core.metrics.defaults.regression import RegressionMetrics -__all__ = [ - "MulticlassClassificationMetrics", - "BinaryClassificationMetrics", -] +__all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics", "RegressionMetrics"] diff --git a/src/eva/core/metrics/defaults/regression/__init__.py b/src/eva/core/metrics/defaults/regression/__init__.py new file mode 100644 index 000000000..083c20122 --- /dev/null +++ b/src/eva/core/metrics/defaults/regression/__init__.py @@ -0,0 +1,5 @@ +"""Default regression metric collections API.""" + +from eva.core.metrics.defaults.regression.regression_metrics import RegressionMetrics + +__all__ = ["RegressionMetrics"] diff --git a/src/eva/core/metrics/defaults/regression/regression_metrics.py b/src/eva/core/metrics/defaults/regression/regression_metrics.py new file mode 100644 index 000000000..0a51b6130 --- /dev/null +++ b/src/eva/core/metrics/defaults/regression/regression_metrics.py @@ -0,0 +1,33 @@ +"""Default metric collection for regression tasks.""" + +from torchmetrics import MeanAbsoluteError, MeanSquaredError, R2Score + +from eva.core.metrics import structs + + +class RegressionMetrics(structs.MetricCollection): + """Default metrics for regression tasks. + + Supports: + Mean Absolute Error + Root Mean Squared Error + R^2 score + """ + + def __init__( + self, + prefix: str | None = None, + postfix: str | None = None, + ) -> None: + """Initialises regression metrics. + + Args: + prefix: A string to prepend to metric names. + postfix: A string to append after metric names. + """ + super().__init__( + metrics=[MeanAbsoluteError(), MeanSquaredError(squared=False), R2Score()], + prefix=prefix, + postfix=postfix, + compute_groups=[["MeanAbsoluteError", "MeanSquaredError", "R2Score"]], + ) diff --git a/src/eva/vision/data/datasets/__init__.py b/src/eva/vision/data/datasets/__init__.py index 95ed8d847..81da744f7 100644 --- a/src/eva/vision/data/datasets/__init__.py +++ b/src/eva/vision/data/datasets/__init__.py @@ -14,6 +14,7 @@ UniToPatho, WsiClassificationDataset, ) +from eva.vision.data.datasets.regression import TIGERTILScore from eva.vision.data.datasets.segmentation import ( BCSS, BTCV, @@ -49,4 +50,5 @@ "VisionDataset", "MultiWsiDataset", "WsiDataset", + "TIGERTILScore", ] diff --git a/src/eva/vision/data/datasets/regression/__init__.py b/src/eva/vision/data/datasets/regression/__init__.py new file mode 100644 index 000000000..6c5f51cdf --- /dev/null +++ b/src/eva/vision/data/datasets/regression/__init__.py @@ -0,0 +1,5 @@ +"""Regression datasets API.""" + +from eva.vision.data.datasets.regression.tiger_til_score import TIGERTILScore + +__all__ = ["TIGERTILScore"] diff --git a/src/eva/vision/data/datasets/regression/tiger_til_score.py b/src/eva/vision/data/datasets/regression/tiger_til_score.py new file mode 100644 index 000000000..b91e05197 --- /dev/null +++ b/src/eva/vision/data/datasets/regression/tiger_til_score.py @@ -0,0 +1,148 @@ +"""Tiger dataset class for regression targets.""" + +import functools +import glob +import os +from pathlib import Path +from typing import Any, Callable, Dict, List, Literal + +import pandas as pd +import torch +from torchvision import tv_tensors +from torchvision.transforms.v2 import functional +from typing_extensions import override + +from eva.vision.data.datasets import _validators, vision, wsi +from eva.vision.data.wsi.patching import samplers + + +class TIGERTILScore(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]): + """Dataset class for TIGERBULK regression tasks with per-slide targets.""" + + def __init__( + self, + root: str, + sampler: samplers.Sampler, + split: Literal["train", "val", "test"] | None = None, + width: int = 224, + height: int = 224, + target_mpp: float = 0.5, + backend: str = "openslide", + image_transforms: Callable | None = None, + coords_path: str | None = None, + seed: int = 42, + n_patches: int = 200, + ) -> None: + """Initializes the dataset. + + Args: + root: Root directory of the dataset. + sampler: The sampler to use for sampling patch coordinates. + split: Dataset split to use. If `None`, the entire dataset is used. + width: Patch width in pixels. + height: Patch height in pixels. + target_mpp: Target microns per pixel (mpp) for patches. + backend: WSI reading backend. + image_transforms: Transforms to apply to patches. + coords_path: Optional path to save patch coordinates. + seed: Random seed. + n_patches: Number of patches per slide. + targets_csv: Path to CSV containing per-slide regression targets. + Must have columns: slide_name,target + """ + self._split = split + self._root = root + self._width = width + self._height = height + self._target_mpp = target_mpp + self._seed = seed + self._n_patches = n_patches + + wsi.MultiWsiDataset.__init__( + self, + root=root, + file_paths=self._load_file_paths(split), + width=width, + height=height, + sampler=sampler, + target_mpp=target_mpp, + backend=backend, + image_transforms=image_transforms, + coords_path=coords_path, + ) + + @functools.cached_property + def annotations(self) -> Dict[str, float]: + """Loads per-slide regression targets from a CSV file. + + Expected CSV format: + image-id,tils-score + 103S,0.70 + ... + """ + targets_csv_path = os.path.join(self._root, "tiger-til-scores-wsitils.csv") + + if not os.path.isfile(targets_csv_path): + raise FileNotFoundError(f"Targets CSV file not found at: {targets_csv_path}") + + df = pd.read_csv(targets_csv_path) + if not {"image-id", "tils-score"} <= set(df.columns): + raise ValueError("targets_csv must contain 'image-id' and 'tils-score' columns.") + + return {str(row["image-id"]): float(row["tils-score"]) for _, row in df.iterrows()} + + @override + def prepare_data(self) -> None: + _validators.check_dataset_exists(self._root, False) + + @override + def __getitem__(self, index: int): + return vision.VisionDataset.__getitem__(self, index) + + @override + def load_data(self, index: int) -> tv_tensors.Image: + image_array = wsi.MultiWsiDataset.__getitem__(self, index) + return functional.to_image(image_array) + + @override + def load_target(self, index: int) -> torch.Tensor: + slide_idx = index // self._n_patches + file_path = self._file_paths[slide_idx] + slide_name = Path(file_path).stem + + target_value = self.annotations[slide_name] + tensor = torch.tensor([target_value], dtype=torch.float32) + return tensor + + @override + def load_metadata(self, index: int) -> Dict[str, Any]: + return wsi.MultiWsiDataset.load_metadata(self, index) + + def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]: + """Loads the file paths of WSIs from wsitils/images. + + Splits are assigned 70% train, 15% val, 15% test by filename sorting. + """ + image_dir = os.path.join(self._root, "images") + + all_paths = sorted(glob.glob(os.path.join(image_dir, "*.tif"))) + + if not all_paths: + raise FileNotFoundError(f"No .tif files found in {image_dir}") + + n_total = len(all_paths) + n_train = int(n_total * 0.7) + n_val = int(n_total * 0.15) + + if split == "train": + selected_paths = all_paths[:n_train] + elif split == "val": + selected_paths = all_paths[n_train : n_train + n_val] + elif split == "test": + selected_paths = all_paths[n_train + n_val :] + elif split is None: + selected_paths = all_paths + else: + raise ValueError("Invalid split. Use 'train', 'val', 'test', or None.") + + return [os.path.relpath(path, self._root) for path in selected_paths] From ef84783719324c7439c23af9188cd41d6b209243 Mon Sep 17 00:00:00 2001 From: Jklubienski Date: Wed, 24 Sep 2025 10:00:40 +0100 Subject: [PATCH 2/3] Refactor for clarity and address suggested changes --- .../offline/regression/tiger_til_score.yaml | 8 +- .../classification/multi_embeddings.py | 108 +-------------- .../core/data/datasets/multi_embeddings.py | 114 ++++++++++++++++ .../data/datasets/regression/embeddings.py | 30 +--- .../datasets/regression/multi_embeddings.py | 112 ++------------- src/eva/core/data/transforms/__init__.py | 4 +- .../core/data/transforms/dtype/__init__.py | 3 +- src/eva/core/data/transforms/dtype/tensor.py | 15 -- .../defaults/regression/regression_metrics.py | 8 +- .../datasets/regression/tiger_til_score.py | 114 ++-------------- src/eva/vision/data/datasets/tiger.py | 129 ++++++++++++++++++ src/eva/vision/data/datasets/wsi.py | 6 +- src/eva/vision/models/networks/abmil.py | 2 +- 13 files changed, 290 insertions(+), 363 deletions(-) create mode 100644 src/eva/core/data/datasets/multi_embeddings.py delete mode 100644 src/eva/core/data/transforms/dtype/tensor.py create mode 100644 src/eva/vision/data/datasets/tiger.py diff --git a/configs/vision/pathology/offline/regression/tiger_til_score.yaml b/configs/vision/pathology/offline/regression/tiger_til_score.yaml index 8b105b54e..99f42ad07 100644 --- a/configs/vision/pathology/offline/regression/tiger_til_score.yaml +++ b/configs/vision/pathology/offline/regression/tiger_til_score.yaml @@ -19,7 +19,7 @@ trainer: filename: best save_last: ${oc.env:SAVE_LAST, false} save_top_k: 1 - monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MeanAbsoluteError} + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MAE} mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: @@ -53,8 +53,6 @@ model: class_path: eva.vision.models.networks.ABMIL init_args: input_size: ${oc.env:IN_FEATURES, 384} - output_size: &NUM_CLASSES 1 - # task: regression criterion: torch.nn.MSELoss optimizer: class_path: torch.optim.AdamW @@ -83,7 +81,9 @@ data: init_args: pad_size: &N_PATCHES ${oc.env:N_PATCHES, 200} target_transforms: - class_path: eva.core.data.transforms.dtype.SqueezeTensor + class_path: eva.vision.data.transforms.common.Squeeze + init_args: + dim: -1 val: class_path: eva.datasets.MultiEmbeddingsRegressionDataset init_args: diff --git a/src/eva/core/data/datasets/classification/multi_embeddings.py b/src/eva/core/data/datasets/classification/multi_embeddings.py index 399d5eab9..ba8a1e223 100644 --- a/src/eva/core/data/datasets/classification/multi_embeddings.py +++ b/src/eva/core/data/datasets/classification/multi_embeddings.py @@ -1,110 +1,16 @@ -"""Dataset class for where a sample corresponds to multiple embeddings.""" - -import os -from typing import Callable, Dict, List, Literal +"""Dataset class for where a classification task sample corresponds to multiple embeddings.""" import numpy as np -import torch -from typing_extensions import override -from eva.core.data.datasets import embeddings as embeddings_base +from eva.core.data.datasets.multi_embeddings import MultiEmbeddingsDataset -class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]): +class MultiEmbeddingsClassificationDataset(MultiEmbeddingsDataset): """Dataset class for where a sample corresponds to multiple embeddings. - Example use case: Slide level dataset where each slide has multiple patch embeddings. + Specialised for classification data with an int target type. """ - def __init__( - self, - root: str, - manifest_file: str, - split: Literal["train", "val", "test"], - column_mapping: Dict[str, str] = embeddings_base.default_column_mapping, - embeddings_transforms: Callable | None = None, - target_transforms: Callable | None = None, - ): - """Initialize dataset. - - Expects a manifest file listing the paths of `.pt` files containing tensor embeddings. - - The manifest must have a `column_mapping["multi_id"]` column that contains the - unique identifier group of embeddings. For oncology datasets, this would be usually - the slide id. Each row in the manifest file points to a .pt file that can contain - one or multiple embeddings (either as a list or stacked tensors). There can also be - multiple rows for the same `multi_id`, in which case the embeddings from the different - .pt files corresponding to that same `multi_id` will be stacked along the first dimension. - - Args: - root: Root directory of the dataset. - manifest_file: The path to the manifest file, which is relative to - the `root` argument. - split: The dataset split to use. The `split` column of the manifest - file will be splitted based on this value. - column_mapping: Defines the map between the variables and the manifest - columns. It will overwrite the `default_column_mapping` with - the provided values, so that `column_mapping` can contain only the - values which are altered or missing. - embeddings_transforms: A function/transform that transforms the embedding. - target_transforms: A function/transform that transforms the target. - """ - super().__init__( - manifest_file=manifest_file, - root=root, - split=split, - column_mapping=column_mapping, - embeddings_transforms=embeddings_transforms, - target_transforms=target_transforms, - ) - - self._multi_ids: List[int] - - @override - def setup(self): - super().setup() - self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique()) - - @override - def load_embeddings(self, index: int) -> torch.Tensor: - """Loads and stacks all embedding corresponding to the `index`'th multi_id.""" - # Get all embeddings for the given index (multi_id) - multi_id = self._multi_ids[index] - embedding_paths = self._data.loc[ - self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"] - ].to_list() - - # Load embeddings and stack them accross the first dimension - embeddings = [] - for path in embedding_paths: - embedding = torch.load(os.path.join(self._root, path), map_location="cpu") - if isinstance(embedding, list): - embedding = torch.stack(embedding, dim=0) - embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding) - embeddings = torch.cat(embeddings, dim=0) - - if not embeddings.ndim == 2: - raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.") - - return embeddings - - @override - def load_target(self, index: int) -> np.ndarray: - """Returns the target corresponding to the `index`'th multi_id. - - This method assumes that all the embeddings corresponding to the same `multi_id` - have the same target. If this is not the case, it will raise an error. - """ - multi_id = self._multi_ids[index] - targets = self._data.loc[ - self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"] - ] - - if not targets.nunique() == 1: - raise ValueError(f"Multiple targets found for {multi_id}.") - - return np.asarray(targets.iloc[0], dtype=np.int64) - - @override - def __len__(self) -> int: - return len(self._multi_ids) + def __init__(self, *args, **kwargs): + """Initialize dataset with the correct return type.""" + super().__init__(*args, target_type=np.int64, **kwargs) diff --git a/src/eva/core/data/datasets/multi_embeddings.py b/src/eva/core/data/datasets/multi_embeddings.py new file mode 100644 index 000000000..dc7213569 --- /dev/null +++ b/src/eva/core/data/datasets/multi_embeddings.py @@ -0,0 +1,114 @@ +"""Dataset class for where a sample corresponds to multiple embeddings.""" + +import os +from typing import Any, Callable, Dict, List, Literal + +import numpy as np +import numpy.typing as npt +import torch +from typing_extensions import override + +from eva.core.data.datasets import embeddings as embeddings_base + + +class MultiEmbeddingsDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]): + """Dataset class for where a sample corresponds to multiple embeddings. + + Example use case: Slide level dataset where each slide has multiple patch embeddings. + """ + + def __init__( + self, + root: str, + manifest_file: str, + split: Literal["train", "val", "test"], + column_mapping: Dict[str, str] = embeddings_base.default_column_mapping, + embeddings_transforms: Callable | None = None, + target_transforms: Callable | None = None, + target_type: type[np.generic] = np.int64, + ): + """Initialize dataset. + + Expects a manifest file listing the paths of `.pt` files containing tensor embeddings. + + The manifest must have a `column_mapping["multi_id"]` column that contains the + unique identifier group of embeddings. For oncology datasets, this would be usually + the slide id. Each row in the manifest file points to a .pt file that can contain + one or multiple embeddings (either as a list or stacked tensors). There can also be + multiple rows for the same `multi_id`, in which case the embeddings from the different + .pt files corresponding to that same `multi_id` will be stacked along the first dimension. + + Args: + root: Root directory of the dataset. + manifest_file: The path to the manifest file, which is relative to + the `root` argument. + split: The dataset split to use. The `split` column of the manifest + file will be splitted based on this value. + column_mapping: Defines the map between the variables and the manifest + columns. It will overwrite the `default_column_mapping` with + the provided values, so that `column_mapping` can contain only the + values which are altered or missing. + embeddings_transforms: A function/transform that transforms the embedding. + target_transforms: A function/transform that transforms the target. + target_type: Desired type of the target data + """ + super().__init__( + manifest_file=manifest_file, + root=root, + split=split, + column_mapping=column_mapping, + embeddings_transforms=embeddings_transforms, + target_transforms=target_transforms, + ) + + self._multi_ids: List[int] + self._target_type = target_type + + @override + def setup(self): + super().setup() + self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique()) + + @override + def load_embeddings(self, index: int) -> torch.Tensor: + """Loads and stacks all embedding corresponding to the `index`'th multi_id.""" + # Get all embeddings for the given index (multi_id) + multi_id = self._multi_ids[index] + embedding_paths = self._data.loc[ + self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"] + ].to_list() + + # Load embeddings and stack them accross the first dimension + embeddings = [] + for path in embedding_paths: + embedding = torch.load(os.path.join(self._root, path), map_location="cpu") + if isinstance(embedding, list): + embedding = torch.stack(embedding, dim=0) + embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding) + embeddings = torch.cat(embeddings, dim=0) + + if not embeddings.ndim == 2: + raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.") + + return embeddings + + @override + def load_target(self, index: int) -> npt.NDArray[Any]: + """Returns the target corresponding to the `index`'th multi_id. + + This method assumes that all the embeddings corresponding to the same `multi_id` + have the same target. If this is not the case, it will raise an error. + """ + multi_id = self._multi_ids[index] + targets = self._data.loc[ + self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"] + ] + + if not targets.nunique() == 1: + raise ValueError(f"Multiple targets found for {multi_id}.") + + return np.asarray(targets.iloc[0], dtype=self._target_type) + + @override + def __len__(self) -> int: + return len(self._multi_ids) diff --git a/src/eva/core/data/datasets/regression/embeddings.py b/src/eva/core/data/datasets/regression/embeddings.py index fa995177d..50dd01671 100644 --- a/src/eva/core/data/datasets/regression/embeddings.py +++ b/src/eva/core/data/datasets/regression/embeddings.py @@ -1,39 +1,15 @@ """Embeddings regression dataset.""" -import os - import torch from typing_extensions import override -from eva.core.data.datasets import embeddings as embeddings_base - +from eva.core.data.datasets.classification import EmbeddingsClassificationDataset -class EmbeddingsRegressionDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]): - """Embeddings dataset class for regression tasks. - NOTE: This barely changes from the EmbeddingsClassificationDataset - but they have been kept apart for abstraction - - """ - - @override - def load_embeddings(self, index: int) -> torch.Tensor: - filename = self.filename(index) - embeddings_path = os.path.join(self._root, filename) - tensor = torch.load(embeddings_path, map_location="cpu") - if isinstance(tensor, list): - if len(tensor) > 1: - raise ValueError( - f"Expected a single tensor in the .pt file, but found {len(tensor)}." - ) - tensor = tensor[0] - return tensor.squeeze(0) +class EmbeddingsRegressionDataset(EmbeddingsClassificationDataset): + """Embeddings dataset class for regression tasks.""" @override def load_target(self, index: int) -> torch.Tensor: target = self._data.at[index, self._column_mapping["target"]] return torch.tensor(float(target), dtype=torch.float32) - - @override - def __len__(self) -> int: - return len(self._data) diff --git a/src/eva/core/data/datasets/regression/multi_embeddings.py b/src/eva/core/data/datasets/regression/multi_embeddings.py index 57e9bce74..d3db9cee5 100644 --- a/src/eva/core/data/datasets/regression/multi_embeddings.py +++ b/src/eva/core/data/datasets/regression/multi_embeddings.py @@ -1,108 +1,16 @@ -"""Dataset class for where a sample corresponds to multiple embeddings (regression).""" +"""Dataset class for where a regression task sample corresponds to multiple embeddings.""" -import os -from typing import Callable, Dict, List, Literal +import numpy as np -import torch -from typing_extensions import override +from eva.core.data.datasets.multi_embeddings import MultiEmbeddingsDataset -from eva.core.data.datasets import embeddings as embeddings_base +class MultiEmbeddingsRegressionDataset(MultiEmbeddingsDataset): + """Dataset class for where a sample corresponds to multiple embeddings. -class MultiEmbeddingsRegressionDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]): - """Dataset class for regression with multiple embeddings per sample.""" + Specialised for regression data with a float target type. + """ - def __init__( - self, - root: str, - manifest_file: str, - split: Literal["train", "val", "test"], - column_mapping: Dict[str, str] = embeddings_base.default_column_mapping, - embeddings_transforms: Callable | None = None, - target_transforms: Callable | None = None, - ): - """Initialize dataset. - - Expects a manifest file listing the paths of `.pt` files containing tensor embeddings. - - The manifest must have a `column_mapping["multi_id"]` column that contains the - unique identifier group of embeddings. For oncology datasets, this would be usually - the slide id. Each row in the manifest file points to a .pt file that can contain - one or multiple embeddings (either as a list or stacked tensors). There can also be - multiple rows for the same `multi_id`, in which case the embeddings from the different - .pt files corresponding to that same `multi_id` will be stacked along the first dimension. - - Args: - root: Root directory of the dataset. - manifest_file: The path to the manifest file, which is relative to - the `root` argument. - split: The dataset split to use. The `split` column of the manifest - file will be splitted based on this value. - column_mapping: Defines the map between the variables and the manifest - columns. It will overwrite the `default_column_mapping` with - the provided values, so that `column_mapping` can contain only the - values which are altered or missing. - embeddings_transforms: A function/transform that transforms the embedding. - target_transforms: A function/transform that transforms the target. - """ - super().__init__( - manifest_file=manifest_file, - root=root, - split=split, - column_mapping=column_mapping, - embeddings_transforms=embeddings_transforms, - target_transforms=target_transforms, - ) - self._multi_ids: List[int] - - @override - def setup(self): - super().setup() - self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique()) - - @override - def load_embeddings(self, index: int) -> torch.Tensor: - """Loads and stacks all embedding corresponding to the `index`'th multi_id.""" - # Get all embeddings for the given index (multi_id) - multi_id = self._multi_ids[index] - embedding_paths = self._data.loc[ - self._data[self._column_mapping["multi_id"]] == multi_id, - self._column_mapping["path"], - ].to_list() - - embeddings = [] - for path in embedding_paths: - embedding = torch.load(os.path.join(self._root, path), map_location="cpu") - if isinstance(embedding, list): - embedding = torch.stack(embedding, dim=0) - embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding) - embeddings = torch.cat(embeddings, dim=0) - - if embeddings.ndim != 2: - raise ValueError( - f"Expected 2D tensor, got \ - {embeddings.ndim} for {multi_id}." - ) - - return embeddings - - @override - def load_target(self, index: int) -> torch.Tensor: - """Returns the target corresponding to the `index`'th multi_id. - - This method assumes that all the embeddings corresponding to the same `multi_id` - have the same target. If this is not the case, it will raise an error. - """ - multi_id = self._multi_ids[index] - targets = self._data.loc[ - self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"] - ] - - if not targets.nunique() == 1: - raise ValueError(f"Multiple targets found for {multi_id}.") - - return torch.tensor(targets.iloc[0], dtype=torch.float32) - - @override - def __len__(self) -> int: - return len(self._multi_ids) + def __init__(self, *args, **kwargs): + """Initialize dataset with the correct return type.""" + super().__init__(*args, target_type=np.float32, **kwargs) diff --git a/src/eva/core/data/transforms/__init__.py b/src/eva/core/data/transforms/__init__.py index f36289d7d..385ba9e15 100644 --- a/src/eva/core/data/transforms/__init__.py +++ b/src/eva/core/data/transforms/__init__.py @@ -1,7 +1,7 @@ """Core data transforms.""" -from eva.core.data.transforms.dtype import ArrayToFloatTensor, ArrayToTensor, SqueezeTensor +from eva.core.data.transforms.dtype import ArrayToFloatTensor, ArrayToTensor from eva.core.data.transforms.padding import Pad2DTensor from eva.core.data.transforms.sampling import SampleFromAxis -__all__ = ["ArrayToFloatTensor", "ArrayToTensor", "Pad2DTensor", "SampleFromAxis", "SqueezeTensor"] +__all__ = ["ArrayToFloatTensor", "ArrayToTensor", "Pad2DTensor", "SampleFromAxis"] diff --git a/src/eva/core/data/transforms/dtype/__init__.py b/src/eva/core/data/transforms/dtype/__init__.py index 943e6d431..50b6fb207 100644 --- a/src/eva/core/data/transforms/dtype/__init__.py +++ b/src/eva/core/data/transforms/dtype/__init__.py @@ -1,6 +1,5 @@ """Type casting related transforms.""" from eva.core.data.transforms.dtype.array import ArrayToFloatTensor, ArrayToTensor -from eva.core.data.transforms.dtype.tensor import SqueezeTensor -__all__ = ["ArrayToFloatTensor", "ArrayToTensor", "SqueezeTensor"] +__all__ = ["ArrayToFloatTensor", "ArrayToTensor"] diff --git a/src/eva/core/data/transforms/dtype/tensor.py b/src/eva/core/data/transforms/dtype/tensor.py deleted file mode 100644 index 6d4a3d43c..000000000 --- a/src/eva/core/data/transforms/dtype/tensor.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Transformations to change the shape of tensors.""" - -import torch - - -class SqueezeTensor: - """Squeezes a [B, 1] tensor to [B].""" - - def __call__(self, tensor: torch.Tensor) -> torch.Tensor: - """Call method for the transformation. - - Args: - tensor: The input tensor to be squeezed. - """ - return tensor.squeeze(-1) diff --git a/src/eva/core/metrics/defaults/regression/regression_metrics.py b/src/eva/core/metrics/defaults/regression/regression_metrics.py index 0a51b6130..3dddf2305 100644 --- a/src/eva/core/metrics/defaults/regression/regression_metrics.py +++ b/src/eva/core/metrics/defaults/regression/regression_metrics.py @@ -26,8 +26,12 @@ def __init__( postfix: A string to append after metric names. """ super().__init__( - metrics=[MeanAbsoluteError(), MeanSquaredError(squared=False), R2Score()], + metrics={ + "MAE": MeanAbsoluteError(), + "RMSE": MeanSquaredError(squared=False), + "R2": R2Score(), + }, prefix=prefix, postfix=postfix, - compute_groups=[["MeanAbsoluteError", "MeanSquaredError", "R2Score"]], + compute_groups=[["MAE", "RMSE", "R2"]], ) diff --git a/src/eva/vision/data/datasets/regression/tiger_til_score.py b/src/eva/vision/data/datasets/regression/tiger_til_score.py index b91e05197..0f62c4d37 100644 --- a/src/eva/vision/data/datasets/regression/tiger_til_score.py +++ b/src/eva/vision/data/datasets/regression/tiger_til_score.py @@ -1,75 +1,22 @@ """Tiger dataset class for regression targets.""" import functools -import glob import os from pathlib import Path -from typing import Any, Callable, Dict, List, Literal +from typing import Dict import pandas as pd import torch -from torchvision import tv_tensors -from torchvision.transforms.v2 import functional from typing_extensions import override -from eva.vision.data.datasets import _validators, vision, wsi -from eva.vision.data.wsi.patching import samplers +from eva.vision.data.datasets import tiger -class TIGERTILScore(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]): - """Dataset class for TIGERBULK regression tasks with per-slide targets.""" +class TIGERTILScore(tiger.TIGERBase): + """Dataset class for regression tasks using the TIGERTILS partition of the TIGER dataset. - def __init__( - self, - root: str, - sampler: samplers.Sampler, - split: Literal["train", "val", "test"] | None = None, - width: int = 224, - height: int = 224, - target_mpp: float = 0.5, - backend: str = "openslide", - image_transforms: Callable | None = None, - coords_path: str | None = None, - seed: int = 42, - n_patches: int = 200, - ) -> None: - """Initializes the dataset. - - Args: - root: Root directory of the dataset. - sampler: The sampler to use for sampling patch coordinates. - split: Dataset split to use. If `None`, the entire dataset is used. - width: Patch width in pixels. - height: Patch height in pixels. - target_mpp: Target microns per pixel (mpp) for patches. - backend: WSI reading backend. - image_transforms: Transforms to apply to patches. - coords_path: Optional path to save patch coordinates. - seed: Random seed. - n_patches: Number of patches per slide. - targets_csv: Path to CSV containing per-slide regression targets. - Must have columns: slide_name,target - """ - self._split = split - self._root = root - self._width = width - self._height = height - self._target_mpp = target_mpp - self._seed = seed - self._n_patches = n_patches - - wsi.MultiWsiDataset.__init__( - self, - root=root, - file_paths=self._load_file_paths(split), - width=width, - height=height, - sampler=sampler, - target_mpp=target_mpp, - backend=backend, - image_transforms=image_transforms, - coords_path=coords_path, - ) + Predicts TIL scores, i.e. the proportion of the cell infiltrated by TILs. + """ @functools.cached_property def annotations(self) -> Dict[str, float]: @@ -91,58 +38,13 @@ def annotations(self) -> Dict[str, float]: return {str(row["image-id"]): float(row["tils-score"]) for _, row in df.iterrows()} - @override - def prepare_data(self) -> None: - _validators.check_dataset_exists(self._root, False) - - @override - def __getitem__(self, index: int): - return vision.VisionDataset.__getitem__(self, index) - - @override - def load_data(self, index: int) -> tv_tensors.Image: - image_array = wsi.MultiWsiDataset.__getitem__(self, index) - return functional.to_image(image_array) - @override def load_target(self, index: int) -> torch.Tensor: - slide_idx = index // self._n_patches + metadata = self.load_metadata(index=index) + slide_idx = metadata["slide_idx"] file_path = self._file_paths[slide_idx] slide_name = Path(file_path).stem target_value = self.annotations[slide_name] tensor = torch.tensor([target_value], dtype=torch.float32) return tensor - - @override - def load_metadata(self, index: int) -> Dict[str, Any]: - return wsi.MultiWsiDataset.load_metadata(self, index) - - def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]: - """Loads the file paths of WSIs from wsitils/images. - - Splits are assigned 70% train, 15% val, 15% test by filename sorting. - """ - image_dir = os.path.join(self._root, "images") - - all_paths = sorted(glob.glob(os.path.join(image_dir, "*.tif"))) - - if not all_paths: - raise FileNotFoundError(f"No .tif files found in {image_dir}") - - n_total = len(all_paths) - n_train = int(n_total * 0.7) - n_val = int(n_total * 0.15) - - if split == "train": - selected_paths = all_paths[:n_train] - elif split == "val": - selected_paths = all_paths[n_train : n_train + n_val] - elif split == "test": - selected_paths = all_paths[n_train + n_val :] - elif split is None: - selected_paths = all_paths - else: - raise ValueError("Invalid split. Use 'train', 'val', 'test', or None.") - - return [os.path.relpath(path, self._root) for path in selected_paths] diff --git a/src/eva/vision/data/datasets/tiger.py b/src/eva/vision/data/datasets/tiger.py new file mode 100644 index 000000000..e0e32f874 --- /dev/null +++ b/src/eva/vision/data/datasets/tiger.py @@ -0,0 +1,129 @@ +"""Abstract base class for TIGER datasets spanning different task types.""" + +import abc +import glob +import os +from typing import Any, Callable, Dict, List, Literal, Tuple + +import numpy as np +import torch +from torchvision import tv_tensors +from torchvision.transforms.v2 import functional +from typing_extensions import override + +from eva.vision.data.datasets import _validators, vision, wsi +from eva.vision.data.wsi.patching import samplers + + +class TIGERBase( + wsi.MultiWsiDataset, + vision.VisionDataset[tv_tensors.Image, torch.Tensor], + abc.ABC, +): + """Abstract base class for TIGER datasets spanning different task types.""" + + _train_split_ratio: float = 0.7 + _val_split_ratio: float = 0.15 + + # target microns per pixel (mpp) for patches. + _target_mpp: float = 0.5 + + def __init__( + self, + root: str, + sampler: samplers.Sampler, + split: Literal["train", "val", "test"] | None = None, + width: int = 224, + height: int = 224, + backend: str = "openslide", + image_transforms: Callable | None = None, + coords_path: str | None = None, + seed: int = 42, + ) -> None: + """Initializes the dataset. + + Args: + root: Root directory of the dataset. + sampler: The sampler to use for sampling patch coordinates. + split: Dataset split to use. If `None`, the entire dataset is used. + width: Patch width in pixels. + height: Patch height in pixels. + backend: WSI reading backend. + image_transforms: Transforms to apply to patches. + coords_path: Optional path to save patch coordinates. + seed: Random seed. + """ + self._root = root + self._split = split + self._width = width + self._height = height + self._seed = seed + + wsi.MultiWsiDataset.__init__( + self, + root=root, + file_paths=self._load_file_paths(split), + width=width, + height=height, + sampler=sampler, + target_mpp=self._target_mpp, + backend=backend, + image_transforms=image_transforms, + coords_path=coords_path, + ) + + @override + def prepare_data(self) -> None: + _validators.check_dataset_exists(self._root, False) + + @override + def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]: + return vision.VisionDataset.__getitem__(self, index) + + @override + def load_data(self, index: int) -> tv_tensors.Image: + image_array = wsi.MultiWsiDataset.__getitem__(self, index) + return functional.to_image(image_array) + + @override + def load_metadata(self, index: int) -> Dict[str, Any]: + return wsi.MultiWsiDataset.load_metadata(self, index) + + @abc.abstractmethod + def annotations(self) -> Dict[str, Any]: + """Annotates target data.""" + raise NotImplementedError + + @abc.abstractmethod + def load_target(self, index: int): + """Task-specific target loading.""" + raise NotImplementedError + + def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]: + """Loads the file paths of WSIs from wsibulk/images. + + Splits are assigned 70% train, 15% val, 15% test by filename sorting. + """ + image_dir = os.path.join(self._root, "images") + all_paths = sorted(glob.glob(os.path.join(image_dir, "*.tif"))) + + if not all_paths: + raise FileNotFoundError(f"No .tif files found in {image_dir}") + + rng = np.random.default_rng(self._seed) # nosec B311 + rng.shuffle(all_paths) + + n_total = len(all_paths) + n_train = int(n_total * self._train_split_ratio) + n_val = int(n_total * self._val_split_ratio) + + if split == "train": + selected_paths = all_paths[:n_train] + elif split == "val": + selected_paths = all_paths[n_train : n_train + n_val] + elif split == "test": + selected_paths = all_paths[n_train + n_val :] + elif split is None: + selected_paths = all_paths + + return [os.path.relpath(path, self._root) for path in selected_paths] diff --git a/src/eva/vision/data/datasets/wsi.py b/src/eva/vision/data/datasets/wsi.py index 4c1c789a3..8e31d5644 100644 --- a/src/eva/vision/data/datasets/wsi.py +++ b/src/eva/vision/data/datasets/wsi.py @@ -179,7 +179,11 @@ def load_metadata(self, index: int) -> Dict[str, Any]: """Loads the metadata for the patch at the specified index.""" dataset_index, sample_index = self._get_dataset_idx(index), self._get_sample_idx(index) patch_metadata = self.datasets[dataset_index].load_metadata(sample_index) - return {"wsi_id": self.filename(index).split(".")[0]} | patch_metadata + return { + "wsi_id": self.filename(index).split(".")[0], + "slide_idx": dataset_index, + "patch_idx": sample_index, + } | patch_metadata def _load_datasets(self) -> list[WsiDataset]: logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...") diff --git a/src/eva/vision/models/networks/abmil.py b/src/eva/vision/models/networks/abmil.py index bb2ca4820..03553bef1 100644 --- a/src/eva/vision/models/networks/abmil.py +++ b/src/eva/vision/models/networks/abmil.py @@ -34,8 +34,8 @@ class ABMIL(torch.nn.Module): def __init__( self, input_size: int, - output_size: int, projected_input_size: int | None, + output_size: int = 1, hidden_size_attention: int = 128, hidden_sizes_mlp: tuple = (128, 64), use_bias: bool = True, From 4ec5faedcd8eddc5f60a9266f9aff930af40a8db Mon Sep 17 00:00:00 2001 From: Jklubienski Date: Mon, 6 Oct 2025 09:27:14 +0100 Subject: [PATCH 3/3] Add secondary changes from review comments --- .../pathology/offline/regression/tiger_til_score.yaml | 3 +-- src/eva/core/data/datasets/multi_embeddings.py | 6 +++--- src/eva/vision/data/datasets/tiger.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/configs/vision/pathology/offline/regression/tiger_til_score.yaml b/configs/vision/pathology/offline/regression/tiger_til_score.yaml index 99f42ad07..eb8430e33 100644 --- a/configs/vision/pathology/offline/regression/tiger_til_score.yaml +++ b/configs/vision/pathology/offline/regression/tiger_til_score.yaml @@ -20,7 +20,7 @@ trainer: save_last: ${oc.env:SAVE_LAST, false} save_top_k: 1 monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MAE} - mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, min} - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 @@ -104,7 +104,6 @@ data: max_samples: *N_PATCHES width: 224 height: 224 - target_mpp: 0.5 split: train coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv image_transforms: diff --git a/src/eva/core/data/datasets/multi_embeddings.py b/src/eva/core/data/datasets/multi_embeddings.py index dc7213569..93b68d8ef 100644 --- a/src/eva/core/data/datasets/multi_embeddings.py +++ b/src/eva/core/data/datasets/multi_embeddings.py @@ -8,10 +8,10 @@ import torch from typing_extensions import override -from eva.core.data.datasets import embeddings as embeddings_base +from eva.core.data.datasets import embeddings as base -class MultiEmbeddingsDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]): +class MultiEmbeddingsDataset(base.EmbeddingsDataset[torch.Tensor]): """Dataset class for where a sample corresponds to multiple embeddings. Example use case: Slide level dataset where each slide has multiple patch embeddings. @@ -22,7 +22,7 @@ def __init__( root: str, manifest_file: str, split: Literal["train", "val", "test"], - column_mapping: Dict[str, str] = embeddings_base.default_column_mapping, + column_mapping: Dict[str, str] = base.default_column_mapping, embeddings_transforms: Callable | None = None, target_transforms: Callable | None = None, target_type: type[np.generic] = np.int64, diff --git a/src/eva/vision/data/datasets/tiger.py b/src/eva/vision/data/datasets/tiger.py index e0e32f874..12d34cfed 100644 --- a/src/eva/vision/data/datasets/tiger.py +++ b/src/eva/vision/data/datasets/tiger.py @@ -25,8 +25,8 @@ class TIGERBase( _train_split_ratio: float = 0.7 _val_split_ratio: float = 0.15 - # target microns per pixel (mpp) for patches. _target_mpp: float = 0.5 + '''Target microns per pixel (mpp) for patches''' def __init__( self,