diff --git a/configs/vision/pathology/offline/classification/bracs.yaml b/configs/vision/pathology/offline/classification/bracs.yaml index c31cccf35..aa1a25275 100644 --- a/configs/vision/pathology/offline/classification/bracs.yaml +++ b/configs/vision/pathology/offline/classification/bracs.yaml @@ -2,7 +2,7 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + n_runs: &N_RUNS ${oc.env:N_RUNS, 20} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/bracs} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} @@ -24,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: ${oc.env:PATIENCE, 167} + patience: ${oc.env:PATIENCE, 22} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.ClassificationEmbeddingsWriter @@ -49,15 +49,17 @@ model: class_path: eva.HeadModule init_args: head: - class_path: torch.nn.Linear + class_path: eva.vision.models.networks.ABMIL init_args: - in_features: ${oc.env:IN_FEATURES, 384} - out_features: &NUM_CLASSES 7 + input_size: ${oc.env:IN_FEATURES, 384} + output_size: &NUM_CLASSES 7 + projected_input_size: 128 criterion: torch.nn.CrossEntropyLoss optimizer: class_path: torch.optim.AdamW init_args: - lr: ${oc.env:LR_VALUE, 0.0003} + lr: ${oc.env:LR_VALUE, 0.001} + betas: [0.9, 0.999] metrics: common: - class_path: eva.metrics.AverageLoss @@ -69,18 +71,24 @@ data: init_args: datasets: train: - class_path: eva.datasets.EmbeddingsClassificationDataset + class_path: eva.datasets.MultiEmbeddingsClassificationDataset init_args: &DATASET_ARGS root: *DATASET_EMBEDDINGS_ROOT manifest_file: manifest.csv split: train + column_mapping: + multi_id: origin + embeddings_transforms: + class_path: eva.core.data.transforms.Pad2DTensor + init_args: + pad_size: &N_PATCHES ${oc.env:N_PATCHES, 250} val: - class_path: eva.datasets.EmbeddingsClassificationDataset + class_path: eva.datasets.MultiEmbeddingsClassificationDataset init_args: <<: *DATASET_ARGS split: val test: - class_path: eva.datasets.EmbeddingsClassificationDataset + class_path: eva.datasets.MultiEmbeddingsClassificationDataset init_args: <<: *DATASET_ARGS split: test @@ -89,10 +97,15 @@ data: init_args: &PREDICT_DATASET_ARGS root: ${oc.env:DATA_ROOT, ./data/bracs} split: train + sampler: + class_path: eva.vision.data.wsi.patching.samplers.GridSampler + init_args: + max_samples: *N_PATCHES + include_partial_patches: true transforms: class_path: eva.vision.data.transforms.common.ResizeAndCrop init_args: - size: ${oc.env:RESIZE_DIM, 224} + size: ${oc.env:RESIZE_DIM, 224} mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} - class_path: eva.vision.datasets.BRACS @@ -105,7 +118,7 @@ data: split: test dataloaders: train: - batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32} num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} shuffle: true val: diff --git a/configs/vision/pathology/online/classification/bracs.yaml b/configs/vision/pathology/online/classification/bracs.yaml deleted file mode 100644 index 7b429cadd..000000000 --- a/configs/vision/pathology/online/classification/bracs.yaml +++ /dev/null @@ -1,94 +0,0 @@ ---- -trainer: - class_path: eva.Trainer - init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 5} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/online/bracs} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} - checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} - callbacks: - - class_path: eva.callbacks.ConfigurationLogger - - class_path: lightning.pytorch.callbacks.TQDMProgressBar - init_args: - refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: epoch - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - filename: best - save_last: ${oc.env:SAVE_LAST, false} - save_top_k: 1 - monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} - mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - min_delta: 0 - patience: ${oc.env:PATIENCE, 167} - monitor: *MONITOR_METRIC - mode: *MONITOR_METRIC_MODE - logger: - - class_path: lightning.pytorch.loggers.TensorBoardLogger - init_args: - save_dir: *OUTPUT_ROOT - name: "" -model: - class_path: eva.HeadModule - init_args: - backbone: - class_path: eva.vision.models.ModelFromRegistry - init_args: - model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} - model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} - head: - class_path: torch.nn.Linear - init_args: - in_features: ${oc.env:IN_FEATURES, 384} - out_features: &NUM_CLASSES 7 - criterion: torch.nn.CrossEntropyLoss - optimizer: - class_path: torch.optim.AdamW - init_args: - lr: ${oc.env:LR_VALUE, 0.0003} - metrics: - common: - - class_path: eva.metrics.AverageLoss - - class_path: eva.metrics.MulticlassClassificationMetrics - init_args: - num_classes: *NUM_CLASSES -data: - class_path: eva.DataModule - init_args: - datasets: - train: - class_path: eva.vision.datasets.BRACS - init_args: &DATASET_ARGS - root: ${oc.env:DATA_ROOT, ./data/bracs} - split: train - transforms: - class_path: eva.vision.data.transforms.common.ResizeAndCrop - init_args: - size: ${oc.env:RESIZE_DIM, 224} - mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} - std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} - val: - class_path: eva.vision.datasets.BRACS - init_args: - <<: *DATASET_ARGS - split: val - test: - class_path: eva.vision.datasets.BRACS - init_args: - <<: *DATASET_ARGS - split: test - dataloaders: - train: - batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} - num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} - shuffle: true - val: - batch_size: *BATCH_SIZE - num_workers: *N_DATA_WORKERS - test: - batch_size: *BATCH_SIZE - num_workers: *N_DATA_WORKERS diff --git a/docs/datasets/bracs.md b/docs/datasets/bracs.md index 568560c67..962b39242 100644 --- a/docs/datasets/bracs.md +++ b/docs/datasets/bracs.md @@ -29,6 +29,8 @@ we use in this benchmarks contains 4539 extracted ROIs / patches. | **Files format** | `png` | | **Number of images** | 4539 | +Given that some images in the `BRACS_ROI` can be very large, we split the image into a grid of 224x224 patches and then treat it as a multi instance learning (MIL) classification problem. + ### Splits diff --git a/src/eva/vision/data/datasets/classification/bracs.py b/src/eva/vision/data/datasets/classification/bracs.py index 848627c3c..997ff922e 100644 --- a/src/eva/vision/data/datasets/classification/bracs.py +++ b/src/eva/vision/data/datasets/classification/bracs.py @@ -1,27 +1,28 @@ """BRACS dataset class.""" import os -from typing import Callable, Dict, List, Literal, Tuple +from typing import Any, Callable, Dict, List, Literal, Tuple import torch from torchvision import tv_tensors from torchvision.datasets import folder +from torchvision.transforms.v2 import functional from typing_extensions import override -from eva.vision.data.datasets import _validators +from eva.vision.data.datasets import _validators, wsi from eva.vision.data.datasets.classification import base -from eva.vision.utils import io +from eva.vision.data.wsi.patching import samplers -class BRACS(base.ImageClassification): +class BRACS(wsi.MultiWsiDataset, base.ImageClassification): """Dataset class for BRACS images and corresponding targets.""" - _expected_dataset_lengths: Dict[str, int] = { + _expected_files: Dict[str, int] = { "train": 3657, "val": 312, "test": 570, } - """Expected dataset lengths for the splits and complete dataset.""" + """Expected number of files for each split.""" _license: str = "CC BY-NC 4.0 (https://creativecommons.org/licenses/by-nc/4.0/)" """Dataset license.""" @@ -30,6 +31,10 @@ def __init__( self, root: str, split: Literal["train", "val", "test"], + sampler: samplers.Sampler | None = None, + width: int = 224, + height: int = 224, + target_mpp: float = 0.25, transforms: Callable | None = None, ) -> None: """Initializes the dataset. @@ -37,15 +42,31 @@ def __init__( Args: root: Path to the root directory of the dataset. split: Dataset split to use. + sampler: The sampler to use for sampling patch coordinates. + If `None`, it will use the ::class::`ForegroundGridSampler` sampler. + width: Width of the patches to be extracted, in pixels. + height: Height of the patches to be extracted, in pixels. + target_mpp: Target microns per pixel (mpp) for the patches. transforms: A function/transform which returns a transformed version of the raw data samples. """ - super().__init__(transforms=transforms) - self._root = root self._split = split + self._path_to_target = self._make_dataset() + self._file_to_path = {os.path.basename(p): p for p in self._path_to_target.keys()} - self._samples: List[Tuple[str, int]] = [] + wsi.MultiWsiDataset.__init__( + self, + root=root, + file_paths=sorted(self._path_to_target.keys()), + width=width, + height=height, + sampler=sampler or samplers.ForegroundGridSampler(max_samples=25), + target_mpp=target_mpp, + overwrite_mpp=0.25, + backend="pil", + image_transforms=transforms, + ) @property @override @@ -57,55 +78,52 @@ def classes(self) -> List[str]: def class_to_idx(self) -> Dict[str, int]: return {name: index for index, name in enumerate(self.classes)} - @override - def filename(self, index: int) -> str: - image_path, *_ = self._samples[index] - return os.path.relpath(image_path, self._dataset_path) - @override def prepare_data(self) -> None: _validators.check_dataset_exists(self._root, True) - @override - def configure(self) -> None: - self._samples = self._make_dataset() - @override def validate(self) -> None: + if len(self._path_to_target) != self._expected_files[self._split]: + raise ValueError( + f"Expected {self._split} split to have {self._expected_files[self._split]} files, " + f"but found {len(self._path_to_target)} files." + ) + _validators.check_dataset_integrity( self, - length=self._expected_dataset_lengths[self._split], + length=None, n_classes=7, first_and_last_labels=("0_N", "6_IC"), ) @override - def load_image(self, index: int) -> tv_tensors.Image: - image_path, _ = self._samples[index] - return io.read_image_as_tensor(image_path) + def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]: + return base.ImageClassification.__getitem__(self, index) @override - def load_target(self, index: int) -> torch.Tensor: - _, target = self._samples[index] - return torch.tensor(target, dtype=torch.long) + def load_image(self, index: int) -> tv_tensors.Image: + image_array = wsi.MultiWsiDataset.__getitem__(self, index) + return functional.to_image(image_array) @override - def __len__(self) -> int: - return len(self._samples) + def load_target(self, index: int) -> torch.Tensor: + path = self._file_to_path[self.filename(index)] + return torch.tensor(self._path_to_target[path], dtype=torch.long) @property def _dataset_path(self) -> str: """Returns the full path of dataset directory.""" return os.path.join(self._root, "BRACS_RoI/latest_version") - def _make_dataset(self) -> List[Tuple[str, int]]: + def _make_dataset(self) -> Dict[str, int]: """Builds the dataset for the specified split.""" dataset = folder.make_dataset( directory=os.path.join(self._dataset_path, self._split), class_to_idx=self.class_to_idx, extensions=(".png"), ) - return dataset + return dict(dataset) def _print_license(self) -> None: """Prints the dataset license.""" diff --git a/src/eva/vision/data/datasets/segmentation/consep.py b/src/eva/vision/data/datasets/segmentation/consep.py index f0cd1b551..9c0ed9dc9 100644 --- a/src/eva/vision/data/datasets/segmentation/consep.py +++ b/src/eva/vision/data/datasets/segmentation/consep.py @@ -55,7 +55,6 @@ def __init__( width: Width of the patches to be extracted, in pixels. height: Height of the patches to be extracted, in pixels. target_mpp: Target microns per pixel (mpp) for the patches. - backend: The backend to use for reading the whole-slide images. transforms: Transforms to apply to the extracted image & mask patches. """ self._split = split diff --git a/src/eva/vision/data/wsi/backends/pil.py b/src/eva/vision/data/wsi/backends/pil.py index fdef88b01..fb37b0038 100644 --- a/src/eva/vision/data/wsi/backends/pil.py +++ b/src/eva/vision/data/wsi/backends/pil.py @@ -50,3 +50,16 @@ def _read_region( ) ) return np.array(patch) + + @override + def _verify_location(self, location: Tuple[int, int], size: Tuple[int, int]) -> None: + """Verifies that the requested coordinates are within the slide dimensions. + + Args: + location: Top-left corner (x, y) coordinates to read. + size: Size of the requested region (width, height) + """ + x_max, y_max = self.level_dimensions[0] + + if int(location[0]) >= x_max or int(location[1]) >= y_max: + raise ValueError(f"Out of bounds region: {location}, {size}") diff --git a/src/eva/vision/data/wsi/patching/samplers/_utils.py b/src/eva/vision/data/wsi/patching/samplers/_utils.py index f1fa3b7e0..9376d0a38 100644 --- a/src/eva/vision/data/wsi/patching/samplers/_utils.py +++ b/src/eva/vision/data/wsi/patching/samplers/_utils.py @@ -1,35 +1,49 @@ -from typing import Tuple +from typing import List, Tuple import numpy as np -def get_grid_coords_and_indices( - layer_shape: Tuple[int, int], +def get_grid_coords( + image_size: Tuple[int, int], width: int, height: int, overlap: Tuple[int, int], shuffle: bool = True, + include_partial_patches: bool = False, seed: int = 42, -): +) -> List[Tuple[int, int]]: """Get grid coordinates and indices. Args: - layer_shape: The shape of the layer. + image_size: The shape of the complete image. width: The width of the patches. height: The height of the patches. overlap: The overlap between patches in the grid. - shuffle: Whether to shuffle the indices. + shuffle: Whether to shuffle the order of the coordinates. + include_partial_patches: Whether to include coordinates of the last patch when it + it partially exceeds the image and therefore is smaller than the + specified patch size. seed: The random seed. + + Returns: + A list of tuples with the (x, y) coordinates. """ - x_range = range(0, layer_shape[0] - width + 1, width - overlap[0]) - y_range = range(0, layer_shape[1] - height + 1, height - overlap[1]) - x_y = [(x, y) for x in x_range for y in y_range] + if not include_partial_patches and (width > image_size[0] or height > image_size[1]): + raise ValueError("The patch size cannot be bigger than the image.") + + x_stop = image_size[0] if include_partial_patches else image_size[0] - width + 1 + y_stop = image_size[1] if include_partial_patches else image_size[1] - height + 1 + + x_coords = list(range(0, x_stop, width - overlap[0])) or [0] + y_coords = list(range(0, y_stop, height - overlap[1])) or [0] + + x_y = [(x, y) for x in x_coords for y in y_coords] - indices = list(range(len(x_y))) if shuffle: random_generator = np.random.default_rng(seed) - random_generator.shuffle(indices) - return x_y, indices + random_generator.shuffle(x_y) + + return x_y def validate_dimensions(width: int, height: int, layer_shape: Tuple[int, int]) -> None: diff --git a/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py b/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py index 7671ab9f5..fe70ac16c 100644 --- a/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py +++ b/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py @@ -45,13 +45,10 @@ def sample( layer_shape: The shape of the layer. mask: The mask of the image. """ - _utils.validate_dimensions(width, height, layer_shape) - x_y, indices = _utils.get_grid_coords_and_indices( - layer_shape, width, height, self.overlap, seed=self.seed - ) + x_y = _utils.get_grid_coords(layer_shape, width, height, self.overlap, seed=self.seed) count = 0 - for i in indices: + for i in range(len(x_y)): if count >= self.max_samples: break diff --git a/src/eva/vision/data/wsi/patching/samplers/grid.py b/src/eva/vision/data/wsi/patching/samplers/grid.py index 3f2b00819..beea4f116 100644 --- a/src/eva/vision/data/wsi/patching/samplers/grid.py +++ b/src/eva/vision/data/wsi/patching/samplers/grid.py @@ -6,24 +6,29 @@ class GridSampler(base.Sampler): - """Sample patches based on a grid. - - Args: - max_samples: The maximum number of samples to return. - overlap: The overlap between patches in the grid. - seed: The random seed. - """ + """Sample patches based on a grid.""" def __init__( self, max_samples: int | None = None, overlap: Tuple[int, int] = (0, 0), seed: int = 42, + include_partial_patches: bool = False, ): - """Initializes the sampler.""" + """Initializes the sampler. + + Args: + max_samples: The maximum number of samples to return. + overlap: The overlap between patches in the grid. + seed: The random seed. + include_partial_patches: Whether to include coordinates of the last patch when it + it partially exceeds the image and therefore is smaller than the + specified patch size. + """ self.max_samples = max_samples self.overlap = overlap self.seed = seed + self.include_partial_patches = include_partial_patches def sample( self, @@ -38,10 +43,14 @@ def sample( height: The height of the patches. layer_shape: The shape of the layer. """ - _utils.validate_dimensions(width, height, layer_shape) - x_y, indices = _utils.get_grid_coords_and_indices( - layer_shape, width, height, self.overlap, seed=self.seed + x_y = _utils.get_grid_coords( + layer_shape, + width, + height, + self.overlap, + seed=self.seed, + include_partial_patches=self.include_partial_patches, ) - max_samples = len(indices) if self.max_samples is None else self.max_samples - for i in indices[:max_samples]: + max_samples = len(x_y) if self.max_samples is None else min(self.max_samples, len(x_y)) + for i in range(max_samples): yield x_y[i] diff --git a/tests/eva/vision/data/datasets/classification/test_bracs.py b/tests/eva/vision/data/datasets/classification/test_bracs.py index 2bd2c829a..5073d92c7 100644 --- a/tests/eva/vision/data/datasets/classification/test_bracs.py +++ b/tests/eva/vision/data/datasets/classification/test_bracs.py @@ -9,6 +9,7 @@ from torchvision import tv_tensors from eva.vision.data import datasets +from eva.vision.data.wsi.patching import samplers @pytest.mark.parametrize( @@ -31,13 +32,13 @@ def test_sample(bracs_dataset: datasets.BRACS, index: int) -> None: # assert the format of the `image` and `target` image, target, _ = sample assert isinstance(image, tv_tensors.Image) - assert image.shape == (3, 40, 40) + assert image.shape == (3, 10, 10) assert isinstance(target, torch.Tensor) assert target in [0, 1, 2, 3, 4, 5, 6, 7, 8] @pytest.fixture(scope="function") -def bracs_dataset(split: Literal["train", "val"], assets_path: str) -> datasets.BRACS: +def bracs_dataset(split: Literal["train", "val", "test"], assets_path: str) -> datasets.BRACS: """BRACS dataset fixture.""" with mock.patch.object( datasets.BRACS, "classes", new_callable=mock.PropertyMock @@ -46,6 +47,10 @@ def bracs_dataset(split: Literal["train", "val"], assets_path: str) -> datasets. dataset = datasets.BRACS( root=os.path.join(assets_path, "vision", "datasets", "bracs"), split=split, + width=10, + height=10, + target_mpp=0.25, + sampler=samplers.GridSampler(), ) dataset.prepare_data() dataset.configure() diff --git a/tests/eva/vision/data/wsi/patching/samplers/test_grid.py b/tests/eva/vision/data/wsi/patching/samplers/test_grid.py index 42cbdbf80..3ab2086db 100644 --- a/tests/eva/vision/data/wsi/patching/samplers/test_grid.py +++ b/tests/eva/vision/data/wsi/patching/samplers/test_grid.py @@ -65,28 +65,51 @@ def test_different_seed(max_samples: int, seed_1: int, seed_2: int) -> None: assert x_y_1 != x_y_2 -def test_invalid_width_height() -> None: - """Tests if the sampler raises an error when width / height is bigger than layer_shape.""" - sampler = samplers.GridSampler(max_samples=10, seed=42) - - with pytest.raises(ValueError): - list(sampler.sample(width=200, height=200, layer_shape=(100, 100))) - - @pytest.mark.parametrize( - "width, height, layer_shape", + "width, height, layer_shape, include_partial_patches, expected_n_samples", [ - (5, 5, (25, 25)), - (5, 5, (100, 100)), - (224, 224, (1000, 1000)), + (5, 5, (25, 25), False, 25), + (5, 5, (25, 25), True, 25), + (224, 224, (1000, 1000), False, 16), + (224, 224, (1000, 1000), True, 25), + (10, 10, (5, 5), True, 1), ], ) -def test_expected_n_patches(width: int, height: int, layer_shape: Tuple[int, int]) -> None: +def test_expected_n_patches( + width: int, + height: int, + layer_shape: Tuple[int, int], + include_partial_patches: bool, + expected_n_samples: int, +) -> None: """Tests if the sampler respects the max_samples limit.""" - sampler = samplers.GridSampler(max_samples=None) - - expected_max_samples = (layer_shape[0] // width) * (layer_shape[1] // height) + sampler = samplers.GridSampler( + max_samples=None, include_partial_patches=include_partial_patches + ) x_y = list(sampler.sample(width=width, height=height, layer_shape=layer_shape)) - assert len(x_y) == expected_max_samples + assert len(x_y) == expected_n_samples + + +@pytest.mark.parametrize("include_partial_patches", [False, True]) +def test_patch_bigger_than_image(include_partial_patches: bool) -> None: + """Test edge case where the patch size is bigger than the image.""" + sampler = samplers.GridSampler( + max_samples=10, seed=42, include_partial_patches=include_partial_patches + ) + patch_dim, image_dim = 200, 100 + + if not include_partial_patches: + with pytest.raises(ValueError, match="The patch size cannot be bigger than the image."): + list( + sampler.sample( + width=patch_dim, height=patch_dim, layer_shape=(image_dim, image_dim) + ) + ) + else: + x_y = list( + sampler.sample(width=patch_dim, height=patch_dim, layer_shape=(image_dim, image_dim)) + ) + assert len(x_y) == 1 + assert x_y[0] == (0, 0)