Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions configs/vision/pathology/offline/classification/bracs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 5}
n_runs: &N_RUNS ${oc.env:N_RUNS, 20}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/bracs}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500}
checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best}
Expand All @@ -24,7 +24,7 @@ trainer:
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: ${oc.env:PATIENCE, 167}
patience: ${oc.env:PATIENCE, 22}
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
- class_path: eva.callbacks.ClassificationEmbeddingsWriter
Expand All @@ -49,15 +49,17 @@ model:
class_path: eva.HeadModule
init_args:
head:
class_path: torch.nn.Linear
class_path: eva.vision.models.networks.ABMIL
init_args:
in_features: ${oc.env:IN_FEATURES, 384}
out_features: &NUM_CLASSES 7
input_size: ${oc.env:IN_FEATURES, 384}
output_size: &NUM_CLASSES 7
projected_input_size: 128
criterion: torch.nn.CrossEntropyLoss
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: ${oc.env:LR_VALUE, 0.0003}
lr: ${oc.env:LR_VALUE, 0.001}
betas: [0.9, 0.999]
metrics:
common:
- class_path: eva.metrics.AverageLoss
Expand All @@ -69,18 +71,24 @@ data:
init_args:
datasets:
train:
class_path: eva.datasets.EmbeddingsClassificationDataset
class_path: eva.datasets.MultiEmbeddingsClassificationDataset
init_args: &DATASET_ARGS
root: *DATASET_EMBEDDINGS_ROOT
manifest_file: manifest.csv
split: train
column_mapping:
multi_id: origin
embeddings_transforms:
class_path: eva.core.data.transforms.Pad2DTensor
init_args:
pad_size: &N_PATCHES ${oc.env:N_PATCHES, 250}
val:
class_path: eva.datasets.EmbeddingsClassificationDataset
class_path: eva.datasets.MultiEmbeddingsClassificationDataset
init_args:
<<: *DATASET_ARGS
split: val
test:
class_path: eva.datasets.EmbeddingsClassificationDataset
class_path: eva.datasets.MultiEmbeddingsClassificationDataset
init_args:
<<: *DATASET_ARGS
split: test
Expand All @@ -89,10 +97,15 @@ data:
init_args: &PREDICT_DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data/bracs}
split: train
sampler:
class_path: eva.vision.data.wsi.patching.samplers.GridSampler
init_args:
max_samples: *N_PATCHES
include_partial_patches: true
transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
size: ${oc.env:RESIZE_DIM, 224}
size: ${oc.env:RESIZE_DIM, 224}
mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
- class_path: eva.vision.datasets.BRACS
Expand All @@ -105,7 +118,7 @@ data:
split: test
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256}
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32}
num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4}
shuffle: true
val:
Expand Down
94 changes: 0 additions & 94 deletions configs/vision/pathology/online/classification/bracs.yaml

This file was deleted.

2 changes: 2 additions & 0 deletions docs/datasets/bracs.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ we use in this benchmarks contains 4539 extracted ROIs / patches.
| **Files format** | `png` |
| **Number of images** | 4539 |

Given that some images in the `BRACS_ROI` can be very large, we split the image into a grid of 224x224 patches and then treat it as a multi instance learning (MIL) classification problem.


### Splits

Expand Down
76 changes: 47 additions & 29 deletions src/eva/vision/data/datasets/classification/bracs.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
"""BRACS dataset class."""

import os
from typing import Callable, Dict, List, Literal, Tuple
from typing import Any, Callable, Dict, List, Literal, Tuple

import torch
from torchvision import tv_tensors
from torchvision.datasets import folder
from torchvision.transforms.v2 import functional
from typing_extensions import override

from eva.vision.data.datasets import _validators
from eva.vision.data.datasets import _validators, wsi
from eva.vision.data.datasets.classification import base
from eva.vision.utils import io
from eva.vision.data.wsi.patching import samplers


class BRACS(base.ImageClassification):
class BRACS(wsi.MultiWsiDataset, base.ImageClassification):
"""Dataset class for BRACS images and corresponding targets."""

_expected_dataset_lengths: Dict[str, int] = {
_expected_files: Dict[str, int] = {
"train": 3657,
"val": 312,
"test": 570,
}
"""Expected dataset lengths for the splits and complete dataset."""
"""Expected number of files for each split."""

_license: str = "CC BY-NC 4.0 (https://creativecommons.org/licenses/by-nc/4.0/)"
"""Dataset license."""
Expand All @@ -30,22 +31,42 @@ def __init__(
self,
root: str,
split: Literal["train", "val", "test"],
sampler: samplers.Sampler | None = None,
width: int = 224,
height: int = 224,
target_mpp: float = 0.25,
transforms: Callable | None = None,
) -> None:
"""Initializes the dataset.

Args:
root: Path to the root directory of the dataset.
split: Dataset split to use.
sampler: The sampler to use for sampling patch coordinates.
If `None`, it will use the ::class::`ForegroundGridSampler` sampler.
width: Width of the patches to be extracted, in pixels.
height: Height of the patches to be extracted, in pixels.
target_mpp: Target microns per pixel (mpp) for the patches.
transforms: A function/transform which returns a transformed
version of the raw data samples.
"""
super().__init__(transforms=transforms)

self._root = root
self._split = split
self._path_to_target = self._make_dataset()
self._file_to_path = {os.path.basename(p): p for p in self._path_to_target.keys()}

self._samples: List[Tuple[str, int]] = []
wsi.MultiWsiDataset.__init__(
self,
root=root,
file_paths=sorted(self._path_to_target.keys()),
width=width,
height=height,
sampler=sampler or samplers.ForegroundGridSampler(max_samples=25),
target_mpp=target_mpp,
overwrite_mpp=0.25,
backend="pil",
image_transforms=transforms,
)

@property
@override
Expand All @@ -57,55 +78,52 @@ def classes(self) -> List[str]:
def class_to_idx(self) -> Dict[str, int]:
return {name: index for index, name in enumerate(self.classes)}

@override
def filename(self, index: int) -> str:
image_path, *_ = self._samples[index]
return os.path.relpath(image_path, self._dataset_path)

@override
def prepare_data(self) -> None:
_validators.check_dataset_exists(self._root, True)

@override
def configure(self) -> None:
self._samples = self._make_dataset()

@override
def validate(self) -> None:
if len(self._path_to_target) != self._expected_files[self._split]:
raise ValueError(
f"Expected {self._split} split to have {self._expected_files[self._split]} files, "
f"but found {len(self._path_to_target)} files."
)

_validators.check_dataset_integrity(
self,
length=self._expected_dataset_lengths[self._split],
length=None,
n_classes=7,
first_and_last_labels=("0_N", "6_IC"),
)

@override
def load_image(self, index: int) -> tv_tensors.Image:
image_path, _ = self._samples[index]
return io.read_image_as_tensor(image_path)
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
return base.ImageClassification.__getitem__(self, index)

@override
def load_target(self, index: int) -> torch.Tensor:
_, target = self._samples[index]
return torch.tensor(target, dtype=torch.long)
def load_image(self, index: int) -> tv_tensors.Image:
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
return functional.to_image(image_array)

@override
def __len__(self) -> int:
return len(self._samples)
def load_target(self, index: int) -> torch.Tensor:
path = self._file_to_path[self.filename(index)]
return torch.tensor(self._path_to_target[path], dtype=torch.long)

@property
def _dataset_path(self) -> str:
"""Returns the full path of dataset directory."""
return os.path.join(self._root, "BRACS_RoI/latest_version")

def _make_dataset(self) -> List[Tuple[str, int]]:
def _make_dataset(self) -> Dict[str, int]:
"""Builds the dataset for the specified split."""
dataset = folder.make_dataset(
directory=os.path.join(self._dataset_path, self._split),
class_to_idx=self.class_to_idx,
extensions=(".png"),
)
return dataset
return dict(dataset)

def _print_license(self) -> None:
"""Prints the dataset license."""
Expand Down
1 change: 0 additions & 1 deletion src/eva/vision/data/datasets/segmentation/consep.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
width: Width of the patches to be extracted, in pixels.
height: Height of the patches to be extracted, in pixels.
target_mpp: Target microns per pixel (mpp) for the patches.
backend: The backend to use for reading the whole-slide images.
transforms: Transforms to apply to the extracted image & mask patches.
"""
self._split = split
Expand Down
13 changes: 13 additions & 0 deletions src/eva/vision/data/wsi/backends/pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,16 @@ def _read_region(
)
)
return np.array(patch)

@override
def _verify_location(self, location: Tuple[int, int], size: Tuple[int, int]) -> None:
"""Verifies that the requested coordinates are within the slide dimensions.

Args:
location: Top-left corner (x, y) coordinates to read.
size: Size of the requested region (width, height)
"""
x_max, y_max = self.level_dimensions[0]

if int(location[0]) >= x_max or int(location[1]) >= y_max:
raise ValueError(f"Out of bounds region: {location}, {size}")
Loading
Loading