Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions trident/InferenceStrategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Protocol

import numpy as np
import torch
from torch.utils.data import DataLoader

from trident.patch_encoder_models.load import BasePatchEncoder, CustomInferenceEncoder


class InferenceStrategy(Protocol):
"""
An interface that supports arbitrary strategies for inference of
image patch embedding models.
"""

def forward(
self,
dataloader: DataLoader,
patch_encoder: BasePatchEncoder | CustomInferenceEncoder,
device: torch.device,
precision: torch.dtype,
) -> np.ndarray: ...


class DefaultInferenceStrategy(InferenceStrategy):
"""
This is the default inference strategy for embedding image patches.
It sequentially processes one batch at a time on the specified device/GPU.
Automatic mixed precision is enabled if `precision` != torch.float32.

Args:
dataloader (DataLoader):
A dataloader that generates image patches.
patch_encoder (BasePatchEncoder | CustomInferenceEncoder):
The image patch embedding model.
device (torch.device):
Device to run feature extraction on (e.g., 'cuda:0').
precision (torch.dtype):
Precision of embedding model weights (e.g. torch.float32).

Returns:
embeddings (np.ndarray): The embeddings for the batch of image patches
"""

def forward(
self,
dataloader: DataLoader,
patch_encoder: BasePatchEncoder | CustomInferenceEncoder,
device: torch.device,
precision: torch.dtype,
) -> np.ndarray:
features = []
for imgs, _ in dataloader:
imgs = imgs.to(device)
with torch.autocast(
device_type="cuda",
dtype=precision,
enabled=(precision != torch.float32),
):
batch_features = patch_encoder(imgs)
features.append(batch_features.cpu().numpy())
features = np.concatenate(features, axis=0)
return features
40 changes: 33 additions & 7 deletions trident/Processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import os
import sys
import torch
from tqdm import tqdm
import shutil
from typing import Optional, List, Dict, Any
Expand All @@ -12,6 +13,9 @@
from trident.Maintenance import deprecated
from trident.Converter import OPENSLIDE_EXTENSIONS, PIL_EXTENSIONS
from trident import WSIReaderType
from trident.patch_encoder_models.load import BasePatchEncoder, CustomInferenceEncoder
from trident.InferenceStrategy import InferenceStrategy, DefaultInferenceStrategy



class Processor:
Expand Down Expand Up @@ -132,7 +136,7 @@ def __init__(
print(f'Using local cache at {wsi_cache}, which currently contains {len(os.listdir(wsi_cache))} files.')

# Lazy-init WSI objects
self.wsis = []
self.wsis: List[OpenSlideWSI] = []
for wsi_idx, wsi in enumerate(valid_slides):
wsi_path = os.path.join(self.wsi_cache, wsi) if self.wsi_cache is not None else os.path.join(self.wsi_source, wsi)

Expand Down Expand Up @@ -464,11 +468,13 @@ def run_feature_extraction_job(
def run_patch_feature_extraction_job(
self,
coords_dir: str,
patch_encoder: torch.nn.Module,
patch_encoder: BasePatchEncoder | CustomInferenceEncoder,
device: str,
saveas: str = 'h5',
batch_limit: int = 512,
saveto: str | None = None
saveto: str | None = None,
inference_strategy: InferenceStrategy = DefaultInferenceStrategy(),
pin_memory: bool = True
) -> str:
"""
The `run_feature_extraction_job` function computes features from the patches generated during the
Expand All @@ -478,8 +484,8 @@ def run_patch_feature_extraction_job(
Parameters:
coords_dir (str):
Path to the directory containing patch coordinates, which are used to locate patches for feature extraction.
patch_encoder (torch.nn.Module):
A pre-trained PyTorch model used to compute features from the extracted patches.
patch_encoder (BasePatchEncoder | CustomInferenceEncoder):
A pre-trained model used to compute features from the extracted patches.
device (str):
The computation device to use (e.g., 'cuda:0' for GPU or 'cpu' for CPU).
saveas (str, optional):
Expand All @@ -489,6 +495,14 @@ def run_patch_feature_extraction_job(
saveto (str, optional):
Directory where the extracted features will be saved. If not provided, a directory name will
be generated automatically. Defaults to None.
inference_strategy (InferenceStrategy, optional):
Allows you to provide arbitrary logic for running inference. Useful
for non-Pytorch models, or advanced needs such as model and data parallelism.
The default implementation assumes a Pytorch model running on a single
GPU, and enables automatic mixed precision when dtype != torch.float32.
pin_memory (bool, optional):
If True, the data loader will copy Tensors into device/CUDA pinned memory
before returning them. Defaults to True.

Returns:
str: The absolute path to where the features are saved.
Expand Down Expand Up @@ -518,13 +532,23 @@ def run_patch_feature_extraction_job(
ignore = ['patch_encoder', 'loop', 'valid_slides', 'wsis']
)

# If patch_encoder is a Pytorch model, or a CustomInferenceEncoder that wraps
# a PyTorch model, we automatically set eval mode and move weights to specified device.
# In all other cases, the user must ensure data and weights reside on the same device.
if isinstance(patch_encoder, torch.nn.Module):
patch_encoder.eval()
patch_encoder.to(device)
elif isinstance(patch_encoder.model, torch.nn.Module):
patch_encoder.model.eval()
patch_encoder.model.to(device)

log_fp = os.path.join(self.job_dir, coords_dir, f'_logs_feats_{patch_encoder.enc_name}.txt')
self.loop = tqdm(self.wsis, desc=f'Extracting patch features from coords in {coords_dir}', total = len(self.wsis))
for wsi in self.loop:
wsi_feats_fp = os.path.join(self.job_dir, saveto, f'{wsi.name}.{saveas}')
# Check if features already exist
if os.path.exists(wsi_feats_fp) and not is_locked(wsi_feats_fp):
self.loop.set_postfix_str(f'Features already extracted for {wsi}. Skipping...')
self.loop.set_postfix_str(f'Features already extracted for {wsi.name}. Skipping...')
update_log(log_fp, f'{wsi.name}{wsi.ext}', 'Features extracted.')
self.cleanup(f'{wsi.name}{wsi.ext}')
continue
Expand Down Expand Up @@ -560,7 +584,9 @@ def run_patch_feature_extraction_job(
save_features=os.path.join(self.job_dir, saveto),
device=device,
saveas=saveas,
batch_limit=batch_limit
batch_limit=batch_limit,
inference_strategy=inference_strategy,
pin_memory=pin_memory
)

remove_lock(wsi_feats_fp)
Expand Down
23 changes: 15 additions & 8 deletions trident/patch_encoder_models/load.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import traceback
from abc import abstractmethod
from typing import Optional
from typing import Optional, Callable
import numpy as np
import torch
import os

Expand Down Expand Up @@ -171,8 +172,14 @@ def _build(self, **build_kwargs):
pass


class CustomInferenceEncoder(BasePatchEncoder):
def __init__(self, enc_name, model, transforms, precision):
class CustomInferenceEncoder:
def __init__(
self,
enc_name: str,
model: Callable[..., np.ndarray],
transforms: Callable,
precision: torch.dtype,
):
"""
Initialize a CustomInferenceEncoder from user-defined components.

Expand All @@ -182,8 +189,8 @@ def __init__(self, enc_name, model, transforms, precision):
Args:
enc_name (str):
A unique name or identifier for the encoder (used for registry or logging).
model (torch.nn.Module):
A PyTorch model instance to use for inference.
model Callable[..., np.ndarray]:
A model instance to use for inference.
transforms (Callable):
A callable (e.g., torchvision or timm transform) to preprocess input images for evaluation.
precision (torch.dtype):
Expand All @@ -194,9 +201,9 @@ def __init__(self, enc_name, model, transforms, precision):
self.model = model
self.eval_transforms = transforms
self.precision = precision
def _build(self):
return None, None, None

def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)


class MuskInferenceEncoder(BasePatchEncoder):
Expand Down
36 changes: 20 additions & 16 deletions trident/wsi_objects/WSI.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
import torch
from typing import List, Tuple, Optional, Literal
from torch.utils.data import DataLoader
import geopandas as gpd
import cv2

from trident.wsi_objects.WSIPatcher import *
from trident.patch_encoder_models.load import BasePatchEncoder, CustomInferenceEncoder
from trident.InferenceStrategy import InferenceStrategy, DefaultInferenceStrategy
from trident.wsi_objects.WSIPatcher import OpenSlideWSIPatcher
from trident.wsi_objects.WSIPatcherDataset import WSIPatcherDataset
from trident.IO import (
save_h5, read_coords, read_coords_legacy,
Expand Down Expand Up @@ -611,12 +615,14 @@ def visualize_coords(self, coords_path: str, save_patch_viz: str) -> str:
@torch.inference_mode()
def extract_patch_features(
self,
patch_encoder: torch.nn.Module,
patch_encoder: BasePatchEncoder | CustomInferenceEncoder,
coords_path: str,
save_features: str,
device: str = 'cuda:0',
saveas: str = 'h5',
batch_limit: int = 512
batch_limit: int = 512,
inference_strategy: InferenceStrategy = DefaultInferenceStrategy(),
pin_memory: bool = True
) -> str:
"""
The `extract_patch_features` function of the class `WSI` extracts feature embeddings
Expand All @@ -625,7 +631,7 @@ def extract_patch_features(

Args:
-----
patch_encoder : torch.nn.Module
patch_encoder : BasePatchEncoder | CustomInferenceEncoder
The model used for feature extraction.
coords_path : str
Path to the file containing patch coordinates.
Expand All @@ -637,6 +643,14 @@ def extract_patch_features(
Format to save the features ('h5' or 'pt'). Defaults to 'h5'.
batch_limit : int, optional
Maximum batch size for feature extraction. Defaults to 512.
inference_strategy (InferenceStrategy, optional):
Allows you to provide arbitrary logic for running inference. Useful
for non-Pytorch models, or advanced needs such as model and data parallelism.
The default implementation assumes a Pytorch model running on a single
GPU, and enables automatic mixed precision when dtype != torch.float32.
pin_memory (bool, optional):
If True, the data loader will copy Tensors into device/CUDA pinned memory
before returning them. Defaults to True.

Returns:
--------
Expand All @@ -645,7 +659,7 @@ def extract_patch_features(

Example:
--------
>>> features_path = wsi.extract_features(patch_encoder, "output_coords/sample_name_patches.h5", "output_features")
>>> features_path = wsi.extract_patch_features(patch_encoder, "output_coords/sample_name_patches.h5", "output_features")
>>> print(features_path)
output_features/sample_name.h5
"""
Expand Down Expand Up @@ -679,17 +693,7 @@ def extract_patch_features(
)
dataset = WSIPatcherDataset(patcher, patch_transforms)
dataloader = DataLoader(dataset, batch_size=batch_limit, num_workers=get_num_workers(batch_limit, max_workers=self.max_workers), pin_memory=True)
# dataloader = DataLoader(dataset, batch_size=batch_limit, num_workers=0, pin_memory=True)

features = []
for imgs, _ in dataloader:
imgs = imgs.to(device)
with torch.autocast(device_type='cuda', dtype=precision, enabled=(precision != torch.float32)):
batch_features = patch_encoder(imgs)
features.append(batch_features.cpu().numpy())

# Concatenate features
features = np.concatenate(features, axis=0)
features = inference_strategy.forward(dataloader, patch_encoder, device, precision)

# Save the features to disk
os.makedirs(save_features, exist_ok=True)
Expand Down