diff --git a/src/hescape/constants.py b/src/hescape/constants.py new file mode 100644 index 0000000..d3e4476 --- /dev/null +++ b/src/hescape/constants.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from enum import Enum + +import torch +import torchvision.transforms.v2 as T + + +class DatasetEnum(str, Enum): + """Enum of datasets covariates.""" + + NAME = "name" + IMG = "image" + GEXP = "gexp" + COORDS = "cell_coords" + SOURCE = "source" + ATLAS = "atlas" + AGE = "age" + DIAGNOSIS = "diagnosis" + CANCER = "cancer" + ONCOTREE_CODE = "oncotree_code" + TISSUE = "tissue" + TUMOR_GRADE = "tumor_grade" + GENDER = "gender" + RACE = "race" + TREATMENT_TYPE = "treatment_type" + THERAPEUTIC_AGENTS = "therapeutic_agents" + TUMOR_TISSUE_TYPE = "tumor_tissue_type" + ASSAY = "assay" + PRESERVATION_METHOD = "preservation_method" + STAIN = "stain" + SPACERANGER = "spaceranger" + SPECIES = "species" + CYTASSIST = "cytassist" + + # PIXEL_SIZE = "pixel_size_um_estimated" + # MAGNIFICATION = "magnification" + # # UCE = "uce" + + +class CameoDatasetEnum(str, Enum): + """Enum of datasets covariates.""" + + IMG = "image" + SPECIES = "species" + CANCER = "cancer" + TISSUE = "tissue" + NAME = "name" + ID = "id" + IMG_EMBED = "img_embed" + GAT = "gat" + ANNOTATION = "annotation" + CELL_TYPE_RATIO = "cell_type_ratio" + GEXP = "gexp" + MASK = "mask" + COORDS = "cell_coords" + SPOT_GEXP = "spot_gexp" + + +EVAL_TRANSFORMS = { + "conch": T.Compose( + [ + T.ToImage(), + T.RandomChoice([T.CenterCrop(512)]), + T.Resize((480, 480), antialias=True, interpolation=T.InterpolationMode.BICUBIC), + T.CenterCrop((480, 480)), + T.ConvertImageDtype(torch.float32), + T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)), + ] + ), + "optimus": T.Compose( + [ + T.ToImage(), + T.RandomChoice([T.CenterCrop(512), T.CenterCrop(256)]), + T.Resize((224, 224)), + T.ConvertImageDtype(torch.float32), + T.Normalize(mean=(0.707223, 0.578729, 0.703617), std=(0.211883, 0.230117, 0.177517)), + ] + ), + "h0-mini": T.Compose( + [ + T.ToImage(), + T.RandomChoice([T.CenterCrop(512), T.CenterCrop(256)]), + T.Resize((224, 224)), + T.ConvertImageDtype(torch.float32), + T.Normalize(mean=(0.707223, 0.578729, 0.703617), std=(0.211883, 0.230117, 0.177517)), + ] + ), + "uni": T.Compose( + [ + T.ToImage(), + T.RandomChoice([T.CenterCrop(512), T.CenterCrop(256)]), + T.Resize((224, 224)), + T.ConvertImageDtype(torch.float32), + T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ), + "ctranspath": T.Compose( + [ + T.ToImage(), + T.RandomChoice([T.CenterCrop(512), T.CenterCrop(256)]), + T.Resize((224, 224)), + T.ConvertImageDtype(torch.float32), + T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ), + "gigapath": T.Compose( + [ + T.ToImage(), + T.RandomChoice([T.CenterCrop(512), T.CenterCrop(256)]), + T.Resize((224, 224)), + T.ConvertImageDtype(torch.float32), + T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ), +} diff --git a/src/hescape/data_modules/image_gexp_dataset.py b/src/hescape/data_modules/image_gexp_dataset.py index ef26cff..a12d047 100644 --- a/src/hescape/data_modules/image_gexp_dataset.py +++ b/src/hescape/data_modules/image_gexp_dataset.py @@ -325,6 +325,16 @@ def forward(self, x): T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ] ), + "lunit": T.Compose( + [ + T.ToImage(), + T.CenterCrop(512), + T.Resize(size=248, interpolation=T.InterpolationMode.BICUBIC, max_size=None, antialias=True), + T.CenterCrop(size=[224, 224]), + T.ConvertImageDtype(torch.float32), + T.Normalize(mean=(0.4850, 0.4560, 0.4060), std=(0.2290, 0.2240, 0.2250)), + ] + ), "augment": T.Compose( [ T.RandomHorizontalFlip(), @@ -345,7 +355,8 @@ def __init__( self, dataset_path: Path, data_gene_reference_path: Path, - img_model_name: Literal["ctranspath", "densenet", "uni", "optimus", "conch", "gigapath", "h0-mini"] | str, + img_model_name: Literal["ctranspath", "densenet", "uni", "optimus", "conch", "gigapath", "h0-mini", "lunit"] + | str, gene_model_name: Literal["drvi", "nicheformer", "scfoundation", "generic"] | str, source_key: str = "source", source_value=None, diff --git a/src/hescape/models/clip.py b/src/hescape/models/clip.py index cc7507b..076a1d6 100644 --- a/src/hescape/models/clip.py +++ b/src/hescape/models/clip.py @@ -47,7 +47,7 @@ def __init__( self, input_genes: int, embed_dim: int, - img_enc_name: Literal["ctranspath", "densenet", "uni", "optimus", "conch", "gigapath"], + img_enc_name: Literal["ctranspath", "densenet", "uni", "optimus", "conch", "gigapath", "lunit"], gene_enc_name: Literal["drvi", "nicheformer", "scfoundation", "uce", "generic"], loss: Literal["CLIP", "SIGLIP"], img_finetune: bool = False, diff --git a/src/hescape/models/image_models/image_encoder.py b/src/hescape/models/image_models/image_encoder.py index e756613..0f4bdb2 100644 --- a/src/hescape/models/image_models/image_encoder.py +++ b/src/hescape/models/image_models/image_encoder.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from peft import LoraConfig, get_peft_model +from safetensors.torch import load_file from timm.layers import Mlp from hescape.models._utils import print_trainable_parameters @@ -23,7 +24,7 @@ class ImageEncoder(nn.Module): def __init__( self, - model_name: Literal["ctranspath", "densenet", "uni", "optimus", "conch", "gigapath", "h0-mini"] | str, + model_name: Literal["ctranspath", "densenet", "uni", "optimus", "conch", "gigapath", "h0-mini", "lunit"] | str, finetune: bool = False, embed_dim: int = -1, proj: str = "mlp", @@ -117,6 +118,17 @@ def _build_trunk(self, model_name: str, checkpoint_root: Path, **kwargs: Any) -> # total_blocks may differ, set it according to your needs total_blocks = 12 # Example + elif model_name == "lunit": + trunk = timm.create_model( + model_name="hf-hub:1aurent/vit_small_patch8_224.lunit_dino", + pretrained=False, + ) + checkpoint_path = checkpoint_root / model_name / "model.safetensors" + trunk.load_state_dict(load_file(checkpoint_path), strict=True) + print(f"Successfully loaded weights for {model_name}") + # total_blocks may differ, set it according to your needs + total_blocks = 12 # Example + else: raise ValueError(f"Unknown model name: {model_name}") @@ -143,6 +155,7 @@ def get_ft_model(self, model_name: str, trunk, lora: bool = False) -> object: "optimus": {"r": 8, "lora_alpha": 16, "target_modules": ["qkv", "proj"]}, "h0-mini": {"r": 8, "lora_alpha": 16, "target_modules": ["qkv", "proj"]}, "gigapath": {"r": 8, "lora_alpha": 16, "target_modules": ["qkv", "proj"]}, + "lunit": {"r": 8, "lora_alpha": 16, "target_modules": ["qkv", "proj"]}, } if lora: diff --git a/src/hescape/modules/pretrain_module.py b/src/hescape/modules/pretrain_module.py index d0acbe9..d21e694 100644 --- a/src/hescape/modules/pretrain_module.py +++ b/src/hescape/modules/pretrain_module.py @@ -46,7 +46,7 @@ def __init__( self, input_genes: int, embed_dim: int, - img_enc_name: Literal["ctranspath", "uni", "conch", "optimus", "densenet", "gigapath"], + img_enc_name: Literal["ctranspath", "uni", "conch", "optimus", "densenet", "gigapath", "lunit"], gene_enc_name: Literal["drvi", "nicheformer", "uce", "scfoundation", "generic"], loss: Literal["CLIP", "SIGLIP"], img_finetune: bool,