From 4e7b108468f645104278703dc9ae95ed57bd3f77 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Mon, 31 Mar 2025 12:53:00 +0200 Subject: [PATCH 1/7] converted BRACS into patched wsi dataset --- .../offline/classification/bracs.yaml | 32 +++++-- .../online/classification/bracs.yaml | 94 ------------------- .../data/datasets/classification/bracs.py | 76 +++++++++------ .../data/datasets/segmentation/consep.py | 1 - .../data/wsi/patching/samplers/_utils.py | 17 +++- .../vision/data/wsi/patching/samplers/grid.py | 29 ++++-- .../datasets/classification/test_bracs.py | 9 +- .../patching/samplers/test_foreground_grid.py | 1 + .../data/wsi/patching/samplers/test_random.py | 1 + 9 files changed, 112 insertions(+), 148 deletions(-) delete mode 100644 configs/vision/pathology/online/classification/bracs.yaml diff --git a/configs/vision/pathology/offline/classification/bracs.yaml b/configs/vision/pathology/offline/classification/bracs.yaml index c31cccf35..8f662b7ca 100644 --- a/configs/vision/pathology/offline/classification/bracs.yaml +++ b/configs/vision/pathology/offline/classification/bracs.yaml @@ -49,15 +49,17 @@ model: class_path: eva.HeadModule init_args: head: - class_path: torch.nn.Linear + class_path: eva.vision.models.networks.ABMIL init_args: - in_features: ${oc.env:IN_FEATURES, 384} - out_features: &NUM_CLASSES 7 + input_size: ${oc.env:IN_FEATURES, 384} + output_size: &NUM_CLASSES 7 + projected_input_size: 128 criterion: torch.nn.CrossEntropyLoss optimizer: class_path: torch.optim.AdamW init_args: - lr: ${oc.env:LR_VALUE, 0.0003} + lr: ${oc.env:LR_VALUE, 0.001} + betas: [0.9, 0.999] metrics: common: - class_path: eva.metrics.AverageLoss @@ -69,18 +71,24 @@ data: init_args: datasets: train: - class_path: eva.datasets.EmbeddingsClassificationDataset + class_path: eva.datasets.MultiEmbeddingsClassificationDataset init_args: &DATASET_ARGS root: *DATASET_EMBEDDINGS_ROOT manifest_file: manifest.csv split: train + column_mapping: + multi_id: origin + embeddings_transforms: + class_path: eva.core.data.transforms.Pad2DTensor + init_args: + pad_size: &N_PATCHES ${oc.env:N_PATCHES, 500} val: - class_path: eva.datasets.EmbeddingsClassificationDataset + class_path: eva.datasets.MultiEmbeddingsClassificationDataset init_args: <<: *DATASET_ARGS split: val test: - class_path: eva.datasets.EmbeddingsClassificationDataset + class_path: eva.datasets.MultiEmbeddingsClassificationDataset init_args: <<: *DATASET_ARGS split: test @@ -89,10 +97,16 @@ data: init_args: &PREDICT_DATASET_ARGS root: ${oc.env:DATA_ROOT, ./data/bracs} split: train + sampler: + class_path: eva.vision.data.wsi.patching.samplers.GridSampler + init_args: + max_samples: *N_PATCHES + validate_dimensions: false # Some images are smaller than 224x224 + include_last: true transforms: class_path: eva.vision.data.transforms.common.ResizeAndCrop init_args: - size: ${oc.env:RESIZE_DIM, 224} + size: ${oc.env:RESIZE_DIM, 224} mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} - class_path: eva.vision.datasets.BRACS @@ -105,7 +119,7 @@ data: split: test dataloaders: train: - batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32} num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} shuffle: true val: diff --git a/configs/vision/pathology/online/classification/bracs.yaml b/configs/vision/pathology/online/classification/bracs.yaml deleted file mode 100644 index 7b429cadd..000000000 --- a/configs/vision/pathology/online/classification/bracs.yaml +++ /dev/null @@ -1,94 +0,0 @@ ---- -trainer: - class_path: eva.Trainer - init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 5} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/online/bracs} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} - checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} - callbacks: - - class_path: eva.callbacks.ConfigurationLogger - - class_path: lightning.pytorch.callbacks.TQDMProgressBar - init_args: - refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: epoch - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - filename: best - save_last: ${oc.env:SAVE_LAST, false} - save_top_k: 1 - monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} - mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - min_delta: 0 - patience: ${oc.env:PATIENCE, 167} - monitor: *MONITOR_METRIC - mode: *MONITOR_METRIC_MODE - logger: - - class_path: lightning.pytorch.loggers.TensorBoardLogger - init_args: - save_dir: *OUTPUT_ROOT - name: "" -model: - class_path: eva.HeadModule - init_args: - backbone: - class_path: eva.vision.models.ModelFromRegistry - init_args: - model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} - model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} - head: - class_path: torch.nn.Linear - init_args: - in_features: ${oc.env:IN_FEATURES, 384} - out_features: &NUM_CLASSES 7 - criterion: torch.nn.CrossEntropyLoss - optimizer: - class_path: torch.optim.AdamW - init_args: - lr: ${oc.env:LR_VALUE, 0.0003} - metrics: - common: - - class_path: eva.metrics.AverageLoss - - class_path: eva.metrics.MulticlassClassificationMetrics - init_args: - num_classes: *NUM_CLASSES -data: - class_path: eva.DataModule - init_args: - datasets: - train: - class_path: eva.vision.datasets.BRACS - init_args: &DATASET_ARGS - root: ${oc.env:DATA_ROOT, ./data/bracs} - split: train - transforms: - class_path: eva.vision.data.transforms.common.ResizeAndCrop - init_args: - size: ${oc.env:RESIZE_DIM, 224} - mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} - std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} - val: - class_path: eva.vision.datasets.BRACS - init_args: - <<: *DATASET_ARGS - split: val - test: - class_path: eva.vision.datasets.BRACS - init_args: - <<: *DATASET_ARGS - split: test - dataloaders: - train: - batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} - num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} - shuffle: true - val: - batch_size: *BATCH_SIZE - num_workers: *N_DATA_WORKERS - test: - batch_size: *BATCH_SIZE - num_workers: *N_DATA_WORKERS diff --git a/src/eva/vision/data/datasets/classification/bracs.py b/src/eva/vision/data/datasets/classification/bracs.py index 848627c3c..997ff922e 100644 --- a/src/eva/vision/data/datasets/classification/bracs.py +++ b/src/eva/vision/data/datasets/classification/bracs.py @@ -1,27 +1,28 @@ """BRACS dataset class.""" import os -from typing import Callable, Dict, List, Literal, Tuple +from typing import Any, Callable, Dict, List, Literal, Tuple import torch from torchvision import tv_tensors from torchvision.datasets import folder +from torchvision.transforms.v2 import functional from typing_extensions import override -from eva.vision.data.datasets import _validators +from eva.vision.data.datasets import _validators, wsi from eva.vision.data.datasets.classification import base -from eva.vision.utils import io +from eva.vision.data.wsi.patching import samplers -class BRACS(base.ImageClassification): +class BRACS(wsi.MultiWsiDataset, base.ImageClassification): """Dataset class for BRACS images and corresponding targets.""" - _expected_dataset_lengths: Dict[str, int] = { + _expected_files: Dict[str, int] = { "train": 3657, "val": 312, "test": 570, } - """Expected dataset lengths for the splits and complete dataset.""" + """Expected number of files for each split.""" _license: str = "CC BY-NC 4.0 (https://creativecommons.org/licenses/by-nc/4.0/)" """Dataset license.""" @@ -30,6 +31,10 @@ def __init__( self, root: str, split: Literal["train", "val", "test"], + sampler: samplers.Sampler | None = None, + width: int = 224, + height: int = 224, + target_mpp: float = 0.25, transforms: Callable | None = None, ) -> None: """Initializes the dataset. @@ -37,15 +42,31 @@ def __init__( Args: root: Path to the root directory of the dataset. split: Dataset split to use. + sampler: The sampler to use for sampling patch coordinates. + If `None`, it will use the ::class::`ForegroundGridSampler` sampler. + width: Width of the patches to be extracted, in pixels. + height: Height of the patches to be extracted, in pixels. + target_mpp: Target microns per pixel (mpp) for the patches. transforms: A function/transform which returns a transformed version of the raw data samples. """ - super().__init__(transforms=transforms) - self._root = root self._split = split + self._path_to_target = self._make_dataset() + self._file_to_path = {os.path.basename(p): p for p in self._path_to_target.keys()} - self._samples: List[Tuple[str, int]] = [] + wsi.MultiWsiDataset.__init__( + self, + root=root, + file_paths=sorted(self._path_to_target.keys()), + width=width, + height=height, + sampler=sampler or samplers.ForegroundGridSampler(max_samples=25), + target_mpp=target_mpp, + overwrite_mpp=0.25, + backend="pil", + image_transforms=transforms, + ) @property @override @@ -57,55 +78,52 @@ def classes(self) -> List[str]: def class_to_idx(self) -> Dict[str, int]: return {name: index for index, name in enumerate(self.classes)} - @override - def filename(self, index: int) -> str: - image_path, *_ = self._samples[index] - return os.path.relpath(image_path, self._dataset_path) - @override def prepare_data(self) -> None: _validators.check_dataset_exists(self._root, True) - @override - def configure(self) -> None: - self._samples = self._make_dataset() - @override def validate(self) -> None: + if len(self._path_to_target) != self._expected_files[self._split]: + raise ValueError( + f"Expected {self._split} split to have {self._expected_files[self._split]} files, " + f"but found {len(self._path_to_target)} files." + ) + _validators.check_dataset_integrity( self, - length=self._expected_dataset_lengths[self._split], + length=None, n_classes=7, first_and_last_labels=("0_N", "6_IC"), ) @override - def load_image(self, index: int) -> tv_tensors.Image: - image_path, _ = self._samples[index] - return io.read_image_as_tensor(image_path) + def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]: + return base.ImageClassification.__getitem__(self, index) @override - def load_target(self, index: int) -> torch.Tensor: - _, target = self._samples[index] - return torch.tensor(target, dtype=torch.long) + def load_image(self, index: int) -> tv_tensors.Image: + image_array = wsi.MultiWsiDataset.__getitem__(self, index) + return functional.to_image(image_array) @override - def __len__(self) -> int: - return len(self._samples) + def load_target(self, index: int) -> torch.Tensor: + path = self._file_to_path[self.filename(index)] + return torch.tensor(self._path_to_target[path], dtype=torch.long) @property def _dataset_path(self) -> str: """Returns the full path of dataset directory.""" return os.path.join(self._root, "BRACS_RoI/latest_version") - def _make_dataset(self) -> List[Tuple[str, int]]: + def _make_dataset(self) -> Dict[str, int]: """Builds the dataset for the specified split.""" dataset = folder.make_dataset( directory=os.path.join(self._dataset_path, self._split), class_to_idx=self.class_to_idx, extensions=(".png"), ) - return dataset + return dict(dataset) def _print_license(self) -> None: """Prints the dataset license.""" diff --git a/src/eva/vision/data/datasets/segmentation/consep.py b/src/eva/vision/data/datasets/segmentation/consep.py index f0cd1b551..9c0ed9dc9 100644 --- a/src/eva/vision/data/datasets/segmentation/consep.py +++ b/src/eva/vision/data/datasets/segmentation/consep.py @@ -55,7 +55,6 @@ def __init__( width: Width of the patches to be extracted, in pixels. height: Height of the patches to be extracted, in pixels. target_mpp: Target microns per pixel (mpp) for the patches. - backend: The backend to use for reading the whole-slide images. transforms: Transforms to apply to the extracted image & mask patches. """ self._split = split diff --git a/src/eva/vision/data/wsi/patching/samplers/_utils.py b/src/eva/vision/data/wsi/patching/samplers/_utils.py index f1fa3b7e0..caa840686 100644 --- a/src/eva/vision/data/wsi/patching/samplers/_utils.py +++ b/src/eva/vision/data/wsi/patching/samplers/_utils.py @@ -9,6 +9,7 @@ def get_grid_coords_and_indices( height: int, overlap: Tuple[int, int], shuffle: bool = True, + include_last: bool = False, seed: int = 42, ): """Get grid coordinates and indices. @@ -19,16 +20,26 @@ def get_grid_coords_and_indices( height: The height of the patches. overlap: The overlap between patches in the grid. shuffle: Whether to shuffle the indices. + include_last: Whether to include coordinates of the last patch when it + it partially exceeds the image. seed: The random seed. """ - x_range = range(0, layer_shape[0] - width + 1, width - overlap[0]) - y_range = range(0, layer_shape[1] - height + 1, height - overlap[1]) - x_y = [(x, y) for x in x_range for y in y_range] + x_coords = list(range(0, layer_shape[0] - width + 1, width - overlap[0])) + y_coords = list(range(0, layer_shape[1] - height + 1, height - overlap[1])) + + if include_last: + if layer_shape[0] % (width - overlap[0]) != 0: + x_coords.append(x_coords[-1] + width - overlap[0]) + if layer_shape[1] % (height - overlap[1]) != 0: + y_coords.append(y_coords[-1] + height - overlap[1]) + + x_y = [(x, y) for x in x_coords for y in y_coords] indices = list(range(len(x_y))) if shuffle: random_generator = np.random.default_rng(seed) random_generator.shuffle(indices) + return x_y, indices diff --git a/src/eva/vision/data/wsi/patching/samplers/grid.py b/src/eva/vision/data/wsi/patching/samplers/grid.py index 3f2b00819..9205cf2f0 100644 --- a/src/eva/vision/data/wsi/patching/samplers/grid.py +++ b/src/eva/vision/data/wsi/patching/samplers/grid.py @@ -6,24 +6,32 @@ class GridSampler(base.Sampler): - """Sample patches based on a grid. - - Args: - max_samples: The maximum number of samples to return. - overlap: The overlap between patches in the grid. - seed: The random seed. - """ + """Sample patches based on a grid.""" def __init__( self, max_samples: int | None = None, overlap: Tuple[int, int] = (0, 0), seed: int = 42, + validate_dimensions: bool = True, + include_last: bool = False, ): - """Initializes the sampler.""" + """Initializes the sampler. + + Args: + max_samples: The maximum number of samples to return. + overlap: The overlap between patches in the grid. + seed: The random seed. + validate_dimensions: Whether to validate the dimensions the image. It + expects the patch size to be smaller than the image size. + include_last: Whether to include coordinates of the last patch when it + it partially exceeds the image. + """ self.max_samples = max_samples self.overlap = overlap self.seed = seed + self.validate_dimensions = validate_dimensions + self.include_last = include_last def sample( self, @@ -38,9 +46,10 @@ def sample( height: The height of the patches. layer_shape: The shape of the layer. """ - _utils.validate_dimensions(width, height, layer_shape) + if self.validate_dimensions: + _utils.validate_dimensions(width, height, layer_shape) x_y, indices = _utils.get_grid_coords_and_indices( - layer_shape, width, height, self.overlap, seed=self.seed + layer_shape, width, height, self.overlap, seed=self.seed, include_last=self.include_last ) max_samples = len(indices) if self.max_samples is None else self.max_samples for i in indices[:max_samples]: diff --git a/tests/eva/vision/data/datasets/classification/test_bracs.py b/tests/eva/vision/data/datasets/classification/test_bracs.py index 2bd2c829a..5073d92c7 100644 --- a/tests/eva/vision/data/datasets/classification/test_bracs.py +++ b/tests/eva/vision/data/datasets/classification/test_bracs.py @@ -9,6 +9,7 @@ from torchvision import tv_tensors from eva.vision.data import datasets +from eva.vision.data.wsi.patching import samplers @pytest.mark.parametrize( @@ -31,13 +32,13 @@ def test_sample(bracs_dataset: datasets.BRACS, index: int) -> None: # assert the format of the `image` and `target` image, target, _ = sample assert isinstance(image, tv_tensors.Image) - assert image.shape == (3, 40, 40) + assert image.shape == (3, 10, 10) assert isinstance(target, torch.Tensor) assert target in [0, 1, 2, 3, 4, 5, 6, 7, 8] @pytest.fixture(scope="function") -def bracs_dataset(split: Literal["train", "val"], assets_path: str) -> datasets.BRACS: +def bracs_dataset(split: Literal["train", "val", "test"], assets_path: str) -> datasets.BRACS: """BRACS dataset fixture.""" with mock.patch.object( datasets.BRACS, "classes", new_callable=mock.PropertyMock @@ -46,6 +47,10 @@ def bracs_dataset(split: Literal["train", "val"], assets_path: str) -> datasets. dataset = datasets.BRACS( root=os.path.join(assets_path, "vision", "datasets", "bracs"), split=split, + width=10, + height=10, + target_mpp=0.25, + sampler=samplers.GridSampler(), ) dataset.prepare_data() dataset.configure() diff --git a/tests/eva/vision/data/wsi/patching/samplers/test_foreground_grid.py b/tests/eva/vision/data/wsi/patching/samplers/test_foreground_grid.py index c87ee8f12..ca43acf87 100644 --- a/tests/eva/vision/data/wsi/patching/samplers/test_foreground_grid.py +++ b/tests/eva/vision/data/wsi/patching/samplers/test_foreground_grid.py @@ -55,6 +55,7 @@ def test_same_seed(max_samples: int, seed: int, x_y_expected: list) -> None: x_y_2 = list(sampler.sample(**TEST_ARGS)) assert x_y_1 == x_y_2 + assert x_y_1 == x_y_expected @pytest.mark.parametrize("max_samples, seed_1, seed_2", [(3, 1, 2), (5, 3, 4)]) diff --git a/tests/eva/vision/data/wsi/patching/samplers/test_random.py b/tests/eva/vision/data/wsi/patching/samplers/test_random.py index 09c1d2796..dbe643ba1 100644 --- a/tests/eva/vision/data/wsi/patching/samplers/test_random.py +++ b/tests/eva/vision/data/wsi/patching/samplers/test_random.py @@ -48,6 +48,7 @@ def test_same_seed(n_samples: int, seed: int, x_y_expected: int) -> None: x_y_2 = list(sampler_2.sample(**TEST_ARGS)) assert x_y_1 == x_y_2 + assert x_y_1 == x_y_expected @pytest.mark.parametrize("n_samples, seed_1, seed_2", [(10, 1, 2), (22, 3, 4)]) From 45411f12cdf1845507714f2594fcf19d690b8a5f Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Thu, 3 Apr 2025 10:55:12 +0200 Subject: [PATCH 2/7] refactoring --- .../data/wsi/patching/samplers/_utils.py | 31 +++++++++---------- .../wsi/patching/samplers/foreground_grid.py | 4 +-- .../vision/data/wsi/patching/samplers/grid.py | 9 +++--- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/eva/vision/data/wsi/patching/samplers/_utils.py b/src/eva/vision/data/wsi/patching/samplers/_utils.py index caa840686..ae847baeb 100644 --- a/src/eva/vision/data/wsi/patching/samplers/_utils.py +++ b/src/eva/vision/data/wsi/patching/samplers/_utils.py @@ -1,9 +1,9 @@ -from typing import Tuple +from typing import Tuple, List import numpy as np -def get_grid_coords_and_indices( +def get_grid_coords( layer_shape: Tuple[int, int], width: int, height: int, @@ -11,7 +11,7 @@ def get_grid_coords_and_indices( shuffle: bool = True, include_last: bool = False, seed: int = 42, -): +) -> List[Tuple[int, int]]: """Get grid coordinates and indices. Args: @@ -19,28 +19,27 @@ def get_grid_coords_and_indices( width: The width of the patches. height: The height of the patches. overlap: The overlap between patches in the grid. - shuffle: Whether to shuffle the indices. + shuffle: Whether to shuffle the order of the coordinates. include_last: Whether to include coordinates of the last patch when it - it partially exceeds the image. + it partially exceeds the image and therefore is smaller than the + specified patch size. seed: The random seed. - """ - x_coords = list(range(0, layer_shape[0] - width + 1, width - overlap[0])) - y_coords = list(range(0, layer_shape[1] - height + 1, height - overlap[1])) - if include_last: - if layer_shape[0] % (width - overlap[0]) != 0: - x_coords.append(x_coords[-1] + width - overlap[0]) - if layer_shape[1] % (height - overlap[1]) != 0: - y_coords.append(y_coords[-1] + height - overlap[1]) + Returns: + A list of tuples with the (x, y) coordinates. + """ + x_range = range(0, layer_shape[0] - width + 1, width - overlap[0]) + y_range = range(0, layer_shape[1] - height + 1, height - overlap[1]) - x_y = [(x, y) for x in x_coords for y in y_coords] + x_y = [(x, y) for x in x_range for y in y_range] - indices = list(range(len(x_y))) if shuffle: + indices = list(range(len(x_y))) random_generator = np.random.default_rng(seed) random_generator.shuffle(indices) + x_y = [x_y[i] for i in indices] - return x_y, indices + return x_y def validate_dimensions(width: int, height: int, layer_shape: Tuple[int, int]) -> None: diff --git a/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py b/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py index 7671ab9f5..9dd7ca4dc 100644 --- a/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py +++ b/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py @@ -46,12 +46,12 @@ def sample( mask: The mask of the image. """ _utils.validate_dimensions(width, height, layer_shape) - x_y, indices = _utils.get_grid_coords_and_indices( + x_y = _utils.get_grid_coords( layer_shape, width, height, self.overlap, seed=self.seed ) count = 0 - for i in indices: + for i in range(len(x_y)): if count >= self.max_samples: break diff --git a/src/eva/vision/data/wsi/patching/samplers/grid.py b/src/eva/vision/data/wsi/patching/samplers/grid.py index 9205cf2f0..151003dbc 100644 --- a/src/eva/vision/data/wsi/patching/samplers/grid.py +++ b/src/eva/vision/data/wsi/patching/samplers/grid.py @@ -25,7 +25,8 @@ def __init__( validate_dimensions: Whether to validate the dimensions the image. It expects the patch size to be smaller than the image size. include_last: Whether to include coordinates of the last patch when it - it partially exceeds the image. + it partially exceeds the image and therefore is smaller than the + specified patch size. """ self.max_samples = max_samples self.overlap = overlap @@ -48,9 +49,9 @@ def sample( """ if self.validate_dimensions: _utils.validate_dimensions(width, height, layer_shape) - x_y, indices = _utils.get_grid_coords_and_indices( + x_y = _utils.get_grid_coords( layer_shape, width, height, self.overlap, seed=self.seed, include_last=self.include_last ) - max_samples = len(indices) if self.max_samples is None else self.max_samples - for i in indices[:max_samples]: + max_samples = len(x_y) if self.max_samples is None else self.max_samples + for i in range(max_samples): yield x_y[i] From e87cb25694c5b02d9d60f030f6622be79efe9447 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Thu, 3 Apr 2025 11:51:42 +0200 Subject: [PATCH 3/7] added include_partial_patches feature in grid sampler --- .../offline/classification/bracs.yaml | 5 +- .../data/wsi/patching/samplers/_utils.py | 26 +++++---- .../wsi/patching/samplers/foreground_grid.py | 5 +- .../vision/data/wsi/patching/samplers/grid.py | 21 ++++--- .../data/wsi/patching/samplers/test_grid.py | 57 +++++++++++++------ 5 files changed, 68 insertions(+), 46 deletions(-) diff --git a/configs/vision/pathology/offline/classification/bracs.yaml b/configs/vision/pathology/offline/classification/bracs.yaml index 8f662b7ca..e97414021 100644 --- a/configs/vision/pathology/offline/classification/bracs.yaml +++ b/configs/vision/pathology/offline/classification/bracs.yaml @@ -81,7 +81,7 @@ data: embeddings_transforms: class_path: eva.core.data.transforms.Pad2DTensor init_args: - pad_size: &N_PATCHES ${oc.env:N_PATCHES, 500} + pad_size: &N_PATCHES ${oc.env:N_PATCHES, 250} val: class_path: eva.datasets.MultiEmbeddingsClassificationDataset init_args: @@ -101,8 +101,7 @@ data: class_path: eva.vision.data.wsi.patching.samplers.GridSampler init_args: max_samples: *N_PATCHES - validate_dimensions: false # Some images are smaller than 224x224 - include_last: true + include_partial_patches: true transforms: class_path: eva.vision.data.transforms.common.ResizeAndCrop init_args: diff --git a/src/eva/vision/data/wsi/patching/samplers/_utils.py b/src/eva/vision/data/wsi/patching/samplers/_utils.py index ae847baeb..9376d0a38 100644 --- a/src/eva/vision/data/wsi/patching/samplers/_utils.py +++ b/src/eva/vision/data/wsi/patching/samplers/_utils.py @@ -1,26 +1,26 @@ -from typing import Tuple, List +from typing import List, Tuple import numpy as np def get_grid_coords( - layer_shape: Tuple[int, int], + image_size: Tuple[int, int], width: int, height: int, overlap: Tuple[int, int], shuffle: bool = True, - include_last: bool = False, + include_partial_patches: bool = False, seed: int = 42, ) -> List[Tuple[int, int]]: """Get grid coordinates and indices. Args: - layer_shape: The shape of the layer. + image_size: The shape of the complete image. width: The width of the patches. height: The height of the patches. overlap: The overlap between patches in the grid. shuffle: Whether to shuffle the order of the coordinates. - include_last: Whether to include coordinates of the last patch when it + include_partial_patches: Whether to include coordinates of the last patch when it it partially exceeds the image and therefore is smaller than the specified patch size. seed: The random seed. @@ -28,16 +28,20 @@ def get_grid_coords( Returns: A list of tuples with the (x, y) coordinates. """ - x_range = range(0, layer_shape[0] - width + 1, width - overlap[0]) - y_range = range(0, layer_shape[1] - height + 1, height - overlap[1]) + if not include_partial_patches and (width > image_size[0] or height > image_size[1]): + raise ValueError("The patch size cannot be bigger than the image.") + + x_stop = image_size[0] if include_partial_patches else image_size[0] - width + 1 + y_stop = image_size[1] if include_partial_patches else image_size[1] - height + 1 + + x_coords = list(range(0, x_stop, width - overlap[0])) or [0] + y_coords = list(range(0, y_stop, height - overlap[1])) or [0] - x_y = [(x, y) for x in x_range for y in y_range] + x_y = [(x, y) for x in x_coords for y in y_coords] if shuffle: - indices = list(range(len(x_y))) random_generator = np.random.default_rng(seed) - random_generator.shuffle(indices) - x_y = [x_y[i] for i in indices] + random_generator.shuffle(x_y) return x_y diff --git a/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py b/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py index 9dd7ca4dc..fe70ac16c 100644 --- a/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py +++ b/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py @@ -45,10 +45,7 @@ def sample( layer_shape: The shape of the layer. mask: The mask of the image. """ - _utils.validate_dimensions(width, height, layer_shape) - x_y = _utils.get_grid_coords( - layer_shape, width, height, self.overlap, seed=self.seed - ) + x_y = _utils.get_grid_coords(layer_shape, width, height, self.overlap, seed=self.seed) count = 0 for i in range(len(x_y)): diff --git a/src/eva/vision/data/wsi/patching/samplers/grid.py b/src/eva/vision/data/wsi/patching/samplers/grid.py index 151003dbc..beea4f116 100644 --- a/src/eva/vision/data/wsi/patching/samplers/grid.py +++ b/src/eva/vision/data/wsi/patching/samplers/grid.py @@ -13,8 +13,7 @@ def __init__( max_samples: int | None = None, overlap: Tuple[int, int] = (0, 0), seed: int = 42, - validate_dimensions: bool = True, - include_last: bool = False, + include_partial_patches: bool = False, ): """Initializes the sampler. @@ -22,17 +21,14 @@ def __init__( max_samples: The maximum number of samples to return. overlap: The overlap between patches in the grid. seed: The random seed. - validate_dimensions: Whether to validate the dimensions the image. It - expects the patch size to be smaller than the image size. - include_last: Whether to include coordinates of the last patch when it + include_partial_patches: Whether to include coordinates of the last patch when it it partially exceeds the image and therefore is smaller than the specified patch size. """ self.max_samples = max_samples self.overlap = overlap self.seed = seed - self.validate_dimensions = validate_dimensions - self.include_last = include_last + self.include_partial_patches = include_partial_patches def sample( self, @@ -47,11 +43,14 @@ def sample( height: The height of the patches. layer_shape: The shape of the layer. """ - if self.validate_dimensions: - _utils.validate_dimensions(width, height, layer_shape) x_y = _utils.get_grid_coords( - layer_shape, width, height, self.overlap, seed=self.seed, include_last=self.include_last + layer_shape, + width, + height, + self.overlap, + seed=self.seed, + include_partial_patches=self.include_partial_patches, ) - max_samples = len(x_y) if self.max_samples is None else self.max_samples + max_samples = len(x_y) if self.max_samples is None else min(self.max_samples, len(x_y)) for i in range(max_samples): yield x_y[i] diff --git a/tests/eva/vision/data/wsi/patching/samplers/test_grid.py b/tests/eva/vision/data/wsi/patching/samplers/test_grid.py index 42cbdbf80..3ab2086db 100644 --- a/tests/eva/vision/data/wsi/patching/samplers/test_grid.py +++ b/tests/eva/vision/data/wsi/patching/samplers/test_grid.py @@ -65,28 +65,51 @@ def test_different_seed(max_samples: int, seed_1: int, seed_2: int) -> None: assert x_y_1 != x_y_2 -def test_invalid_width_height() -> None: - """Tests if the sampler raises an error when width / height is bigger than layer_shape.""" - sampler = samplers.GridSampler(max_samples=10, seed=42) - - with pytest.raises(ValueError): - list(sampler.sample(width=200, height=200, layer_shape=(100, 100))) - - @pytest.mark.parametrize( - "width, height, layer_shape", + "width, height, layer_shape, include_partial_patches, expected_n_samples", [ - (5, 5, (25, 25)), - (5, 5, (100, 100)), - (224, 224, (1000, 1000)), + (5, 5, (25, 25), False, 25), + (5, 5, (25, 25), True, 25), + (224, 224, (1000, 1000), False, 16), + (224, 224, (1000, 1000), True, 25), + (10, 10, (5, 5), True, 1), ], ) -def test_expected_n_patches(width: int, height: int, layer_shape: Tuple[int, int]) -> None: +def test_expected_n_patches( + width: int, + height: int, + layer_shape: Tuple[int, int], + include_partial_patches: bool, + expected_n_samples: int, +) -> None: """Tests if the sampler respects the max_samples limit.""" - sampler = samplers.GridSampler(max_samples=None) - - expected_max_samples = (layer_shape[0] // width) * (layer_shape[1] // height) + sampler = samplers.GridSampler( + max_samples=None, include_partial_patches=include_partial_patches + ) x_y = list(sampler.sample(width=width, height=height, layer_shape=layer_shape)) - assert len(x_y) == expected_max_samples + assert len(x_y) == expected_n_samples + + +@pytest.mark.parametrize("include_partial_patches", [False, True]) +def test_patch_bigger_than_image(include_partial_patches: bool) -> None: + """Test edge case where the patch size is bigger than the image.""" + sampler = samplers.GridSampler( + max_samples=10, seed=42, include_partial_patches=include_partial_patches + ) + patch_dim, image_dim = 200, 100 + + if not include_partial_patches: + with pytest.raises(ValueError, match="The patch size cannot be bigger than the image."): + list( + sampler.sample( + width=patch_dim, height=patch_dim, layer_shape=(image_dim, image_dim) + ) + ) + else: + x_y = list( + sampler.sample(width=patch_dim, height=patch_dim, layer_shape=(image_dim, image_dim)) + ) + assert len(x_y) == 1 + assert x_y[0] == (0, 0) From 6c585a2b2561ce8c2ccea2e06dde44c3ef6e738a Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Thu, 3 Apr 2025 13:57:29 +0200 Subject: [PATCH 4/7] override _verify_location for PILImage --- src/eva/vision/data/wsi/backends/pil.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/eva/vision/data/wsi/backends/pil.py b/src/eva/vision/data/wsi/backends/pil.py index fdef88b01..fb37b0038 100644 --- a/src/eva/vision/data/wsi/backends/pil.py +++ b/src/eva/vision/data/wsi/backends/pil.py @@ -50,3 +50,16 @@ def _read_region( ) ) return np.array(patch) + + @override + def _verify_location(self, location: Tuple[int, int], size: Tuple[int, int]) -> None: + """Verifies that the requested coordinates are within the slide dimensions. + + Args: + location: Top-left corner (x, y) coordinates to read. + size: Size of the requested region (width, height) + """ + x_max, y_max = self.level_dimensions[0] + + if int(location[0]) >= x_max or int(location[1]) >= y_max: + raise ValueError(f"Out of bounds region: {location}, {size}") From 3f38318cf52952cc58dced2cb14520f470899dea Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Thu, 3 Apr 2025 14:36:38 +0200 Subject: [PATCH 5/7] updated docs --- docs/datasets/bracs.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/datasets/bracs.md b/docs/datasets/bracs.md index 568560c67..962b39242 100644 --- a/docs/datasets/bracs.md +++ b/docs/datasets/bracs.md @@ -29,6 +29,8 @@ we use in this benchmarks contains 4539 extracted ROIs / patches. | **Files format** | `png` | | **Number of images** | 4539 | +Given that some images in the `BRACS_ROI` can be very large, we split the image into a grid of 224x224 patches and then treat it as a multi instance learning (MIL) classification problem. + ### Splits From 28294c96ddb70331fccde96e762951f2fff1336c Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Thu, 3 Apr 2025 14:49:03 +0200 Subject: [PATCH 6/7] increase number of runs to 20 --- configs/vision/pathology/offline/classification/bracs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/vision/pathology/offline/classification/bracs.yaml b/configs/vision/pathology/offline/classification/bracs.yaml index e97414021..54e454b1a 100644 --- a/configs/vision/pathology/offline/classification/bracs.yaml +++ b/configs/vision/pathology/offline/classification/bracs.yaml @@ -2,7 +2,7 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + n_runs: &N_RUNS ${oc.env:N_RUNS, 20} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/bracs} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} From dbc515285bf36c658067289953d0e53432de70df Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Thu, 3 Apr 2025 14:57:42 +0200 Subject: [PATCH 7/7] configured patience --- configs/vision/pathology/offline/classification/bracs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/vision/pathology/offline/classification/bracs.yaml b/configs/vision/pathology/offline/classification/bracs.yaml index 54e454b1a..aa1a25275 100644 --- a/configs/vision/pathology/offline/classification/bracs.yaml +++ b/configs/vision/pathology/offline/classification/bracs.yaml @@ -24,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: ${oc.env:PATIENCE, 167} + patience: ${oc.env:PATIENCE, 22} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.ClassificationEmbeddingsWriter