diff --git a/configs/vision/pathology/offline/classification/bach-linear-experiments/eval_dinov3_custom.yaml b/configs/vision/pathology/offline/classification/bach-linear-experiments/eval_dinov3_custom.yaml new file mode 100644 index 000000000..835847ba1 --- /dev/null +++ b/configs/vision/pathology/offline/classification/bach-linear-experiments/eval_dinov3_custom.yaml @@ -0,0 +1,114 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/bach} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 6000} #usually 12500 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: ${oc.env:SAVE_LAST, false} + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: ${oc.env:PATIENCE, 1250} + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.ClassificationEmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, dino_vits16}/bach + dataloader_idx_map: + 0: train + 1: val + backbone: + class_path: eva.core.models.wrappers.ModelFromLocal + init_args: + local_repo_path: "./dinov3" + model_name: "dinov3_vith16plus" #The model you want to load + checkpoint_path: "./dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth" #Path to the model checkpoint + + + overwrite: true + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.HeadModule + init_args: + head: + class_path: eva.core.models.networks.batch_linear.batch_linear + init_args: + in_features: ${oc.env:IN_FEATURES, 1280} #MODIFY ME + out_features: &NUM_CLASSES 4 + criterion: torch.nn.CrossEntropyLoss + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.0001} + metrics: + common: + - class_path: eva.metrics.AverageLoss + - class_path: eva.metrics.MulticlassClassificationMetrics + init_args: + num_classes: *NUM_CLASSES +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + val: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: + <<: *DATASET_ARGS + split: val + predict: + - class_path: eva.vision.datasets.BACH + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/bach} + split: train + download: ${oc.env:DOWNLOAD_DATA, false} + # Set `download: true` to download the dataset from https://zenodo.org/records/3632035 + # The BACH dataset is distributed under the following license + # Attribution-NonCommercial-NoDerivs 4.0 International license + # (see: https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode) + transforms: + class_path: eva.vision.data.transforms.common.ResizeAndCrop + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} #Flagged for maybe needs fixing + std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.BACH + init_args: + <<: *PREDICT_DATASET_ARGS + split: val + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} + num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} + shuffle: true + val: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} + num_workers: *N_DATA_WORKERS diff --git a/configs/vision/pathology/offline/classification/bach-linear-experiments/eval_dinov3_custom_adj_pooled_linear.yaml b/configs/vision/pathology/offline/classification/bach-linear-experiments/eval_dinov3_custom_adj_pooled_linear.yaml new file mode 100644 index 000000000..af46c4d2e --- /dev/null +++ b/configs/vision/pathology/offline/classification/bach-linear-experiments/eval_dinov3_custom_adj_pooled_linear.yaml @@ -0,0 +1,114 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/bach} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 6000} #usually 12500 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: ${oc.env:SAVE_LAST, false} + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: ${oc.env:PATIENCE, 1250} + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.ClassificationEmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, dino_vits16}/bach + dataloader_idx_map: + 0: train + 1: val + backbone: + class_path: eva.core.models.wrappers.ModelFromLocal + init_args: + local_repo_path: "./dinov3" + model_name: "dinov3_vith16plus" #The model you want to load + checkpoint_path: "./dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth" #Path to the model checkpoint + + + overwrite: true + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.HeadModule + init_args: + head: + class_path: eva.core.models.networks.adj_pooled_linear.adj_pooled_linear + init_args: + in_features: ${oc.env:IN_FEATURES, 1280} #MODIFY ME + out_features: &NUM_CLASSES 4 + criterion: torch.nn.CrossEntropyLoss + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.0001} + metrics: + common: + - class_path: eva.metrics.AverageLoss + - class_path: eva.metrics.MulticlassClassificationMetrics + init_args: + num_classes: *NUM_CLASSES +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + val: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: + <<: *DATASET_ARGS + split: val + predict: + - class_path: eva.vision.datasets.BACH + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/bach} + split: train + download: ${oc.env:DOWNLOAD_DATA, false} + # Set `download: true` to download the dataset from https://zenodo.org/records/3632035 + # The BACH dataset is distributed under the following license + # Attribution-NonCommercial-NoDerivs 4.0 International license + # (see: https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode) + transforms: + class_path: eva.vision.data.transforms.common.ResizeAndCrop + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} #Flagged for maybe needs fixing + std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.BACH + init_args: + <<: *PREDICT_DATASET_ARGS + split: val + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} + num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} + shuffle: true + val: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} + num_workers: *N_DATA_WORKERS diff --git a/configs/vision/pathology/offline/classification/bach-linear-experiments/eval_dinov3_custom_adj_pooled_linear_noreg.yaml b/configs/vision/pathology/offline/classification/bach-linear-experiments/eval_dinov3_custom_adj_pooled_linear_noreg.yaml new file mode 100644 index 000000000..65d5c1817 --- /dev/null +++ b/configs/vision/pathology/offline/classification/bach-linear-experiments/eval_dinov3_custom_adj_pooled_linear_noreg.yaml @@ -0,0 +1,114 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/bach} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 6000} #usually 12500 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: ${oc.env:SAVE_LAST, false} + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: ${oc.env:PATIENCE, 1250} + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.ClassificationEmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, dino_vits16}/bach + dataloader_idx_map: + 0: train + 1: val + backbone: + class_path: eva.core.models.wrappers.ModelFromLocal + init_args: + local_repo_path: "./dinov3" + model_name: "dinov3_vith16plus" #The model you want to load + checkpoint_path: "./dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth" #Path to the model checkpoint + + + overwrite: true + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.HeadModule + init_args: + head: + class_path: eva.core.models.networks.adj_pooled_linear_noreg.adj_pooled_linear_noreg + init_args: + in_features: ${oc.env:IN_FEATURES, 1280} #MODIFY ME + out_features: &NUM_CLASSES 4 + criterion: torch.nn.CrossEntropyLoss + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.0001} + metrics: + common: + - class_path: eva.metrics.AverageLoss + - class_path: eva.metrics.MulticlassClassificationMetrics + init_args: + num_classes: *NUM_CLASSES +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + val: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: + <<: *DATASET_ARGS + split: val + predict: + - class_path: eva.vision.datasets.BACH + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/bach} + split: train + download: ${oc.env:DOWNLOAD_DATA, false} + # Set `download: true` to download the dataset from https://zenodo.org/records/3632035 + # The BACH dataset is distributed under the following license + # Attribution-NonCommercial-NoDerivs 4.0 International license + # (see: https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode) + transforms: + class_path: eva.vision.data.transforms.common.ResizeAndCrop + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} #Flagged for maybe needs fixing + std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.BACH + init_args: + <<: *PREDICT_DATASET_ARGS + split: val + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} + num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} + shuffle: true + val: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} + num_workers: *N_DATA_WORKERS diff --git a/configs/vision/pathology/offline/classification/bach-linear-experiments/eval_dinov3_custom_tied_double_linear.yaml b/configs/vision/pathology/offline/classification/bach-linear-experiments/eval_dinov3_custom_tied_double_linear.yaml new file mode 100644 index 000000000..863016308 --- /dev/null +++ b/configs/vision/pathology/offline/classification/bach-linear-experiments/eval_dinov3_custom_tied_double_linear.yaml @@ -0,0 +1,114 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/bach} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 6000} #usually 12500 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: ${oc.env:SAVE_LAST, false} + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: ${oc.env:PATIENCE, 1250} + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.ClassificationEmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, dino_vits16}/bach + dataloader_idx_map: + 0: train + 1: val + backbone: + class_path: eva.core.models.wrappers.ModelFromLocal + init_args: + local_repo_path: "./dinov3" + model_name: "dinov3_vith16plus" #The model you want to load + checkpoint_path: "./dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth" #Path to the model checkpoint + + + overwrite: true + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.HeadModule + init_args: + head: + class_path: eva.core.models.networks.tied_double_linear.tied_double_linear + init_args: + in_features: ${oc.env:IN_FEATURES, 1280} #MODIFY ME + out_features: &NUM_CLASSES 4 + criterion: torch.nn.CrossEntropyLoss + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.0001} + metrics: + common: + - class_path: eva.metrics.AverageLoss + - class_path: eva.metrics.MulticlassClassificationMetrics + init_args: + num_classes: *NUM_CLASSES +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + val: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: + <<: *DATASET_ARGS + split: val + predict: + - class_path: eva.vision.datasets.BACH + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/bach} + split: train + download: ${oc.env:DOWNLOAD_DATA, false} + # Set `download: true` to download the dataset from https://zenodo.org/records/3632035 + # The BACH dataset is distributed under the following license + # Attribution-NonCommercial-NoDerivs 4.0 International license + # (see: https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode) + transforms: + class_path: eva.vision.data.transforms.common.ResizeAndCrop + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} #Flagged for maybe needs fixing + std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.BACH + init_args: + <<: *PREDICT_DATASET_ARGS + split: val + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} + num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} + shuffle: true + val: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} + num_workers: *N_DATA_WORKERS diff --git a/eval_dinov3_custom.yaml b/eval_dinov3_custom.yaml index d7a700fa2..835847ba1 100644 --- a/eval_dinov3_custom.yaml +++ b/eval_dinov3_custom.yaml @@ -51,7 +51,7 @@ model: class_path: eva.HeadModule init_args: head: - class_path: batch_linear.batch_linear + class_path: eva.core.models.networks.batch_linear.batch_linear init_args: in_features: ${oc.env:IN_FEATURES, 1280} #MODIFY ME out_features: &NUM_CLASSES 4 diff --git a/src/eva/core/data/datasets/classification/embeddings.py b/src/eva/core/data/datasets/classification/embeddings.py index 489b20dd2..1ef4ce3de 100644 --- a/src/eva/core/data/datasets/classification/embeddings.py +++ b/src/eva/core/data/datasets/classification/embeddings.py @@ -26,7 +26,7 @@ def load_embeddings(self, index: int) -> torch.Tensor: @override def load_target(self, index: int) -> torch.Tensor: - target = self._data.at[index, self._column_mapping["target"]] + target = int(self._data.at[index, self._column_mapping["target"]]) return torch.tensor(target, dtype=torch.int64) @override diff --git a/src/eva/core/models/networks/adj_pooled_linear.py b/src/eva/core/models/networks/adj_pooled_linear.py new file mode 100644 index 000000000..02fb422f2 --- /dev/null +++ b/src/eva/core/models/networks/adj_pooled_linear.py @@ -0,0 +1,61 @@ +import torch +from torch import nn +import math + +class adj_pooled_linear(nn.Module): + + def __init__( + self, + in_features: int, + out_features: int, + pool_kernel_size = (7, 7), + patch_grid_hw = (14, 14), + ): + super().__init__() + + self.in_features = in_features + self.grid_h, self.grid_w = patch_grid_hw + self.n_patches = self.grid_h * self.grid_w + + if self.grid_h <= 0 or self.grid_w <= 0: + raise ValueError(f"patch_grid_hw must be positive, got {patch_grid_hw}.") + + k_h, k_w = pool_kernel_size + if (self.grid_h % k_h) != 0 or (self.grid_w % k_w) != 0: + raise ValueError(f"pool_kernel_size {pool_kernel_size} must tile patch_grid_hw {patch_grid_hw}.") + + self.pooled_h = self.grid_h // k_h + self.pooled_w = self.grid_w // k_w + pooled_n_patches = self.pooled_h * self.pooled_w + + self.batch = nn.RMSNorm(in_features) + self.pool = nn.AvgPool2d(kernel_size=pool_kernel_size, stride=pool_kernel_size) + + self.cls_lin = nn.Linear(in_features, out_features) + self.patch_linear = nn.Linear(in_features * pooled_n_patches, out_features) + self.reg_linear = nn.Linear(in_features * 4, out_features) + + def forward(self, x): + + x = self.batch(x) + + B = x.shape[0] + + registers = x[:, 0:4, :] + patch = x[:, 4:-1, :] + cls = x[:, -1, :] + + patch_grid = patch.reshape(B, self.grid_h, self.grid_w, self.in_features).permute(0, 3, 1, 2) + pooled_patches_flat = self.pool(patch_grid).reshape(B, -1) + + registers_flat = registers.reshape(B, -1) + + registers_out = self.reg_linear(registers_flat) + patch_out = self.patch_linear(pooled_patches_flat) + cls_out = self.cls_lin(cls) + + xs = torch.stack((registers_out, patch_out, cls_out), dim=1) + + x_out = torch.mean(xs, dim=1) + + return x_out diff --git a/src/eva/core/models/networks/adj_pooled_linear_noreg.py b/src/eva/core/models/networks/adj_pooled_linear_noreg.py new file mode 100644 index 000000000..d42932cb9 --- /dev/null +++ b/src/eva/core/models/networks/adj_pooled_linear_noreg.py @@ -0,0 +1,56 @@ +import torch +from torch import nn +import math + +class adj_pooled_linear_noreg(nn.Module): + + def __init__( + self, + in_features: int, + out_features: int, + pool_kernel_size = (7, 7), + patch_grid_hw = (14, 14), + ): + super().__init__() + + self.in_features = in_features + self.grid_h, self.grid_w = patch_grid_hw + self.n_patches = self.grid_h * self.grid_w + + if self.grid_h <= 0 or self.grid_w <= 0: + raise ValueError(f"patch_grid_hw must be positive, got {patch_grid_hw}.") + + k_h, k_w = pool_kernel_size + if (self.grid_h % k_h) != 0 or (self.grid_w % k_w) != 0: + raise ValueError(f"pool_kernel_size {pool_kernel_size} must tile patch_grid_hw {patch_grid_hw}.") + + self.pooled_h = self.grid_h // k_h + self.pooled_w = self.grid_w // k_w + pooled_n_patches = self.pooled_h * self.pooled_w + + self.batch = nn.RMSNorm(in_features) + self.pool = nn.AvgPool2d(kernel_size=pool_kernel_size, stride=pool_kernel_size) + + self.cls_lin = nn.Linear(in_features, out_features) + self.patch_linear = nn.Linear(in_features * pooled_n_patches, out_features) + + def forward(self, x): + + x = self.batch(x) + + B = x.shape[0] + + patch = x[:, 4:-1, :] + cls = x[:, -1, :] + + patch_grid = patch.reshape(B, self.grid_h, self.grid_w, self.in_features).permute(0, 3, 1, 2) + pooled_patches_flat = self.pool(patch_grid).reshape(B, -1) + + patch_out = self.patch_linear(pooled_patches_flat) + cls_out = self.cls_lin(cls) + + xs = torch.stack((patch_out, cls_out), dim=1) + + x_out = torch.mean(xs, dim=1) + + return x_out diff --git a/src/eva/core/models/networks/batch_linear.py b/src/eva/core/models/networks/batch_linear.py new file mode 100644 index 000000000..6490c76a2 --- /dev/null +++ b/src/eva/core/models/networks/batch_linear.py @@ -0,0 +1,23 @@ + + +from torch import nn + +class batch_linear(nn.Module): + + def __init__(self, in_features: int, out_features: int): + super().__init__() + + #self.batch = nn.BatchNorm1d(in_features) + self.batch = nn.RMSNorm(in_features) + self.lin = nn.Linear(in_features, out_features) + + + def forward(self, x): + + x = self.batch(x) + x = self.lin(x) + x = x[:, -1, :] + + return x + + diff --git a/src/eva/core/models/networks/three_linear.py b/src/eva/core/models/networks/three_linear.py new file mode 100644 index 000000000..d208ab887 --- /dev/null +++ b/src/eva/core/models/networks/three_linear.py @@ -0,0 +1,43 @@ + +import torch +from torch import nn + +class three_linear(nn.Module): + + def __init__(self, in_features: int, out_features: int): + super().__init__() + + #self.batch = nn.BatchNorm1d(out_features) + #self.batch = nn.RMSNorm(out_features) + + self.cls_lin = nn.Linear(in_features, out_features) + self.patch_linear = nn.Linear(in_features * 196, out_features) + self.reg_linear = nn.Linear(in_features * 4, out_features) + + def forward(self, x): + + #4 register, 196 patch, 1 cls + registers = x[:, 0:4, :] + registers = registers.reshape(x.shape[0], -1)#Flatten + + patch = x[:, 4:-1, :] + patch = patch.reshape(x.shape[0], -1) + + cls = x[:, -1 , :] + + + registers = self.reg_linear(registers) + patch = self.patch_linear(patch) + cls = self.cls_lin(cls) + + + xs = torch.stack((registers, patch, cls), axis = 1) + + #output shape is batch, 3, embedd + + #Either average the features, or learn some kind of combiner. + x = torch.mean(xs, dim = 1) +# x = self.batch(x) + + return x + diff --git a/src/eva/core/models/networks/tied_double_linear.py b/src/eva/core/models/networks/tied_double_linear.py new file mode 100644 index 000000000..09f62ac87 --- /dev/null +++ b/src/eva/core/models/networks/tied_double_linear.py @@ -0,0 +1,46 @@ +import torch +from torch import nn + +class tied_double_linear(nn.Module): + + def __init__( + self, + in_features: int, + out_features: int, + projection_dim: int = 64, + n_patches: int = 196 + ): + + super().__init__() + + self.batch = nn.RMSNorm(in_features) + self.projection = nn.Linear(in_features, projection_dim) + + self.cls_lin = nn.Linear(in_features, out_features) + self.patch_linear = nn.Linear(projection_dim * n_patches, out_features) + self.reg_linear = nn.Linear(in_features * 4, out_features) + + def forward(self, x): + + x = self.batch(x) + + B = x.shape[0] + + registers = x[:, 0:4, :] + patch = x[:, 4:-1, :] + cls = x[:, -1, :] + + projected_patches = self.projection(patch) + projected_patches_flat = projected_patches.reshape(B, -1) + + registers_flat = registers.reshape(B, -1) + + registers_out = self.reg_linear(registers_flat) + patch_out = self.patch_linear(projected_patches_flat) + cls_out = self.cls_lin(cls) + + xs = torch.stack((registers_out, patch_out, cls_out), dim=1) + + x_out = torch.mean(xs, dim=1) + + return x_out