diff --git a/beast/__init__.py b/beast/__init__.py index 3b95ecd..04449b7 100644 --- a/beast/__init__.py +++ b/beast/__init__.py @@ -1,12 +1,44 @@ # Hacky way to get version from pypackage.toml. # Adapted from: https://github.com/python-poetry/poetry/issues/273#issuecomment-1877789967 import importlib.metadata +import time from pathlib import Path -from typing import Any +from typing import Any, Optional __package_version = "unknown" +def log_step( + msg: str, + level: Optional[str] = None, + flush: bool = True, + logger: Any = None, +) -> None: + """Unified logging function with optional level. + + Parameters + ---------- + msg: message to log + level: None (plain timestamp + msg), 'info', 'debug', or 'error' + flush: whether to flush stdout after printing + logger: if provided and level is 'info', also call logger.info(msg); + if provided and level is 'error', also call logger.error(msg) + """ + timestamp = time.strftime('%Y-%m-%d %H:%M:%S') + if level == 'info': + print(f"[{timestamp}] INFO: {msg}", flush=flush) + if logger is not None: + logger.info(msg) + elif level == 'debug': + print(f"[{timestamp}] DEBUG: {msg}", flush=flush) + elif level == 'error': + print(f"[{timestamp}] ERROR: {msg}", flush=flush) + if logger is not None: + logger.error(msg) + else: + print(f"[{timestamp}] {msg}", flush=flush) + + def __get_package_version() -> str: """Find the version of this package.""" diff --git a/beast/api/model.py b/beast/api/model.py index b8a13c3..876861e 100644 --- a/beast/api/model.py +++ b/beast/api/model.py @@ -1,5 +1,6 @@ import contextlib import os +import time from pathlib import Path from typing import Any @@ -12,6 +13,7 @@ from beast.models.resnets import ResnetAutoencoder from beast.models.vits import VisionTransformer from beast.train import train +from beast import log_step # TODO: Replace with contextlib.chdir in python 3.11. @@ -112,7 +114,15 @@ def from_config(cls, config_path: str | Path | dict): # Initialize the LightningModule model_class = cls.MODEL_REGISTRY[model_type] + log_step(f"Creating {model_type} model instance", level='debug') + log_step( + f"About to call {model_class.__name__}.__init__() - this may take several minutes if downloading pretrained weights", + level='debug', + ) + init_start = time.time() model = model_class(config) + init_duration = time.time() - init_start + log_step(f"Model initialization completed in {init_duration:.2f} seconds", level='debug') print(f'Initialized a {model_class} model') diff --git a/beast/cli/commands/train.py b/beast/cli/commands/train.py index 862e719..1c8e20d 100644 --- a/beast/cli/commands/train.py +++ b/beast/cli/commands/train.py @@ -4,6 +4,7 @@ import logging from pathlib import Path +from beast import log_step from beast.cli.types import config_file, output_dir _logger = logging.getLogger('BEAST.CLI.TRAIN') @@ -65,50 +66,73 @@ def register_parser(subparsers): def handle(args): """Handle the train command execution.""" + log_step("Starting train command handler", level='info', logger=_logger) + # Determine output directory + log_step("Determining output directory", level='info', logger=_logger) if not args.output: now = datetime.datetime.now() args.output = Path('runs').resolve() / now.strftime('%Y-%m-%d') / now.strftime('%H-%M-%S') args.output.mkdir(parents=True, exist_ok=True) + log_step(f"Output directory: {args.output}", level='info', logger=_logger) # Set up logging to the model directory + log_step("Setting up model logging", level='info', logger=_logger) model_log_handler = _setup_model_logging(args.output) + log_step("Model logging set up", level='info', logger=_logger) # try: # Load config + log_step(f"Loading config from: {args.config}", level='info', logger=_logger) from beast.io import load_config config = load_config(args.config) + log_step("Config loaded", level='info', logger=_logger) # Apply overrides if args.overrides: + log_step("Applying config overrides", level='info', logger=_logger) from beast.io import apply_config_overrides config = apply_config_overrides(config, args.overrides) + log_step("Config overrides applied", level='info', logger=_logger) # Override specific values from command line + log_step("Applying command line overrides", level='info', logger=_logger) if args.data: config['data']['data_dir'] = str(args.data) + log_step(f"Data directory overridden to: {args.data}", level='info', logger=_logger) if args.gpus is not None: config['training']['num_gpus'] = args.gpus + log_step(f"Number of GPUs overridden to: {args.gpus}", level='info', logger=_logger) if args.nodes is not None: config['training']['num_nodes'] = args.nodes + log_step(f"Number of nodes overridden to: {args.nodes}", level='info', logger=_logger) + + # Check for unsupported --checkpoint argument + if hasattr(args, 'checkpoint') and args.checkpoint: + log_step( + f"WARNING: --checkpoint argument provided but not supported: {args.checkpoint}", level='info', logger=_logger) + log_step("Checkpoint resuming is not currently implemented in the CLI", + level='info', logger=_logger) # Initialize model + log_step("Initializing model from config", level='info', logger=_logger) from beast.api.model import Model model = Model.from_config(config) + log_step("Model initialized", level='info', logger=_logger) # if args.resume: # train_kwargs['resume_from_checkpoint'] = args.resume - - _logger.info(f'Training {type(model.model)} model') - _logger.info(f'Data directory: {args.data}') - _logger.info(f'Output directory: {args.output}') + log_step(f'Training {type(model.model)} model', level='info', logger=_logger) + log_step(f'Data directory: {args.data}', level='info', logger=_logger) + log_step(f'Output directory: {args.output}', level='info', logger=_logger) # Run training + log_step("About to call model.train()", level='info', logger=_logger) model.train(output_dir=args.output) - - _logger.info(f'Training complete. Model saved to {args.output}') + log_step("model.train() completed", level='info', logger=_logger) + log_step(f'Training complete. Model saved to {args.output}', level='info', logger=_logger) # except Exception as e: diff --git a/beast/data/datamodules.py b/beast/data/datamodules.py index 145204d..2ec88ad 100644 --- a/beast/data/datamodules.py +++ b/beast/data/datamodules.py @@ -1,6 +1,7 @@ """Data modules split a dataset into train, val, and test modules.""" import copy +import multiprocessing import os import lightning.pytorch as pl @@ -155,10 +156,13 @@ def train_dataloader(self) -> torch.utils.data.DataLoader: batch_size=None if self.use_sampler else self.train_batch_size, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, + pin_memory=True, # Helps with GPU transfer shuffle=True if not self.use_sampler else False, sampler=self.sampler if self.use_sampler else None, generator=torch.Generator().manual_seed(self.seed), collate_fn=contrastive_collate_fn if self.use_sampler else None, + multiprocessing_context=multiprocessing.get_context( + 'spawn') if self.num_workers > 0 else None, # More stable on HPC ) def val_dataloader(self) -> torch.utils.data.DataLoader: @@ -167,6 +171,9 @@ def val_dataloader(self) -> torch.utils.data.DataLoader: batch_size=self.val_batch_size, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, + pin_memory=True, + multiprocessing_context=multiprocessing.get_context( + 'spawn') if self.num_workers > 0 else None, ) def test_dataloader(self) -> torch.utils.data.DataLoader: @@ -174,6 +181,8 @@ def test_dataloader(self) -> torch.utils.data.DataLoader: self.test_dataset, batch_size=self.test_batch_size, num_workers=self.num_workers, + pin_memory=True, + multiprocessing_context='spawn' if self.num_workers > 0 else None, ) def full_labeled_dataloader(self) -> torch.utils.data.DataLoader: diff --git a/beast/data/datasets.py b/beast/data/datasets.py index 18aa56a..488d696 100644 --- a/beast/data/datasets.py +++ b/beast/data/datasets.py @@ -1,5 +1,6 @@ """Dataset objects store images and augmentation pipeline.""" +import time from pathlib import Path from typing import Callable @@ -9,6 +10,7 @@ from torchvision import transforms from typeguard import typechecked +from beast import log_step from beast.data.types import ExampleDict _IMAGENET_MEAN = [0.485, 0.456, 0.406] @@ -28,15 +30,29 @@ def __init__(self, data_dir: str | Path, imgaug_pipeline: Callable | None) -> No imgaug_transform: imgaug transform pipeline to apply to images """ + log_step(f"BaseDataset.__init__ called with data_dir: {data_dir}", level='debug') self.data_dir = Path(data_dir) if not self.data_dir.is_dir(): raise ValueError(f'{self.data_dir} is not a directory') + log_step(f"Data directory exists: {self.data_dir}", level='debug') self.imgaug_pipeline = imgaug_pipeline # collect ALL png files in data_dir - self.image_list = sorted(list(self.data_dir.rglob('*.png'))) + scan_start = time.time() + try: + log_step( + f"Starting to scan for PNG files in {self.data_dir} (this may take a while for large directories)...", level='debug') + self.image_list = sorted(list(self.data_dir.rglob('*.png'))) + scan_duration = time.time() - scan_start + log_step( + f"Finished scanning. Found {len(self.image_list)} PNG files in {scan_duration:.2f} seconds", level='debug') + except Exception as e: + log_step(f"ERROR during file scanning: {e}", level='error') + raise if len(self.image_list) == 0: raise ValueError(f'{self.data_dir} does not contain image data in png format') + log_step( + f"BaseDataset initialization complete with {len(self.image_list)} images", level='debug') # send image to tensor, resize to canonical dimensions, and normalize pytorch_transform_list = [ diff --git a/beast/models/base.py b/beast/models/base.py index d33ca7b..8f11a7d 100644 --- a/beast/models/base.py +++ b/beast/models/base.py @@ -113,7 +113,17 @@ def evaluate_batch( # multi-GPU training. Performance overhead was found negligible. # log overall supervised loss - self.log(f'{stage}_loss', loss, prog_bar=True, sync_dist=True) + # For training: log on step; for validation: log on epoch + on_step = (stage == 'train') + on_epoch = True # Always log on epoch for both train and val + self.log( + f'{stage}_loss', + loss, + prog_bar=True, + sync_dist=True, + on_step=on_step, + on_epoch=on_epoch + ) # log individual supervised losses for log_dict in log_list: self.log( @@ -121,6 +131,8 @@ def evaluate_batch( log_dict['value'].to(self.device), prog_bar=log_dict.get('prog_bar', False), sync_dist=True, + on_step=on_step, + on_epoch=on_epoch ) return loss diff --git a/beast/models/perceptual.py b/beast/models/perceptual.py new file mode 100644 index 0000000..eef92c3 --- /dev/null +++ b/beast/models/perceptual.py @@ -0,0 +1,50 @@ +import torch +import torchvision +from torch import nn +from typing import Any +# https://github.com/MLReproHub/SMAE/blob/main/src/loss/perceptual.py + + +class Perceptual(nn.Module): + def __init__(self, *, network: nn.Module, criterion: nn.Module): + """Initialize perceptual loss module. + + Parameters + ---------- + network: feature extractor that maps input images to feature tensors + criterion: loss function applied to extracted features (e.g. MSELoss) + """ + super(Perceptual, self).__init__() + self.net = network + self.criterion = criterion + self.sigmoid = nn.Sigmoid() + + def forward(self, x_hat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + x_hat_features = self.sigmoid(self.net(x_hat)) + x_features = self.sigmoid(self.net(x)) + loss = self.criterion(x_hat_features, x_features) + return loss + + +class AlexPerceptual(Perceptual): + def __init__(self, *, device: str | torch.device, **kwargs: Any): + """Perceptual loss using pretrained AlexNet features [Pihlgren et al. 2020]. + + Extracts features from the first five layers of AlexNet (pretrained on ImageNet) + and computes loss between reconstructed and target feature maps. + + Parameters + ---------- + device: device to run the feature extractor on (e.g. 'cuda', 'cpu') + **kwargs: passed to parent; must include criterion (e.g. nn.MSELoss()) + """ + # Load alex net pretrained on IN1k + alex_net = torchvision.models.alexnet(weights='IMAGENET1K_V1') + # Extract features after second relu activation + # Append sigmoid layer to normalize features + perceptual_net = alex_net.features[:5].to(device) + # Don't record gradients for the perceptual net, the gradients will still propagate through. + for parameter in perceptual_net.parameters(): + parameter.requires_grad = False + + super(AlexPerceptual, self).__init__(network=perceptual_net, **kwargs) diff --git a/beast/models/vits.py b/beast/models/vits.py index 0188ec0..f9644ee 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -14,6 +14,8 @@ from typeguard import typechecked from beast.models.base import BaseLightningModel +from beast.models.perceptual import AlexPerceptual +from beast import log_step class BatchNormProjector(nn.Module): @@ -43,12 +45,30 @@ def __init__(self, config): super().__init__(config) # Set up ViT architecture vit_mae_config = ViTMAEConfig(**config['model']['model_params']) - self.vit_mae = ViTMAE(vit_mae_config).from_pretrained("facebook/vit-mae-base") + + # Check if we should use pretrained weights or random initialization + use_pretrained = not config['model']['model_params'].get('random_init', False) + if use_pretrained: + log_step( + "Loading pretrained model from 'facebook/vit-mae-base' (this may take several minutes if downloading)...", level='debug') + log_step("Note: Model will be cached locally after first download", level='debug') + self.vit_mae = ViTMAE.from_pretrained("facebook/vit-mae-base", config=vit_mae_config) + else: + log_step("Using random initialization (random_init=True)", level='debug') + self.vit_mae = ViTMAE(vit_mae_config) + log_step("Randomly initialized model created", level='debug') + self.mask_ratio = config['model']['model_params']['mask_ratio'] + # perceptual loss + if config['model']['model_params'].get('use_perceptual_loss', False): + self.perceptual_loss = AlexPerceptual( + device=self.device, + criterion=nn.MSELoss() + ) # contrastive loss - if config['model']['model_params']['use_infoNCE']: + if config['model']['model_params'].get('use_infoNCE', False): self.proj = BatchNormProjector(vit_mae_config) - if self.config['model']['model_params']['temp_scale']: + if config['model']['model_params'].get('temp_scale', False): self.temperature = nn.Parameter(torch.ones([]) * np.log(1)) def forward( @@ -56,7 +76,11 @@ def forward( x: Float[torch.Tensor, 'batch channels img_height img_width'], ) -> Dict[str, torch.Tensor]: results_dict = self.vit_mae(pixel_values=x, return_recon=True) - if self.config['model']['model_params']['use_infoNCE']: + if self.config['model']['model_params'].get('use_perceptual_loss', False): + results_dict['perceptual_loss'] = self.perceptual_loss( + results_dict['reconstructions'], x + ) + if self.config['model']['model_params'].get('use_infoNCE', False): cls_token = results_dict['latents'][:, 0, :] proj_hidden = self.proj(cls_token) # normalize projection @@ -85,10 +109,19 @@ def compute_loss( {'name': f'{stage}_mse', 'value': mse_loss.clone()} ] loss = mse_loss - if self.config['model']['model_params']['use_infoNCE']: + if self.config['model']['model_params'].get('use_perceptual_loss', False): + perceptual_loss = kwargs['perceptual_loss'] + log_list.append({ + 'name': f'{stage}_perceptual', + 'value': perceptual_loss.clone() + }) + loss += self.config['model']['model_params'].get( + 'lambda_perceptual', 10.0 + ) * perceptual_loss + if self.config['model']['model_params'].get('use_infoNCE', False): z = kwargs['z'] sim_matrix = z @ z.T - if self.config['model']['model_params']['temp_scale']: + if self.config['model']['model_params'].get('temp_scale', False): sim_matrix /= self.temperature.exp() loss_dict = batch_wise_contrastive_loss(sim_matrix) loss_dict['infoNCE_loss'] *= self.config['model']['model_params']['infoNCE_weight'] @@ -119,9 +152,8 @@ def predict_step(self, batch_dict: dict, batch_idx: int) -> dict: class ViTMAE(ViTMAEForPreTraining): - # Overriding the forward method to return the latent and loss - # This is used for training and inference - # Huggingface Transformer library + """ViT-MAE for masked autoencoding. Returns latents, reconstructions, and MSE loss.""" + def forward( self, pixel_values: torch.Tensor, @@ -177,10 +209,12 @@ def forward( logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels) loss = self.forward_loss(pixel_values, logits, mask) + if return_recon: return { 'latents': latent, 'loss': loss, + 'mse_loss': loss, 'reconstructions': self.unpatchify(logits), } return { diff --git a/beast/train.py b/beast/train.py index ee851c9..32ac8df 100644 --- a/beast/train.py +++ b/beast/train.py @@ -12,6 +12,7 @@ import beast from beast.data.augmentations import imgaug_pipeline +from beast import log_step from beast.data.datamodules import BaseDataModule from beast.data.datasets import BaseDataset @@ -48,15 +49,22 @@ def train(config: dict, model, output_dir: str | Path): # Only print from rank 0 if rank_zero_only.rank == 0: + log_step("Entering train() function", level='debug') print(f'output directory: {output_dir}') print(f'model type: {type(model)}') # reset all seeds + if rank_zero_only.rank == 0: + log_step("Resetting seeds", level='debug') reset_seeds(seed=0) # record beast version + if rank_zero_only.rank == 0: + log_step("Recording beast version", level='debug') config['model']['beast_version'] = beast.version + if rank_zero_only.rank == 0: + log_step("Printing config", level='debug') pretty_print_config(config) # ---------------------------------------------------------------------------------- @@ -64,11 +72,15 @@ def train(config: dict, model, output_dir: str | Path): # ---------------------------------------------------------------------------------- # imgaug transform + if rank_zero_only.rank == 0: + log_step("Setting up imgaug pipeline", level='debug') pipe_params = config.get('training', {}).get('imgaug', 'none') if isinstance(pipe_params, str): from beast.data.augmentations import expand_imgaug_str_to_dict pipe_params = expand_imgaug_str_to_dict(pipe_params) imgaug_pipeline_ = imgaug_pipeline(pipe_params) + if rank_zero_only.rank == 0: + log_step("Imgaug pipeline created", level='debug') # dataset dataset = BaseDataset( @@ -77,6 +89,8 @@ def train(config: dict, model, output_dir: str | Path): ) # datamodule; breaks up dataset into train/val/test + if rank_zero_only.rank == 0: + log_step("Creating BaseDataModule", level='debug') datamodule = BaseDataModule( dataset=dataset, train_batch_size=config['training']['train_batch_size'], @@ -88,8 +102,12 @@ def train(config: dict, model, output_dir: str | Path): val_probability=config['training'].get('val_probability', 0.05), seed=config['training']['seed'], ) + if rank_zero_only.rank == 0: + log_step("BaseDataModule created", level='debug') # update number of training steps (for learning rate scheduler with step information) + if rank_zero_only.rank == 0: + log_step("Calculating training steps", level='debug') num_epochs = config['training']['num_epochs'] steps_per_epoch = int(np.ceil( len(datamodule.train_dataset) @@ -99,6 +117,9 @@ def train(config: dict, model, output_dir: str | Path): )) model.config['optimizer']['steps_per_epoch'] = steps_per_epoch model.config['optimizer']['total_steps'] = steps_per_epoch * num_epochs + if rank_zero_only.rank == 0: + log_step( + f"Training steps calculated: {steps_per_epoch} steps/epoch, {num_epochs} epochs", level='debug') # ---------------------------------------------------------------------------------- # Save configuration in output directory @@ -106,23 +127,35 @@ def train(config: dict, model, output_dir: str | Path): # Done before training; files will exist even if script dies prematurely. # save config file + if rank_zero_only.rank == 0: + log_step(f"Saving config to {output_dir}", level='debug') output_dir.mkdir(parents=True, exist_ok=True) dest_config_file = Path(output_dir) / 'config.yaml' with open(dest_config_file, 'w') as file: yaml.dump(config, file) + if rank_zero_only.rank == 0: + log_step("Config saved", level='debug') # ---------------------------------------------------------------------------------- # Set up and run training # ---------------------------------------------------------------------------------- # logger + if rank_zero_only.rank == 0: + log_step("Creating TensorBoardLogger", level='debug') logger = pl.loggers.TensorBoardLogger('tb_logs', name='') + if rank_zero_only.rank == 0: + log_step("TensorBoardLogger created", level='debug') # early stopping, learning rate monitoring, model checkpointing, backbone unfreezing + if rank_zero_only.rank == 0: + log_step("Setting up callbacks", level='debug') callbacks = get_callbacks( lr_monitor=True, ckpt_every_n_epochs=config['training'].get('ckpt_every_n_epochs', None), ) + if rank_zero_only.rank == 0: + log_step(f"Callbacks created: {len(callbacks)} callbacks", level='debug') # initialize to Trainer defaults. Note max_steps defaults to -1. min_epochs = config['training']['num_epochs'] @@ -134,6 +167,12 @@ def train(config: dict, model, output_dir: str | Path): else: use_distributed_sampler = True + if rank_zero_only.rank == 0: + log_step("Creating PyTorch Lightning Trainer", level='debug') + log_step(f" - accelerator: gpu", level='debug') + log_step(f" - devices: {config['training']['num_gpus']}", level='debug') + log_step(f" - num_nodes: {config['training']['num_nodes']}", level='debug') + log_step(f" - max_epochs: {max_epochs}", level='debug') trainer = pl.Trainer( accelerator='gpu', devices=config['training']['num_gpus'], @@ -148,9 +187,15 @@ def train(config: dict, model, output_dir: str | Path): sync_batchnorm=True, use_distributed_sampler=use_distributed_sampler, ) + if rank_zero_only.rank == 0: + log_step("Trainer created", level='debug') # train model! + if rank_zero_only.rank == 0: + log_step("About to call trainer.fit() - this may hang here if there are issues with data loading or GPU setup", level='debug') trainer.fit(model=model, datamodule=datamodule) + if rank_zero_only.rank == 0: + log_step("trainer.fit() completed", level='debug') # when devices > 0, lightning creates a process per device. # kill processes other than the main process, otherwise they all go forward. diff --git a/configs/vit.yaml b/configs/vit.yaml index 5cfde2a..38c8154 100644 --- a/configs/vit.yaml +++ b/configs/vit.yaml @@ -29,6 +29,8 @@ model: random_init: False # use random initialization instead of pretrained weights use_infoNCE: False # use InfoNCE loss infoNCE_weight: 0.03 # weight for InfoNCE loss + use_perceptual_loss: False # use perceptual loss (AlexNet features) + lambda_perceptual: 10.0 # weight for perceptual loss # training configuration training: diff --git a/tests/models/test_perceptual.py b/tests/models/test_perceptual.py new file mode 100644 index 0000000..3a077b0 --- /dev/null +++ b/tests/models/test_perceptual.py @@ -0,0 +1,45 @@ +import torch +from beast.models.perceptual import AlexPerceptual, Perceptual + + +def test_perceptual_forward(): + """Test base Perceptual class forward pass with a simple mock network.""" + # Mock network: Conv2d that preserves spatial dimensions for AlexNet-like feature output + mock_net = torch.nn.Sequential( + torch.nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + torch.nn.ReLU(inplace=True), + ) + criterion = torch.nn.MSELoss() + perceptual = Perceptual(network=mock_net, criterion=criterion) + x_hat = torch.randn((5, 3, 224, 224)) + x = torch.randn((5, 3, 224, 224)) + loss = perceptual(x_hat, x) + assert isinstance(loss, torch.Tensor) + assert loss.ndim == 0 + assert loss.item() >= 0 + + +def test_alex_perceptual_forward(): + """Test AlexPerceptual forward pass with pretrained AlexNet features.""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + criterion = torch.nn.MSELoss() + perceptual = AlexPerceptual(device=device, criterion=criterion) + x_hat = torch.randn((5, 3, 224, 224), device=device) + x = torch.randn((5, 3, 224, 224), device=device) + loss = perceptual(x_hat, x) + assert isinstance(loss, torch.Tensor) + assert loss.ndim == 0 + assert loss.item() >= 0 + + +def test_alex_perceptual_different_inputs_produce_different_loss(): + """Test that different inputs produce different loss values.""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + criterion = torch.nn.MSELoss() + perceptual = AlexPerceptual(device=device, criterion=criterion) + x = torch.randn((2, 3, 224, 224), device=device) + loss_same = perceptual(x, x) + assert loss_same.item() < 1e-5 + x_hat = torch.randn((2, 3, 224, 224), device=device) + loss_diff = perceptual(x_hat, x) + assert loss_diff.item() > 0 diff --git a/tests/models/test_vits.py b/tests/models/test_vits.py index e4e0da9..db4760a 100644 --- a/tests/models/test_vits.py +++ b/tests/models/test_vits.py @@ -39,3 +39,10 @@ def test_vit_autoencoder_contrastive_integration(config_vit, run_model_test): config = copy.deepcopy(config_vit) config['model']['model_params']['use_infoNCE'] = True run_model_test(config=config) + + +def test_alex_perceptual_integration(config_vit, run_model_test): + """Test ViT autoencoder with AlexNet perceptual loss enabled.""" + config = copy.deepcopy(config_vit) + config['model']['model_params']['use_perceptual_loss'] = True + run_model_test(config=config)