From 898d8d496c695bbfd69b7be32a6c19f81c64b9e8 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Sat, 22 Nov 2025 13:56:19 -0600 Subject: [PATCH 01/28] turned off contrastive loss --- configs/vit.yaml | 2 + configs/vit_perceptual.yaml | 65 +++++++++++++++++++++ test_scripts/run_besat_resume_perceptual.sh | 60 +++++++++++++++++++ 3 files changed, 127 insertions(+) create mode 100644 configs/vit_perceptual.yaml create mode 100644 test_scripts/run_besat_resume_perceptual.sh diff --git a/configs/vit.yaml b/configs/vit.yaml index 5cfde2a..38c8154 100644 --- a/configs/vit.yaml +++ b/configs/vit.yaml @@ -29,6 +29,8 @@ model: random_init: False # use random initialization instead of pretrained weights use_infoNCE: False # use InfoNCE loss infoNCE_weight: 0.03 # weight for InfoNCE loss + use_perceptual_loss: False # use perceptual loss (AlexNet features) + lambda_perceptual: 10.0 # weight for perceptual loss # training configuration training: diff --git a/configs/vit_perceptual.yaml b/configs/vit_perceptual.yaml new file mode 100644 index 0000000..6e955d1 --- /dev/null +++ b/configs/vit_perceptual.yaml @@ -0,0 +1,65 @@ +# model configuration +model: + seed: 0 + checkpoint: null # load weights from checkpoint + model_class: vit + model_params: + hidden_size: 768 + num_hidden_layers: 12 + num_attention_heads: 12 + intermediate_size: 3072 + hidden_act: "gelu" + hidden_dropout_prob: 0.0 + attention_probs_dropout_prob: 0.0 + initializer_range: 0.02 + layer_norm_eps: 1.e-12 + image_size: 224 # usually 224 + patch_size: 16 # default is 16, we use large patch size + num_channels: 3 # 3 for RGB + qkv_bias: True + decoder_num_attention_heads: 16 + decoder_hidden_size: 512 + decoder_num_hidden_layers: 8 + decoder_intermediate_size: 2048 + mask_ratio: 0.75 # 0 for no masking, usually 0.75 (MAE) + norm_pix_loss: False + + embed_size: 768 # projected embedding size, used for contrastive learning + temp_scale: False # temperature scaling for contrastive loss + random_init: False # use random initialization instead of pretrained weights + use_infoNCE: False # use InfoNCE loss + infoNCE_weight: 0.03 # weight for InfoNCE loss + use_perceptual_loss: True # use perceptual loss (AlexNet features) + lambda_perceptual: 10.0 # weight for perceptual loss + +# training configuration +training: + seed: 0 + imgaug: default # default | top-down + train_batch_size: 128 # per GPU + val_batch_size: 1024 + test_batch_size: 128 + num_epochs: 800 + num_workers: 8 # Number of CPU workers for the DataLoader + num_gpus: 1 + num_nodes: 1 + # frequency to log training metrics + log_every_n_steps: 10 + # frequency to log validation metrics + check_val_every_n_epoch: 5 + ckpt_every_n_epochs: 200 + +# optimizer configuration +optimizer: + type: Adam + accumulate_grad_batches: 1 + lr: 5.e-5 + wd: 0.05 + warmup_pct: 0.15 # cosine/linear + gamma: 0.95 # step + div_factor: 10 # cosine + scheduler: cosine # step/cosine/linear + +# data configuration +data: + data_dir: /PATH/TO/DATA \ No newline at end of file diff --git a/test_scripts/run_besat_resume_perceptual.sh b/test_scripts/run_besat_resume_perceptual.sh new file mode 100644 index 0000000..0c74223 --- /dev/null +++ b/test_scripts/run_besat_resume_perceptual.sh @@ -0,0 +1,60 @@ +#!/bin/bash +#SBATCH -A bfsr-delta-gpu +#SBATCH -p gpuA40x4,gpuA100x4,gpuA40x4-preempt,gpuA100x4-preempt +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=100G +#SBATCH -t 08:00:00 +#SBATCH -J beast_train +#SBATCH -o /work/nvme/bfsr/xdai3/runs/beast_train_%j.out +#SBATCH -e /work/nvme/bfsr/xdai3/runs/beast_train_%j.err + + +# --- Setup environment --- +source ~/.bashrc +module load ffmpeg +conda activate beast +cd /u/xdai3/beast + +# Set multiprocessing temp directory to a more stable location (avoid /tmp cleanup issues) +export TMPDIR="/work/nvme/bfsr/xdai3/tmp/${SLURM_JOB_ID:-$USER}" +mkdir -p "$TMPDIR" +echo "TMPDIR set to: $TMPDIR" + +# --- Define paths --- +CONFIG="configs/vit_perceptual.yaml" +DATA="/work/nvme/bfsr/xdai3/raw_data/beast/test_video1" +CHECKPOINT="/work/nvme/bfsr/xdai3/runs/beast_train_13711940/tb_logs/version_0/checkpoints/epoch=244-step=2695-best.ckpt" + +# Define unique output directory per job (using Slurm job name + ID) +OUTPUT_DIR="/work/nvme/bfsr/xdai3/runs/${SLURM_JOB_NAME}_${SLURM_JOB_ID}" +mkdir -p "$OUTPUT_DIR" + +echo "---------------------------------------" +echo "Job name: $SLURM_JOB_NAME" +echo "Job ID: $SLURM_JOB_ID" +echo "Running on node(s): $SLURM_NODELIST" +echo "Output directory: $OUTPUT_DIR" +echo "---------------------------------------" + +# --- Run BEAST --- +echo "[$(date +'%Y-%m-%d %H:%M:%S')] Starting BEAST training..." + +if [ -f "$CHECKPOINT" ]; then + echo "[$(date +'%Y-%m-%d %H:%M:%S')] Found checkpoint: $CHECKPOINT" + echo "[$(date +'%Y-%m-%d %H:%M:%S')] Resuming BEAST training from checkpoint..." + echo "[$(date +'%Y-%m-%d %H:%M:%S')] About to call: beast train --config \"$CONFIG\" --data \"$DATA\" --checkpoint \"$CHECKPOINT\" --output \"$OUTPUT_DIR\"" + # Note: --checkpoint argument may not be supported, checking if it causes issues + beast train --config "$CONFIG" --data "$DATA" --checkpoint "$CHECKPOINT" --output "$OUTPUT_DIR" 2>&1 | tee "$OUTPUT_DIR/training_output.log" +else + echo "[$(date +'%Y-%m-%d %H:%M:%S')] No checkpoint found. Starting new training run." + echo "[$(date +'%Y-%m-%d %H:%M:%S')] CONFIG=$CONFIG DATA=$DATA OUTPUT_DIR=$OUTPUT_DIR" + echo "[$(date +'%Y-%m-%d %H:%M:%S')] About to call: beast train --config \"$CONFIG\" --data \"$DATA\" --output \"$OUTPUT_DIR\"" + beast train --config "$CONFIG" --data "$DATA" --output "$OUTPUT_DIR" 2>&1 | tee "$OUTPUT_DIR/training_output.log" +fi + +echo "[$(date +'%Y-%m-%d %H:%M:%S')] BEAST training completed." + +conda deactivate From 857c0b5bfa3dbda70952b613140c93ef705c9c41 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 18 Nov 2025 16:16:08 -0600 Subject: [PATCH 02/28] added perceptual loss --- beast/__init__.py | 15 +++-- beast/api/model.py | 10 ++++ beast/cli/commands/train.py | 29 ++++++++++ beast/data/datamodules.py | 7 +++ beast/data/datasets.py | 20 ++++++- beast/models/base.py | 14 ++++- beast/models/combined.py | 54 ++++++++++++++++++ beast/models/perceptual.py | 36 ++++++++++++ beast/models/vits.py | 107 +++++++++++++++++++++++++++++++++--- beast/train.py | 55 ++++++++++++++++++ configs/vit.yaml | 2 +- 11 files changed, 332 insertions(+), 17 deletions(-) create mode 100644 beast/models/combined.py create mode 100644 beast/models/perceptual.py diff --git a/beast/__init__.py b/beast/__init__.py index 3b95ecd..f30be8e 100644 --- a/beast/__init__.py +++ b/beast/__init__.py @@ -26,14 +26,17 @@ def __get_package_version() -> str: # Fall back on getting it from a local pyproject.toml. # This works in a development environment where the # package has not been installed from a distribution. - import warnings + try: + import warnings + import toml - import toml + warnings.warn('beast not pip-installed, getting version from pyproject.toml.') - warnings.warn('beast not pip-installed, getting version from pyproject.toml.') - - pyproject_toml_file = Path(__file__).parent.parent / 'pyproject.toml' - __package_version = toml.load(pyproject_toml_file)['project']['version'] + pyproject_toml_file = Path(__file__).parent.parent / 'pyproject.toml' + __package_version = toml.load(pyproject_toml_file)['project']['version'] + except (ImportError, FileNotFoundError, KeyError): + # If toml is not available or file doesn't exist, use a default version + __package_version = "dev" return __package_version diff --git a/beast/api/model.py b/beast/api/model.py index b8a13c3..ddd1e98 100644 --- a/beast/api/model.py +++ b/beast/api/model.py @@ -111,8 +111,18 @@ def from_config(cls, config_path: str | Path | dict): raise ValueError(f'Unknown model type: {model_type}') # Initialize the LightningModule + import time + def _log_step(msg): + timestamp = time.strftime('%Y-%m-%d %H:%M:%S') + print(f"[{timestamp}] MODEL DEBUG: {msg}", flush=True) + + _log_step(f"Creating {model_type} model instance") model_class = cls.MODEL_REGISTRY[model_type] + _log_step(f"About to call {model_class.__name__}.__init__() - this may take several minutes if downloading pretrained weights") + init_start = time.time() model = model_class(config) + init_duration = time.time() - init_start + _log_step(f"Model initialization completed in {init_duration:.2f} seconds") print(f'Initialized a {model_class} model') diff --git a/beast/cli/commands/train.py b/beast/cli/commands/train.py index 862e719..232b0d7 100644 --- a/beast/cli/commands/train.py +++ b/beast/cli/commands/train.py @@ -65,38 +65,65 @@ def register_parser(subparsers): def handle(args): """Handle the train command execution.""" + import time + def _log_step(msg): + timestamp = time.strftime('%Y-%m-%d %H:%M:%S') + print(f"[{timestamp}] CLI DEBUG: {msg}", flush=True) + _logger.info(msg) + + _log_step("Starting train command handler") + # Determine output directory + _log_step("Determining output directory") if not args.output: now = datetime.datetime.now() args.output = Path('runs').resolve() / now.strftime('%Y-%m-%d') / now.strftime('%H-%M-%S') args.output.mkdir(parents=True, exist_ok=True) + _log_step(f"Output directory: {args.output}") # Set up logging to the model directory + _log_step("Setting up model logging") model_log_handler = _setup_model_logging(args.output) + _log_step("Model logging set up") # try: # Load config + _log_step(f"Loading config from: {args.config}") from beast.io import load_config config = load_config(args.config) + _log_step("Config loaded") # Apply overrides if args.overrides: + _log_step("Applying config overrides") from beast.io import apply_config_overrides config = apply_config_overrides(config, args.overrides) + _log_step("Config overrides applied") # Override specific values from command line + _log_step("Applying command line overrides") if args.data: config['data']['data_dir'] = str(args.data) + _log_step(f"Data directory overridden to: {args.data}") if args.gpus is not None: config['training']['num_gpus'] = args.gpus + _log_step(f"Number of GPUs overridden to: {args.gpus}") if args.nodes is not None: config['training']['num_nodes'] = args.nodes + _log_step(f"Number of nodes overridden to: {args.nodes}") + + # Check for unsupported --checkpoint argument + if hasattr(args, 'checkpoint') and args.checkpoint: + _log_step(f"WARNING: --checkpoint argument provided but not supported: {args.checkpoint}") + _log_step("Checkpoint resuming is not currently implemented in the CLI") # Initialize model + _log_step("Initializing model from config") from beast.api.model import Model model = Model.from_config(config) + _log_step("Model initialized") # if args.resume: # train_kwargs['resume_from_checkpoint'] = args.resume @@ -106,7 +133,9 @@ def handle(args): _logger.info(f'Output directory: {args.output}') # Run training + _log_step("About to call model.train()") model.train(output_dir=args.output) + _log_step("model.train() completed") _logger.info(f'Training complete. Model saved to {args.output}') diff --git a/beast/data/datamodules.py b/beast/data/datamodules.py index 145204d..f0ba13c 100644 --- a/beast/data/datamodules.py +++ b/beast/data/datamodules.py @@ -1,6 +1,7 @@ """Data modules split a dataset into train, val, and test modules.""" import copy +import multiprocessing import os import lightning.pytorch as pl @@ -155,10 +156,12 @@ def train_dataloader(self) -> torch.utils.data.DataLoader: batch_size=None if self.use_sampler else self.train_batch_size, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, + pin_memory=True, # Helps with GPU transfer shuffle=True if not self.use_sampler else False, sampler=self.sampler if self.use_sampler else None, generator=torch.Generator().manual_seed(self.seed), collate_fn=contrastive_collate_fn if self.use_sampler else None, + multiprocessing_context=multiprocessing.get_context('spawn') if self.num_workers > 0 else None, # More stable on HPC ) def val_dataloader(self) -> torch.utils.data.DataLoader: @@ -167,6 +170,8 @@ def val_dataloader(self) -> torch.utils.data.DataLoader: batch_size=self.val_batch_size, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, + pin_memory=True, + multiprocessing_context=multiprocessing.get_context('spawn') if self.num_workers > 0 else None, ) def test_dataloader(self) -> torch.utils.data.DataLoader: @@ -174,6 +179,8 @@ def test_dataloader(self) -> torch.utils.data.DataLoader: self.test_dataset, batch_size=self.test_batch_size, num_workers=self.num_workers, + pin_memory=True, + multiprocessing_context='spawn' if self.num_workers > 0 else None, ) def full_labeled_dataloader(self) -> torch.utils.data.DataLoader: diff --git a/beast/data/datasets.py b/beast/data/datasets.py index 18aa56a..b03ea02 100644 --- a/beast/data/datasets.py +++ b/beast/data/datasets.py @@ -1,5 +1,6 @@ """Dataset objects store images and augmentation pipeline.""" +import time from pathlib import Path from typing import Callable @@ -11,6 +12,12 @@ from beast.data.types import ExampleDict + +def _debug_log(msg: str, flush: bool = True): + """Debug logging function with timestamp.""" + timestamp = time.strftime('%Y-%m-%d %H:%M:%S') + print(f"[{timestamp}] DATASET DEBUG: {msg}", flush=flush) + _IMAGENET_MEAN = [0.485, 0.456, 0.406] _IMAGENET_STD = [0.229, 0.224, 0.225] @@ -28,15 +35,26 @@ def __init__(self, data_dir: str | Path, imgaug_pipeline: Callable | None) -> No imgaug_transform: imgaug transform pipeline to apply to images """ + _debug_log(f"BaseDataset.__init__ called with data_dir: {data_dir}") self.data_dir = Path(data_dir) if not self.data_dir.is_dir(): raise ValueError(f'{self.data_dir} is not a directory') + _debug_log(f"Data directory exists: {self.data_dir}") self.imgaug_pipeline = imgaug_pipeline # collect ALL png files in data_dir - self.image_list = sorted(list(self.data_dir.rglob('*.png'))) + _debug_log(f"Starting to scan for PNG files in {self.data_dir} (this may take a while for large directories)...") + scan_start = time.time() + try: + self.image_list = sorted(list(self.data_dir.rglob('*.png'))) + scan_duration = time.time() - scan_start + _debug_log(f"Finished scanning. Found {len(self.image_list)} PNG files in {scan_duration:.2f} seconds") + except Exception as e: + _debug_log(f"ERROR during file scanning: {e}") + raise if len(self.image_list) == 0: raise ValueError(f'{self.data_dir} does not contain image data in png format') + _debug_log(f"BaseDataset initialization complete with {len(self.image_list)} images") # send image to tensor, resize to canonical dimensions, and normalize pytorch_transform_list = [ diff --git a/beast/models/base.py b/beast/models/base.py index d33ca7b..bb295a1 100644 --- a/beast/models/base.py +++ b/beast/models/base.py @@ -113,7 +113,17 @@ def evaluate_batch( # multi-GPU training. Performance overhead was found negligible. # log overall supervised loss - self.log(f'{stage}_loss', loss, prog_bar=True, sync_dist=True) + # For training: log on step; for validation: log on epoch + on_step = (stage == 'train') + on_epoch = True # Always log on epoch for both train and val + self.log( + f'{stage}_loss', + loss, + prog_bar=True, + sync_dist=True, + on_step=on_step, + on_epoch=on_epoch + ) # log individual supervised losses for log_dict in log_list: self.log( @@ -121,6 +131,8 @@ def evaluate_batch( log_dict['value'].to(self.device), prog_bar=log_dict.get('prog_bar', False), sync_dist=True, + on_step=on_step, + on_epoch=on_epoch ) return loss diff --git a/beast/models/combined.py b/beast/models/combined.py new file mode 100644 index 0000000..b129fbb --- /dev/null +++ b/beast/models/combined.py @@ -0,0 +1,54 @@ +from typing import Tuple + +import torch +from torch import Tensor, tensor, zeros +from torch import nn + +from loss.patch_wise import PatchWise +from loss.perceptual import SqueezePerceptual +# https://github.com/MLReproHub/SMAE/blob/main/src/loss/perceptual.p + +class CombinedLoss(nn.Module): + def __init__(self, pixel_criterion: nn.Module, perceptual_criterion: nn.Module, lambda_perceptual: float = 1.0, + calibration_steps: int = 10, device='cuda'): + super(CombinedLoss, self).__init__() + self.pixel_loss = PatchWise(criterion=pixel_criterion) + self.perceptual_loss = SqueezePerceptual(criterion=perceptual_criterion) + self.register_buffer('lambda_perceptual', tensor(lambda_perceptual)) + # 0 Calibration steps disables calibration + self.register_buffer('remaining_calibration_steps', tensor(calibration_steps)) + self.register_buffer('calibration_points', zeros((calibration_steps, 2))) + + @property + def w_pixel(self) -> float: + return self.pixel_scale.detach() + + @property + def w_perceptual(self) -> float: + return self.perceptual_scale.detach() * self.lambda_perceptual.detach() + + def forward(self, x_hat, x, patches_hat, patches) -> Tuple[Tensor, Tensor, Tensor]: + pixel_loss = self.pixel_loss(patches_hat, patches) + perceptual_loss = self.perceptual_loss(x_hat, x) + + # Calibrate the losses to be balanced during the first batches. + if self.remaining_calibration_steps > 0: + self.calibrate(pixel_loss, perceptual_loss) + + pixel_loss = self.pixel_scale * pixel_loss + perceptual_loss = self.lambda_perceptual * self.perceptual_scale * perceptual_loss + combined_loss = pixel_loss + perceptual_loss + + return combined_loss, pixel_loss, perceptual_loss + + def calibrate(self, initial_pixel_loss, initial_perceptual_loss): + # Set the latest calibration point + self.calibration_points[-self.remaining_calibration_steps] = 1 / tensor((initial_pixel_loss, + initial_perceptual_loss)) + c = self.calibration_points + # Update the scaling factors by taking the mean over all the populated calibration points + pixel_scale, perceptual_scale = c[c.nonzero(as_tuple=True)].reshape(-1, 2).mean(dim=0) + self.register_buffer('pixel_scale', pixel_scale) + self.register_buffer('perceptual_scale', perceptual_scale) + # Decrement the remaining calibration steps + self.register_buffer('remaining_calibration_steps', self.remaining_calibration_steps - 1) \ No newline at end of file diff --git a/beast/models/perceptual.py b/beast/models/perceptual.py new file mode 100644 index 0000000..c2aa1c9 --- /dev/null +++ b/beast/models/perceptual.py @@ -0,0 +1,36 @@ +import torchvision +from torch import nn +# https://github.com/MLReproHub/SMAE/blob/main/src/loss/perceptual.py + + +class Perceptual(nn.Module): + + def __init__(self, *, network, criterion): + super(Perceptual, self).__init__() + self.net = network + self.criterion = criterion + self.sigmoid = nn.Sigmoid() + + def forward(self, x_hat, x): + x_hat_features = self.sigmoid(self.net(x_hat)) + x_features = self.sigmoid(self.net(x)) + loss = self.criterion(x_hat_features, x_features) + return loss + + +class AlexPerceptual(Perceptual): + """ + Implements perceptual loss with a pre-trained alex net [Pihlgren et al. 2020] + """ + + def __init__(self, *, device, **kwargs): + # Load alex net pretrained on IN1k + alex_net = torchvision.models.alexnet(weights='IMAGENET1K_V1') + # Extract features after second relu activation + # Append sigmoid layer to normalize features + perceptual_net = alex_net.features[:5].to(device) + # Don't record gradients for the perceptual net, the gradients will still propagate through. + for parameter in perceptual_net.parameters(): + parameter.requires_grad = False + + super(AlexPerceptual, self).__init__(network=perceptual_net, **kwargs) \ No newline at end of file diff --git a/beast/models/vits.py b/beast/models/vits.py index 0188ec0..1649ad2 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -1,5 +1,6 @@ """Vision transformer autoencoder implementation.""" +import time from typing import Dict, Optional import numpy as np @@ -14,6 +15,13 @@ from typeguard import typechecked from beast.models.base import BaseLightningModel +from beast.models.perceptual import AlexPerceptual + + +def _debug_log(msg: str, flush: bool = True): + """Debug logging function with timestamp.""" + timestamp = time.strftime('%Y-%m-%d %H:%M:%S') + print(f"[{timestamp}] VIT DEBUG: {msg}", flush=flush) class BatchNormProjector(nn.Module): @@ -42,14 +50,52 @@ class VisionTransformer(BaseLightningModel): def __init__(self, config): super().__init__(config) # Set up ViT architecture + _debug_log("Creating ViTMAEConfig") vit_mae_config = ViTMAEConfig(**config['model']['model_params']) - self.vit_mae = ViTMAE(vit_mae_config).from_pretrained("facebook/vit-mae-base") + _debug_log("ViTMAEConfig created") + + # Get perceptual loss parameters from config + use_perceptual_loss = config['model']['model_params'].get('use_perceptual_loss', False) + lambda_perceptual = config['model']['model_params'].get('lambda_perceptual', 1.0) + device = config['model']['model_params'].get('device', 'cuda') + + if use_perceptual_loss: + _debug_log(f"Perceptual loss enabled with lambda={lambda_perceptual}") + + # Check if we should use pretrained weights or random initialization + use_pretrained = not config['model']['model_params'].get('random_init', False) + + if use_pretrained: + _debug_log("Loading pretrained model from 'facebook/vit-mae-base' (this may take several minutes if downloading)...") + _debug_log("Note: Model will be cached locally after first download") + load_start = time.time() + self.vit_mae = ViTMAE( + vit_mae_config, + use_perceptual_loss=use_perceptual_loss, + lambda_perceptual=lambda_perceptual, + device=device + ).from_pretrained("facebook/vit-mae-base") + load_duration = time.time() - load_start + _debug_log(f"Pretrained model loaded in {load_duration:.2f} seconds") + else: + _debug_log("Using random initialization (random_init=True)") + self.vit_mae = ViTMAE( + vit_mae_config, + use_perceptual_loss=use_perceptual_loss, + lambda_perceptual=lambda_perceptual, + device=device + ) + _debug_log("Randomly initialized model created") + self.mask_ratio = config['model']['model_params']['mask_ratio'] # contrastive loss if config['model']['model_params']['use_infoNCE']: + _debug_log("Setting up InfoNCE projection layer") self.proj = BatchNormProjector(vit_mae_config) if self.config['model']['model_params']['temp_scale']: self.temperature = nn.Parameter(torch.ones([]) * np.log(1)) + _debug_log("InfoNCE projection layer created") + _debug_log("VisionTransformer initialization complete") def forward( self, @@ -79,12 +125,31 @@ def compute_loss( **kwargs, ) -> tuple[torch.tensor, list[dict]]: assert 'loss' in kwargs, "Loss is not in the kwargs" - mse_loss = kwargs['loss'] + loss = kwargs['loss'] # add all losses here for logging + # Get MSE loss directly from model output if available, otherwise use combined loss + if 'mse_loss' in kwargs: + mse_loss = kwargs['mse_loss'] + else: + # Fallback: if perceptual loss is available, extract MSE by subtraction + if 'perceptual_loss' in kwargs: + perceptual_loss = kwargs['perceptual_loss'] + mse_loss = loss - self.vit_mae.lambda_perceptual * perceptual_loss + else: + mse_loss = loss + log_list = [ - {'name': f'{stage}_mse', 'value': mse_loss.clone()} + {'name': f'{stage}_mse', 'value': mse_loss.detach().clone(), 'prog_bar': True}, ] - loss = mse_loss + + if 'perceptual_loss' in kwargs: + perceptual_loss = kwargs['perceptual_loss'] + log_list.append({ + 'name': f'{stage}_perceptual', + 'value': perceptual_loss.detach().clone(), + 'prog_bar': True + }) + if self.config['model']['model_params']['use_infoNCE']: z = kwargs['z'] sim_matrix = z @ z.T @@ -94,11 +159,13 @@ def compute_loss( loss_dict['infoNCE_loss'] *= self.config['model']['model_params']['infoNCE_weight'] log_list.append({ 'name': f'{stage}_infoNCE', - 'value': loss_dict['infoNCE_loss'] + 'value': loss_dict['infoNCE_loss'].detach().clone(), + 'prog_bar': True }) log_list.append({ 'name': f'{stage}_infoNCE_percent_correct', - 'value': loss_dict['percent_correct'] + 'value': loss_dict['percent_correct'].detach().clone(), + 'prog_bar': False }) loss += loss_dict['infoNCE_loss'] return loss, log_list @@ -122,6 +189,17 @@ class ViTMAE(ViTMAEForPreTraining): # Overriding the forward method to return the latent and loss # This is used for training and inference # Huggingface Transformer library + def __init__(self, config, use_perceptual_loss: bool = False, lambda_perceptual: float = 1.0, device='cuda'): + super().__init__(config) + self.use_perceptual_loss = use_perceptual_loss + self.lambda_perceptual = lambda_perceptual + if use_perceptual_loss: + # Initialize AlexPerceptual with MSE criterion + self.perceptual_loss = AlexPerceptual( + device=device, + criterion=nn.MSELoss() + ) + def forward( self, pixel_values: torch.Tensor, @@ -176,13 +254,26 @@ def forward( decoder_outputs = self.decoder(latent, ids_restore) logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels) - loss = self.forward_loss(pixel_values, logits, mask) + mse_loss = self.forward_loss(pixel_values, logits, mask) + + # Compute perceptual loss if enabled and we have reconstructions + perceptual_loss_value = None + loss = mse_loss + if self.use_perceptual_loss and return_recon: + reconstructions = self.unpatchify(logits) + perceptual_loss_value = self.perceptual_loss(reconstructions, pixel_values) + loss = mse_loss + self.lambda_perceptual * perceptual_loss_value + if return_recon: - return { + result = { 'latents': latent, 'loss': loss, + 'mse_loss': mse_loss, 'reconstructions': self.unpatchify(logits), } + if perceptual_loss_value is not None: + result['perceptual_loss'] = perceptual_loss_value + return result return { 'latents': cls_latent, 'loss': loss, diff --git a/beast/train.py b/beast/train.py index ee851c9..b5e410e 100644 --- a/beast/train.py +++ b/beast/train.py @@ -1,6 +1,7 @@ import os import random import sys +import time from pathlib import Path import lightning.pytorch as pl @@ -16,6 +17,12 @@ from beast.data.datasets import BaseDataset +def _debug_log(msg: str, flush: bool = True): + """Debug logging function with timestamp.""" + timestamp = time.strftime('%Y-%m-%d %H:%M:%S') + print(f"[{timestamp}] DEBUG: {msg}", flush=flush) + + @typechecked def reset_seeds(seed: int = 0) -> None: os.environ["PYTHONHASHSEED"] = str(seed) @@ -48,15 +55,22 @@ def train(config: dict, model, output_dir: str | Path): # Only print from rank 0 if rank_zero_only.rank == 0: + _debug_log("Entering train() function") print(f'output directory: {output_dir}') print(f'model type: {type(model)}') # reset all seeds + if rank_zero_only.rank == 0: + _debug_log("Resetting seeds") reset_seeds(seed=0) # record beast version + if rank_zero_only.rank == 0: + _debug_log("Recording beast version") config['model']['beast_version'] = beast.version + if rank_zero_only.rank == 0: + _debug_log("Printing config") pretty_print_config(config) # ---------------------------------------------------------------------------------- @@ -64,19 +78,30 @@ def train(config: dict, model, output_dir: str | Path): # ---------------------------------------------------------------------------------- # imgaug transform + if rank_zero_only.rank == 0: + _debug_log("Setting up imgaug pipeline") pipe_params = config.get('training', {}).get('imgaug', 'none') if isinstance(pipe_params, str): from beast.data.augmentations import expand_imgaug_str_to_dict pipe_params = expand_imgaug_str_to_dict(pipe_params) imgaug_pipeline_ = imgaug_pipeline(pipe_params) + if rank_zero_only.rank == 0: + _debug_log("Imgaug pipeline created") # dataset + if rank_zero_only.rank == 0: + _debug_log(f"Creating BaseDataset with data_dir: {config['data']['data_dir']}") + _debug_log("WARNING: This may take a long time if data directory is large (scanning for PNG files)") dataset = BaseDataset( data_dir=config['data']['data_dir'], imgaug_pipeline=imgaug_pipeline_, ) + if rank_zero_only.rank == 0: + _debug_log(f"BaseDataset created. Found {len(dataset)} images") # datamodule; breaks up dataset into train/val/test + if rank_zero_only.rank == 0: + _debug_log("Creating BaseDataModule") datamodule = BaseDataModule( dataset=dataset, train_batch_size=config['training']['train_batch_size'], @@ -88,8 +113,12 @@ def train(config: dict, model, output_dir: str | Path): val_probability=config['training'].get('val_probability', 0.05), seed=config['training']['seed'], ) + if rank_zero_only.rank == 0: + _debug_log("BaseDataModule created") # update number of training steps (for learning rate scheduler with step information) + if rank_zero_only.rank == 0: + _debug_log("Calculating training steps") num_epochs = config['training']['num_epochs'] steps_per_epoch = int(np.ceil( len(datamodule.train_dataset) @@ -99,6 +128,8 @@ def train(config: dict, model, output_dir: str | Path): )) model.config['optimizer']['steps_per_epoch'] = steps_per_epoch model.config['optimizer']['total_steps'] = steps_per_epoch * num_epochs + if rank_zero_only.rank == 0: + _debug_log(f"Training steps calculated: {steps_per_epoch} steps/epoch, {num_epochs} epochs") # ---------------------------------------------------------------------------------- # Save configuration in output directory @@ -106,23 +137,35 @@ def train(config: dict, model, output_dir: str | Path): # Done before training; files will exist even if script dies prematurely. # save config file + if rank_zero_only.rank == 0: + _debug_log(f"Saving config to {output_dir}") output_dir.mkdir(parents=True, exist_ok=True) dest_config_file = Path(output_dir) / 'config.yaml' with open(dest_config_file, 'w') as file: yaml.dump(config, file) + if rank_zero_only.rank == 0: + _debug_log("Config saved") # ---------------------------------------------------------------------------------- # Set up and run training # ---------------------------------------------------------------------------------- # logger + if rank_zero_only.rank == 0: + _debug_log("Creating TensorBoardLogger") logger = pl.loggers.TensorBoardLogger('tb_logs', name='') + if rank_zero_only.rank == 0: + _debug_log("TensorBoardLogger created") # early stopping, learning rate monitoring, model checkpointing, backbone unfreezing + if rank_zero_only.rank == 0: + _debug_log("Setting up callbacks") callbacks = get_callbacks( lr_monitor=True, ckpt_every_n_epochs=config['training'].get('ckpt_every_n_epochs', None), ) + if rank_zero_only.rank == 0: + _debug_log(f"Callbacks created: {len(callbacks)} callbacks") # initialize to Trainer defaults. Note max_steps defaults to -1. min_epochs = config['training']['num_epochs'] @@ -134,6 +177,12 @@ def train(config: dict, model, output_dir: str | Path): else: use_distributed_sampler = True + if rank_zero_only.rank == 0: + _debug_log("Creating PyTorch Lightning Trainer") + _debug_log(f" - accelerator: gpu") + _debug_log(f" - devices: {config['training']['num_gpus']}") + _debug_log(f" - num_nodes: {config['training']['num_nodes']}") + _debug_log(f" - max_epochs: {max_epochs}") trainer = pl.Trainer( accelerator='gpu', devices=config['training']['num_gpus'], @@ -148,9 +197,15 @@ def train(config: dict, model, output_dir: str | Path): sync_batchnorm=True, use_distributed_sampler=use_distributed_sampler, ) + if rank_zero_only.rank == 0: + _debug_log("Trainer created") # train model! + if rank_zero_only.rank == 0: + _debug_log("About to call trainer.fit() - this may hang here if there are issues with data loading or GPU setup") trainer.fit(model=model, datamodule=datamodule) + if rank_zero_only.rank == 0: + _debug_log("trainer.fit() completed") # when devices > 0, lightning creates a process per device. # kill processes other than the main process, otherwise they all go forward. diff --git a/configs/vit.yaml b/configs/vit.yaml index 38c8154..16b906b 100644 --- a/configs/vit.yaml +++ b/configs/vit.yaml @@ -27,7 +27,7 @@ model: embed_size: 768 # projected embedding size, used for contrastive learning temp_scale: False # temperature scaling for contrastive loss random_init: False # use random initialization instead of pretrained weights - use_infoNCE: False # use InfoNCE loss + use_infoNCE: True # use InfoNCE loss infoNCE_weight: 0.03 # weight for InfoNCE loss use_perceptual_loss: False # use perceptual loss (AlexNet features) lambda_perceptual: 10.0 # weight for perceptual loss From ffee13d899fb62ceaa642111cceb7ac50a4a613c Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Sat, 22 Nov 2025 11:13:30 -0600 Subject: [PATCH 03/28] added perceptual loss to tensorboard --- beast/models/vits.py | 45 +++++++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/beast/models/vits.py b/beast/models/vits.py index 1649ad2..c3a4383 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -69,12 +69,17 @@ def __init__(self, config): _debug_log("Loading pretrained model from 'facebook/vit-mae-base' (this may take several minutes if downloading)...") _debug_log("Note: Model will be cached locally after first download") load_start = time.time() - self.vit_mae = ViTMAE( - vit_mae_config, - use_perceptual_loss=use_perceptual_loss, - lambda_perceptual=lambda_perceptual, - device=device - ).from_pretrained("facebook/vit-mae-base") + # Load pretrained weights first (from_pretrained creates a new instance) + self.vit_mae = ViTMAE.from_pretrained("facebook/vit-mae-base", config=vit_mae_config) + # Set perceptual loss settings after loading (from_pretrained doesn't preserve custom init params) + self.vit_mae.use_perceptual_loss = use_perceptual_loss + self.vit_mae.lambda_perceptual = lambda_perceptual + if use_perceptual_loss: + # Initialize perceptual loss module on the correct device + self.vit_mae.perceptual_loss = AlexPerceptual( + device=device, + criterion=nn.MSELoss() + ) load_duration = time.time() - load_start _debug_log(f"Pretrained model loaded in {load_duration:.2f} seconds") else: @@ -142,13 +147,27 @@ def compute_loss( {'name': f'{stage}_mse', 'value': mse_loss.detach().clone(), 'prog_bar': True}, ] - if 'perceptual_loss' in kwargs: - perceptual_loss = kwargs['perceptual_loss'] - log_list.append({ - 'name': f'{stage}_perceptual', - 'value': perceptual_loss.detach().clone(), - 'prog_bar': True - }) + # Always log perceptual loss if it's enabled in config + if self.config['model']['model_params'].get('use_perceptual_loss', False): + if 'perceptual_loss' in kwargs: + perceptual_loss = kwargs['perceptual_loss'] + # Ensure it's a tensor and on the correct device + if not isinstance(perceptual_loss, torch.Tensor): + perceptual_loss = torch.tensor(perceptual_loss, device=loss.device, dtype=loss.dtype) + log_list.append({ + 'name': f'{stage}_perceptual', + 'value': perceptual_loss.detach().clone(), + 'prog_bar': True + }) + else: + # This shouldn't happen if perceptual loss is properly initialized + # Log 0 as fallback (shouldn't occur with the fix above) + perceptual_loss_value = torch.tensor(0.0, device=loss.device, dtype=loss.dtype) + log_list.append({ + 'name': f'{stage}_perceptual', + 'value': perceptual_loss_value, + 'prog_bar': True + }) if self.config['model']['model_params']['use_infoNCE']: z = kwargs['z'] From 149fac7e81a4baa869d195e309a76e6b6dbddfcb Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Wed, 18 Feb 2026 11:34:55 -0600 Subject: [PATCH 04/28] delete perceptual config and resume_beast.sh --- configs/vit_perceptual.yaml | 65 --------------------- test_scripts/run_besat_resume_perceptual.sh | 60 ------------------- 2 files changed, 125 deletions(-) delete mode 100644 configs/vit_perceptual.yaml delete mode 100644 test_scripts/run_besat_resume_perceptual.sh diff --git a/configs/vit_perceptual.yaml b/configs/vit_perceptual.yaml deleted file mode 100644 index 6e955d1..0000000 --- a/configs/vit_perceptual.yaml +++ /dev/null @@ -1,65 +0,0 @@ -# model configuration -model: - seed: 0 - checkpoint: null # load weights from checkpoint - model_class: vit - model_params: - hidden_size: 768 - num_hidden_layers: 12 - num_attention_heads: 12 - intermediate_size: 3072 - hidden_act: "gelu" - hidden_dropout_prob: 0.0 - attention_probs_dropout_prob: 0.0 - initializer_range: 0.02 - layer_norm_eps: 1.e-12 - image_size: 224 # usually 224 - patch_size: 16 # default is 16, we use large patch size - num_channels: 3 # 3 for RGB - qkv_bias: True - decoder_num_attention_heads: 16 - decoder_hidden_size: 512 - decoder_num_hidden_layers: 8 - decoder_intermediate_size: 2048 - mask_ratio: 0.75 # 0 for no masking, usually 0.75 (MAE) - norm_pix_loss: False - - embed_size: 768 # projected embedding size, used for contrastive learning - temp_scale: False # temperature scaling for contrastive loss - random_init: False # use random initialization instead of pretrained weights - use_infoNCE: False # use InfoNCE loss - infoNCE_weight: 0.03 # weight for InfoNCE loss - use_perceptual_loss: True # use perceptual loss (AlexNet features) - lambda_perceptual: 10.0 # weight for perceptual loss - -# training configuration -training: - seed: 0 - imgaug: default # default | top-down - train_batch_size: 128 # per GPU - val_batch_size: 1024 - test_batch_size: 128 - num_epochs: 800 - num_workers: 8 # Number of CPU workers for the DataLoader - num_gpus: 1 - num_nodes: 1 - # frequency to log training metrics - log_every_n_steps: 10 - # frequency to log validation metrics - check_val_every_n_epoch: 5 - ckpt_every_n_epochs: 200 - -# optimizer configuration -optimizer: - type: Adam - accumulate_grad_batches: 1 - lr: 5.e-5 - wd: 0.05 - warmup_pct: 0.15 # cosine/linear - gamma: 0.95 # step - div_factor: 10 # cosine - scheduler: cosine # step/cosine/linear - -# data configuration -data: - data_dir: /PATH/TO/DATA \ No newline at end of file diff --git a/test_scripts/run_besat_resume_perceptual.sh b/test_scripts/run_besat_resume_perceptual.sh deleted file mode 100644 index 0c74223..0000000 --- a/test_scripts/run_besat_resume_perceptual.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -#SBATCH -A bfsr-delta-gpu -#SBATCH -p gpuA40x4,gpuA100x4,gpuA40x4-preempt,gpuA100x4-preempt -#SBATCH --nodes=1 -#SBATCH --ntasks=1 -#SBATCH --gpus-per-task=1 -#SBATCH --cpus-per-task=4 -#SBATCH --mem=100G -#SBATCH -t 08:00:00 -#SBATCH -J beast_train -#SBATCH -o /work/nvme/bfsr/xdai3/runs/beast_train_%j.out -#SBATCH -e /work/nvme/bfsr/xdai3/runs/beast_train_%j.err - - -# --- Setup environment --- -source ~/.bashrc -module load ffmpeg -conda activate beast -cd /u/xdai3/beast - -# Set multiprocessing temp directory to a more stable location (avoid /tmp cleanup issues) -export TMPDIR="/work/nvme/bfsr/xdai3/tmp/${SLURM_JOB_ID:-$USER}" -mkdir -p "$TMPDIR" -echo "TMPDIR set to: $TMPDIR" - -# --- Define paths --- -CONFIG="configs/vit_perceptual.yaml" -DATA="/work/nvme/bfsr/xdai3/raw_data/beast/test_video1" -CHECKPOINT="/work/nvme/bfsr/xdai3/runs/beast_train_13711940/tb_logs/version_0/checkpoints/epoch=244-step=2695-best.ckpt" - -# Define unique output directory per job (using Slurm job name + ID) -OUTPUT_DIR="/work/nvme/bfsr/xdai3/runs/${SLURM_JOB_NAME}_${SLURM_JOB_ID}" -mkdir -p "$OUTPUT_DIR" - -echo "---------------------------------------" -echo "Job name: $SLURM_JOB_NAME" -echo "Job ID: $SLURM_JOB_ID" -echo "Running on node(s): $SLURM_NODELIST" -echo "Output directory: $OUTPUT_DIR" -echo "---------------------------------------" - -# --- Run BEAST --- -echo "[$(date +'%Y-%m-%d %H:%M:%S')] Starting BEAST training..." - -if [ -f "$CHECKPOINT" ]; then - echo "[$(date +'%Y-%m-%d %H:%M:%S')] Found checkpoint: $CHECKPOINT" - echo "[$(date +'%Y-%m-%d %H:%M:%S')] Resuming BEAST training from checkpoint..." - echo "[$(date +'%Y-%m-%d %H:%M:%S')] About to call: beast train --config \"$CONFIG\" --data \"$DATA\" --checkpoint \"$CHECKPOINT\" --output \"$OUTPUT_DIR\"" - # Note: --checkpoint argument may not be supported, checking if it causes issues - beast train --config "$CONFIG" --data "$DATA" --checkpoint "$CHECKPOINT" --output "$OUTPUT_DIR" 2>&1 | tee "$OUTPUT_DIR/training_output.log" -else - echo "[$(date +'%Y-%m-%d %H:%M:%S')] No checkpoint found. Starting new training run." - echo "[$(date +'%Y-%m-%d %H:%M:%S')] CONFIG=$CONFIG DATA=$DATA OUTPUT_DIR=$OUTPUT_DIR" - echo "[$(date +'%Y-%m-%d %H:%M:%S')] About to call: beast train --config \"$CONFIG\" --data \"$DATA\" --output \"$OUTPUT_DIR\"" - beast train --config "$CONFIG" --data "$DATA" --output "$OUTPUT_DIR" 2>&1 | tee "$OUTPUT_DIR/training_output.log" -fi - -echo "[$(date +'%Y-%m-%d %H:%M:%S')] BEAST training completed." - -conda deactivate From efafbaa1c8112679cae56bece510e164b5864806 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Wed, 18 Feb 2026 13:15:52 -0600 Subject: [PATCH 05/28] delete combined.py --- beast/models/combined.py | 54 ---------------------------------------- 1 file changed, 54 deletions(-) delete mode 100644 beast/models/combined.py diff --git a/beast/models/combined.py b/beast/models/combined.py deleted file mode 100644 index b129fbb..0000000 --- a/beast/models/combined.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Tuple - -import torch -from torch import Tensor, tensor, zeros -from torch import nn - -from loss.patch_wise import PatchWise -from loss.perceptual import SqueezePerceptual -# https://github.com/MLReproHub/SMAE/blob/main/src/loss/perceptual.p - -class CombinedLoss(nn.Module): - def __init__(self, pixel_criterion: nn.Module, perceptual_criterion: nn.Module, lambda_perceptual: float = 1.0, - calibration_steps: int = 10, device='cuda'): - super(CombinedLoss, self).__init__() - self.pixel_loss = PatchWise(criterion=pixel_criterion) - self.perceptual_loss = SqueezePerceptual(criterion=perceptual_criterion) - self.register_buffer('lambda_perceptual', tensor(lambda_perceptual)) - # 0 Calibration steps disables calibration - self.register_buffer('remaining_calibration_steps', tensor(calibration_steps)) - self.register_buffer('calibration_points', zeros((calibration_steps, 2))) - - @property - def w_pixel(self) -> float: - return self.pixel_scale.detach() - - @property - def w_perceptual(self) -> float: - return self.perceptual_scale.detach() * self.lambda_perceptual.detach() - - def forward(self, x_hat, x, patches_hat, patches) -> Tuple[Tensor, Tensor, Tensor]: - pixel_loss = self.pixel_loss(patches_hat, patches) - perceptual_loss = self.perceptual_loss(x_hat, x) - - # Calibrate the losses to be balanced during the first batches. - if self.remaining_calibration_steps > 0: - self.calibrate(pixel_loss, perceptual_loss) - - pixel_loss = self.pixel_scale * pixel_loss - perceptual_loss = self.lambda_perceptual * self.perceptual_scale * perceptual_loss - combined_loss = pixel_loss + perceptual_loss - - return combined_loss, pixel_loss, perceptual_loss - - def calibrate(self, initial_pixel_loss, initial_perceptual_loss): - # Set the latest calibration point - self.calibration_points[-self.remaining_calibration_steps] = 1 / tensor((initial_pixel_loss, - initial_perceptual_loss)) - c = self.calibration_points - # Update the scaling factors by taking the mean over all the populated calibration points - pixel_scale, perceptual_scale = c[c.nonzero(as_tuple=True)].reshape(-1, 2).mean(dim=0) - self.register_buffer('pixel_scale', pixel_scale) - self.register_buffer('perceptual_scale', perceptual_scale) - # Decrement the remaining calibration steps - self.register_buffer('remaining_calibration_steps', self.remaining_calibration_steps - 1) \ No newline at end of file From 651a16b2e4eec30bbb80f6f66e3ba1c7faad1ac4 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Wed, 18 Feb 2026 13:26:27 -0600 Subject: [PATCH 06/28] add test_alex_perceptual_integration method to test perceptual integration --- tests/models/test_vits.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/models/test_vits.py b/tests/models/test_vits.py index e4e0da9..f67d491 100644 --- a/tests/models/test_vits.py +++ b/tests/models/test_vits.py @@ -39,3 +39,9 @@ def test_vit_autoencoder_contrastive_integration(config_vit, run_model_test): config = copy.deepcopy(config_vit) config['model']['model_params']['use_infoNCE'] = True run_model_test(config=config) + +def test_alex_perceptual_integration(config_vit, run_model_test): + """Test ViT autoencoder with AlexNet perceptual loss enabled.""" + config = copy.deepcopy(config_vit) + config['model']['model_params']['use_perceptual_loss'] = True + run_model_test(config=config) \ No newline at end of file From 8d4a52ff89f3673fffe61b532a3dd4fb81ef1535 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Wed, 18 Feb 2026 13:30:33 -0600 Subject: [PATCH 07/28] add test for perceptual.py --- tests/models/test_perceptual.py | 45 +++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 tests/models/test_perceptual.py diff --git a/tests/models/test_perceptual.py b/tests/models/test_perceptual.py new file mode 100644 index 0000000..8f2d1bb --- /dev/null +++ b/tests/models/test_perceptual.py @@ -0,0 +1,45 @@ +import torch +from beast.models.perceptual import AlexPerceptual, Perceptual + + +def test_perceptual_forward(): + """Test base Perceptual class forward pass with a simple mock network.""" + # Mock network: Conv2d that preserves spatial dimensions for AlexNet-like feature output + mock_net = torch.nn.Sequential( + torch.nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + torch.nn.ReLU(inplace=True), + ) + criterion = torch.nn.MSELoss() + perceptual = Perceptual(network=mock_net, criterion=criterion) + x_hat = torch.randn((5, 3, 224, 224)) + x = torch.randn((5, 3, 224, 224)) + loss = perceptual(x_hat, x) + assert isinstance(loss, torch.Tensor) + assert loss.ndim == 0 + assert loss.item() >= 0 + + +def test_alex_perceptual_forward(): + """Test AlexPerceptual forward pass with pretrained AlexNet features.""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + criterion = torch.nn.MSELoss() + perceptual = AlexPerceptual(device=device, criterion=criterion) + x_hat = torch.randn((5, 3, 224, 224), device=device) + x = torch.randn((5, 3, 224, 224), device=device) + loss = perceptual(x_hat, x) + assert isinstance(loss, torch.Tensor) + assert loss.ndim == 0 + assert loss.item() >= 0 + + +def test_alex_perceptual_different_inputs_produce_different_loss(): + """Test that different inputs produce different loss values.""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + criterion = torch.nn.MSELoss() + perceptual = AlexPerceptual(device=device, criterion=criterion) + x = torch.randn((2, 3, 224, 224), device=device) + loss_same = perceptual(x, x) + assert loss_same.item() < 1e-5 + x_hat = torch.randn((2, 3, 224, 224), device=device) + loss_diff = perceptual(x_hat, x) + assert loss_diff.item() > 0 \ No newline at end of file From 4a202aad56513fbcc03b88224b870dc7f1c3187c Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Sun, 22 Feb 2026 20:11:43 -0600 Subject: [PATCH 08/28] move _log_step to beast.log_step with level param --- beast/__init__.py | 33 +++++++++++++++++++- beast/api/model.py | 15 ++++----- beast/cli/commands/train.py | 45 ++++++++++++--------------- beast/data/datasets.py | 19 +++++------- beast/models/vits.py | 29 ++++++++---------- beast/train.py | 61 +++++++++++++++++-------------------- 6 files changed, 107 insertions(+), 95 deletions(-) diff --git a/beast/__init__.py b/beast/__init__.py index f30be8e..b518b3c 100644 --- a/beast/__init__.py +++ b/beast/__init__.py @@ -1,12 +1,43 @@ # Hacky way to get version from pypackage.toml. # Adapted from: https://github.com/python-poetry/poetry/issues/273#issuecomment-1877789967 import importlib.metadata +import time from pathlib import Path -from typing import Any +from typing import Any, Optional __package_version = "unknown" +def log_step( + msg: str, + level: Optional[str] = None, + flush: bool = True, + logger: Any = None, +) -> None: + """Unified logging function with optional level. + + Parameters + ---------- + msg : str + Message to log. + level : Optional[str] + Log level: None (plain timestamp + msg), 'info', or 'debug'. + flush : bool + Whether to flush stdout after printing. + logger : Any + If provided and level is 'info', also call logger.info(msg). + """ + timestamp = time.strftime('%Y-%m-%d %H:%M:%S') + if level == 'info': + print(f"[{timestamp}] INFO: {msg}", flush=flush) + if logger is not None: + logger.info(msg) + elif level == 'debug': + print(f"[{timestamp}] DEBUG: {msg}", flush=flush) + else: + print(f"[{timestamp}] {msg}", flush=flush) + + def __get_package_version() -> str: """Find the version of this package.""" diff --git a/beast/api/model.py b/beast/api/model.py index ddd1e98..42be61f 100644 --- a/beast/api/model.py +++ b/beast/api/model.py @@ -7,6 +7,7 @@ import yaml from typeguard import typechecked +import beast from beast.inference import predict_images, predict_video from beast.models.base import BaseLightningModel from beast.models.resnets import ResnetAutoencoder @@ -112,17 +113,17 @@ def from_config(cls, config_path: str | Path | dict): # Initialize the LightningModule import time - def _log_step(msg): - timestamp = time.strftime('%Y-%m-%d %H:%M:%S') - print(f"[{timestamp}] MODEL DEBUG: {msg}", flush=True) - - _log_step(f"Creating {model_type} model instance") + model_class = cls.MODEL_REGISTRY[model_type] - _log_step(f"About to call {model_class.__name__}.__init__() - this may take several minutes if downloading pretrained weights") + beast.log_step(f"Creating {model_type} model instance", level='debug') + beast.log_step( + f"About to call {model_class.__name__}.__init__() - this may take several minutes if downloading pretrained weights", + level='debug', + ) init_start = time.time() model = model_class(config) init_duration = time.time() - init_start - _log_step(f"Model initialization completed in {init_duration:.2f} seconds") + beast.log_step(f"Model initialization completed in {init_duration:.2f} seconds", level='debug') print(f'Initialized a {model_class} model') diff --git a/beast/cli/commands/train.py b/beast/cli/commands/train.py index 232b0d7..9582985 100644 --- a/beast/cli/commands/train.py +++ b/beast/cli/commands/train.py @@ -4,6 +4,7 @@ import logging from pathlib import Path +from beast import log_step from beast.cli.types import config_file, output_dir _logger = logging.getLogger('BEAST.CLI.TRAIN') @@ -65,65 +66,59 @@ def register_parser(subparsers): def handle(args): """Handle the train command execution.""" - import time - def _log_step(msg): - timestamp = time.strftime('%Y-%m-%d %H:%M:%S') - print(f"[{timestamp}] CLI DEBUG: {msg}", flush=True) - _logger.info(msg) - - _log_step("Starting train command handler") + log_step("Starting train command handler", level='info', logger=_logger) # Determine output directory - _log_step("Determining output directory") + log_step("Determining output directory", level='info', logger=_logger) if not args.output: now = datetime.datetime.now() args.output = Path('runs').resolve() / now.strftime('%Y-%m-%d') / now.strftime('%H-%M-%S') args.output.mkdir(parents=True, exist_ok=True) - _log_step(f"Output directory: {args.output}") + log_step(f"Output directory: {args.output}", level='info', logger=_logger) # Set up logging to the model directory - _log_step("Setting up model logging") + log_step("Setting up model logging", level='info', logger=_logger) model_log_handler = _setup_model_logging(args.output) - _log_step("Model logging set up") + log_step("Model logging set up", level='info', logger=_logger) # try: # Load config - _log_step(f"Loading config from: {args.config}") + log_step(f"Loading config from: {args.config}", level='info', logger=_logger) from beast.io import load_config config = load_config(args.config) - _log_step("Config loaded") + log_step("Config loaded", level='info', logger=_logger) # Apply overrides if args.overrides: - _log_step("Applying config overrides") + log_step("Applying config overrides", level='info', logger=_logger) from beast.io import apply_config_overrides config = apply_config_overrides(config, args.overrides) - _log_step("Config overrides applied") + log_step("Config overrides applied", level='info', logger=_logger) # Override specific values from command line - _log_step("Applying command line overrides") + log_step("Applying command line overrides", level='info', logger=_logger) if args.data: config['data']['data_dir'] = str(args.data) - _log_step(f"Data directory overridden to: {args.data}") + log_step(f"Data directory overridden to: {args.data}", level='info', logger=_logger) if args.gpus is not None: config['training']['num_gpus'] = args.gpus - _log_step(f"Number of GPUs overridden to: {args.gpus}") + log_step(f"Number of GPUs overridden to: {args.gpus}", level='info', logger=_logger) if args.nodes is not None: config['training']['num_nodes'] = args.nodes - _log_step(f"Number of nodes overridden to: {args.nodes}") + log_step(f"Number of nodes overridden to: {args.nodes}", level='info', logger=_logger) # Check for unsupported --checkpoint argument if hasattr(args, 'checkpoint') and args.checkpoint: - _log_step(f"WARNING: --checkpoint argument provided but not supported: {args.checkpoint}") - _log_step("Checkpoint resuming is not currently implemented in the CLI") + log_step(f"WARNING: --checkpoint argument provided but not supported: {args.checkpoint}", level='info', logger=_logger) + log_step("Checkpoint resuming is not currently implemented in the CLI", level='info', logger=_logger) # Initialize model - _log_step("Initializing model from config") + log_step("Initializing model from config", level='info', logger=_logger) from beast.api.model import Model model = Model.from_config(config) - _log_step("Model initialized") + log_step("Model initialized", level='info', logger=_logger) # if args.resume: # train_kwargs['resume_from_checkpoint'] = args.resume @@ -133,9 +128,9 @@ def _log_step(msg): _logger.info(f'Output directory: {args.output}') # Run training - _log_step("About to call model.train()") + log_step("About to call model.train()", level='info', logger=_logger) model.train(output_dir=args.output) - _log_step("model.train() completed") + log_step("model.train() completed", level='info', logger=_logger) _logger.info(f'Training complete. Model saved to {args.output}') diff --git a/beast/data/datasets.py b/beast/data/datasets.py index b03ea02..d93efbc 100644 --- a/beast/data/datasets.py +++ b/beast/data/datasets.py @@ -10,14 +10,9 @@ from torchvision import transforms from typeguard import typechecked +from beast import log_step from beast.data.types import ExampleDict - -def _debug_log(msg: str, flush: bool = True): - """Debug logging function with timestamp.""" - timestamp = time.strftime('%Y-%m-%d %H:%M:%S') - print(f"[{timestamp}] DATASET DEBUG: {msg}", flush=flush) - _IMAGENET_MEAN = [0.485, 0.456, 0.406] _IMAGENET_STD = [0.229, 0.224, 0.225] @@ -35,26 +30,26 @@ def __init__(self, data_dir: str | Path, imgaug_pipeline: Callable | None) -> No imgaug_transform: imgaug transform pipeline to apply to images """ - _debug_log(f"BaseDataset.__init__ called with data_dir: {data_dir}") + log_step(f"BaseDataset.__init__ called with data_dir: {data_dir}", level='debug') self.data_dir = Path(data_dir) if not self.data_dir.is_dir(): raise ValueError(f'{self.data_dir} is not a directory') - _debug_log(f"Data directory exists: {self.data_dir}") + log_step(f"Data directory exists: {self.data_dir}", level='debug') self.imgaug_pipeline = imgaug_pipeline # collect ALL png files in data_dir - _debug_log(f"Starting to scan for PNG files in {self.data_dir} (this may take a while for large directories)...") + log_step(f"Starting to scan for PNG files in {self.data_dir} (this may take a while for large directories)...", level='debug') scan_start = time.time() try: self.image_list = sorted(list(self.data_dir.rglob('*.png'))) scan_duration = time.time() - scan_start - _debug_log(f"Finished scanning. Found {len(self.image_list)} PNG files in {scan_duration:.2f} seconds") + log_step(f"Finished scanning. Found {len(self.image_list)} PNG files in {scan_duration:.2f} seconds", level='debug') except Exception as e: - _debug_log(f"ERROR during file scanning: {e}") + log_step(f"ERROR during file scanning: {e}", level='debug') raise if len(self.image_list) == 0: raise ValueError(f'{self.data_dir} does not contain image data in png format') - _debug_log(f"BaseDataset initialization complete with {len(self.image_list)} images") + log_step(f"BaseDataset initialization complete with {len(self.image_list)} images", level='debug') # send image to tensor, resize to canonical dimensions, and normalize pytorch_transform_list = [ diff --git a/beast/models/vits.py b/beast/models/vits.py index c3a4383..54d1aec 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -14,16 +14,11 @@ ) from typeguard import typechecked +from beast import log_step from beast.models.base import BaseLightningModel from beast.models.perceptual import AlexPerceptual -def _debug_log(msg: str, flush: bool = True): - """Debug logging function with timestamp.""" - timestamp = time.strftime('%Y-%m-%d %H:%M:%S') - print(f"[{timestamp}] VIT DEBUG: {msg}", flush=flush) - - class BatchNormProjector(nn.Module): def __init__(self, config): super().__init__() @@ -50,9 +45,9 @@ class VisionTransformer(BaseLightningModel): def __init__(self, config): super().__init__(config) # Set up ViT architecture - _debug_log("Creating ViTMAEConfig") + log_step("Creating ViTMAEConfig", level='debug') vit_mae_config = ViTMAEConfig(**config['model']['model_params']) - _debug_log("ViTMAEConfig created") + log_step("ViTMAEConfig created", level='debug') # Get perceptual loss parameters from config use_perceptual_loss = config['model']['model_params'].get('use_perceptual_loss', False) @@ -60,14 +55,14 @@ def __init__(self, config): device = config['model']['model_params'].get('device', 'cuda') if use_perceptual_loss: - _debug_log(f"Perceptual loss enabled with lambda={lambda_perceptual}") + log_step(f"Perceptual loss enabled with lambda={lambda_perceptual}", level='debug') # Check if we should use pretrained weights or random initialization use_pretrained = not config['model']['model_params'].get('random_init', False) if use_pretrained: - _debug_log("Loading pretrained model from 'facebook/vit-mae-base' (this may take several minutes if downloading)...") - _debug_log("Note: Model will be cached locally after first download") + log_step("Loading pretrained model from 'facebook/vit-mae-base' (this may take several minutes if downloading)...", level='debug') + log_step("Note: Model will be cached locally after first download", level='debug') load_start = time.time() # Load pretrained weights first (from_pretrained creates a new instance) self.vit_mae = ViTMAE.from_pretrained("facebook/vit-mae-base", config=vit_mae_config) @@ -81,26 +76,26 @@ def __init__(self, config): criterion=nn.MSELoss() ) load_duration = time.time() - load_start - _debug_log(f"Pretrained model loaded in {load_duration:.2f} seconds") + log_step(f"Pretrained model loaded in {load_duration:.2f} seconds", level='debug') else: - _debug_log("Using random initialization (random_init=True)") + log_step("Using random initialization (random_init=True)", level='debug') self.vit_mae = ViTMAE( vit_mae_config, use_perceptual_loss=use_perceptual_loss, lambda_perceptual=lambda_perceptual, device=device ) - _debug_log("Randomly initialized model created") + log_step("Randomly initialized model created", level='debug') self.mask_ratio = config['model']['model_params']['mask_ratio'] # contrastive loss if config['model']['model_params']['use_infoNCE']: - _debug_log("Setting up InfoNCE projection layer") + log_step("Setting up InfoNCE projection layer", level='debug') self.proj = BatchNormProjector(vit_mae_config) if self.config['model']['model_params']['temp_scale']: self.temperature = nn.Parameter(torch.ones([]) * np.log(1)) - _debug_log("InfoNCE projection layer created") - _debug_log("VisionTransformer initialization complete") + log_step("InfoNCE projection layer created", level='debug') + log_step("VisionTransformer initialization complete", level='debug') def forward( self, diff --git a/beast/train.py b/beast/train.py index b5e410e..94d561d 100644 --- a/beast/train.py +++ b/beast/train.py @@ -13,16 +13,11 @@ import beast from beast.data.augmentations import imgaug_pipeline +from beast import log_step from beast.data.datamodules import BaseDataModule from beast.data.datasets import BaseDataset -def _debug_log(msg: str, flush: bool = True): - """Debug logging function with timestamp.""" - timestamp = time.strftime('%Y-%m-%d %H:%M:%S') - print(f"[{timestamp}] DEBUG: {msg}", flush=flush) - - @typechecked def reset_seeds(seed: int = 0) -> None: os.environ["PYTHONHASHSEED"] = str(seed) @@ -55,22 +50,22 @@ def train(config: dict, model, output_dir: str | Path): # Only print from rank 0 if rank_zero_only.rank == 0: - _debug_log("Entering train() function") + log_step("Entering train() function", level='debug') print(f'output directory: {output_dir}') print(f'model type: {type(model)}') # reset all seeds if rank_zero_only.rank == 0: - _debug_log("Resetting seeds") + log_step("Resetting seeds", level='debug') reset_seeds(seed=0) # record beast version if rank_zero_only.rank == 0: - _debug_log("Recording beast version") + log_step("Recording beast version", level='debug') config['model']['beast_version'] = beast.version if rank_zero_only.rank == 0: - _debug_log("Printing config") + log_step("Printing config", level='debug') pretty_print_config(config) # ---------------------------------------------------------------------------------- @@ -79,29 +74,29 @@ def train(config: dict, model, output_dir: str | Path): # imgaug transform if rank_zero_only.rank == 0: - _debug_log("Setting up imgaug pipeline") + log_step("Setting up imgaug pipeline", level='debug') pipe_params = config.get('training', {}).get('imgaug', 'none') if isinstance(pipe_params, str): from beast.data.augmentations import expand_imgaug_str_to_dict pipe_params = expand_imgaug_str_to_dict(pipe_params) imgaug_pipeline_ = imgaug_pipeline(pipe_params) if rank_zero_only.rank == 0: - _debug_log("Imgaug pipeline created") + log_step("Imgaug pipeline created", level='debug') # dataset if rank_zero_only.rank == 0: - _debug_log(f"Creating BaseDataset with data_dir: {config['data']['data_dir']}") - _debug_log("WARNING: This may take a long time if data directory is large (scanning for PNG files)") + log_step(f"Creating BaseDataset with data_dir: {config['data']['data_dir']}", level='debug') + log_step("WARNING: This may take a long time if data directory is large (scanning for PNG files)", level='debug') dataset = BaseDataset( data_dir=config['data']['data_dir'], imgaug_pipeline=imgaug_pipeline_, ) if rank_zero_only.rank == 0: - _debug_log(f"BaseDataset created. Found {len(dataset)} images") + log_step(f"BaseDataset created. Found {len(dataset)} images", level='debug') # datamodule; breaks up dataset into train/val/test if rank_zero_only.rank == 0: - _debug_log("Creating BaseDataModule") + log_step("Creating BaseDataModule", level='debug') datamodule = BaseDataModule( dataset=dataset, train_batch_size=config['training']['train_batch_size'], @@ -114,11 +109,11 @@ def train(config: dict, model, output_dir: str | Path): seed=config['training']['seed'], ) if rank_zero_only.rank == 0: - _debug_log("BaseDataModule created") + log_step("BaseDataModule created", level='debug') # update number of training steps (for learning rate scheduler with step information) if rank_zero_only.rank == 0: - _debug_log("Calculating training steps") + log_step("Calculating training steps", level='debug') num_epochs = config['training']['num_epochs'] steps_per_epoch = int(np.ceil( len(datamodule.train_dataset) @@ -129,7 +124,7 @@ def train(config: dict, model, output_dir: str | Path): model.config['optimizer']['steps_per_epoch'] = steps_per_epoch model.config['optimizer']['total_steps'] = steps_per_epoch * num_epochs if rank_zero_only.rank == 0: - _debug_log(f"Training steps calculated: {steps_per_epoch} steps/epoch, {num_epochs} epochs") + log_step(f"Training steps calculated: {steps_per_epoch} steps/epoch, {num_epochs} epochs", level='debug') # ---------------------------------------------------------------------------------- # Save configuration in output directory @@ -138,13 +133,13 @@ def train(config: dict, model, output_dir: str | Path): # save config file if rank_zero_only.rank == 0: - _debug_log(f"Saving config to {output_dir}") + log_step(f"Saving config to {output_dir}", level='debug') output_dir.mkdir(parents=True, exist_ok=True) dest_config_file = Path(output_dir) / 'config.yaml' with open(dest_config_file, 'w') as file: yaml.dump(config, file) if rank_zero_only.rank == 0: - _debug_log("Config saved") + log_step("Config saved", level='debug') # ---------------------------------------------------------------------------------- # Set up and run training @@ -152,20 +147,20 @@ def train(config: dict, model, output_dir: str | Path): # logger if rank_zero_only.rank == 0: - _debug_log("Creating TensorBoardLogger") + log_step("Creating TensorBoardLogger", level='debug') logger = pl.loggers.TensorBoardLogger('tb_logs', name='') if rank_zero_only.rank == 0: - _debug_log("TensorBoardLogger created") + log_step("TensorBoardLogger created", level='debug') # early stopping, learning rate monitoring, model checkpointing, backbone unfreezing if rank_zero_only.rank == 0: - _debug_log("Setting up callbacks") + log_step("Setting up callbacks", level='debug') callbacks = get_callbacks( lr_monitor=True, ckpt_every_n_epochs=config['training'].get('ckpt_every_n_epochs', None), ) if rank_zero_only.rank == 0: - _debug_log(f"Callbacks created: {len(callbacks)} callbacks") + log_step(f"Callbacks created: {len(callbacks)} callbacks", level='debug') # initialize to Trainer defaults. Note max_steps defaults to -1. min_epochs = config['training']['num_epochs'] @@ -178,11 +173,11 @@ def train(config: dict, model, output_dir: str | Path): use_distributed_sampler = True if rank_zero_only.rank == 0: - _debug_log("Creating PyTorch Lightning Trainer") - _debug_log(f" - accelerator: gpu") - _debug_log(f" - devices: {config['training']['num_gpus']}") - _debug_log(f" - num_nodes: {config['training']['num_nodes']}") - _debug_log(f" - max_epochs: {max_epochs}") + log_step("Creating PyTorch Lightning Trainer", level='debug') + log_step(f" - accelerator: gpu", level='debug') + log_step(f" - devices: {config['training']['num_gpus']}", level='debug') + log_step(f" - num_nodes: {config['training']['num_nodes']}", level='debug') + log_step(f" - max_epochs: {max_epochs}", level='debug') trainer = pl.Trainer( accelerator='gpu', devices=config['training']['num_gpus'], @@ -198,14 +193,14 @@ def train(config: dict, model, output_dir: str | Path): use_distributed_sampler=use_distributed_sampler, ) if rank_zero_only.rank == 0: - _debug_log("Trainer created") + log_step("Trainer created", level='debug') # train model! if rank_zero_only.rank == 0: - _debug_log("About to call trainer.fit() - this may hang here if there are issues with data loading or GPU setup") + log_step("About to call trainer.fit() - this may hang here if there are issues with data loading or GPU setup", level='debug') trainer.fit(model=model, datamodule=datamodule) if rank_zero_only.rank == 0: - _debug_log("trainer.fit() completed") + log_step("trainer.fit() completed", level='debug') # when devices > 0, lightning creates a process per device. # kill processes other than the main process, otherwise they all go forward. From 7d6e2dcc716b3df8b6e989a6b80dd5de1e4cca20 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 13:52:15 -0600 Subject: [PATCH 09/28] move import to the top of the file --- beast/api/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/beast/api/model.py b/beast/api/model.py index 42be61f..fd7835b 100644 --- a/beast/api/model.py +++ b/beast/api/model.py @@ -1,5 +1,6 @@ import contextlib import os +import time from pathlib import Path from typing import Any @@ -112,8 +113,6 @@ def from_config(cls, config_path: str | Path | dict): raise ValueError(f'Unknown model type: {model_type}') # Initialize the LightningModule - import time - model_class = cls.MODEL_REGISTRY[model_type] beast.log_step(f"Creating {model_type} model instance", level='debug') beast.log_step( From dc0c0b0bc71f8bf606f033dfbb4f3832f1197204 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 13:57:28 -0600 Subject: [PATCH 10/28] import log_step from beast --- beast/api/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/beast/api/model.py b/beast/api/model.py index fd7835b..876861e 100644 --- a/beast/api/model.py +++ b/beast/api/model.py @@ -8,12 +8,12 @@ import yaml from typeguard import typechecked -import beast from beast.inference import predict_images, predict_video from beast.models.base import BaseLightningModel from beast.models.resnets import ResnetAutoencoder from beast.models.vits import VisionTransformer from beast.train import train +from beast import log_step # TODO: Replace with contextlib.chdir in python 3.11. @@ -114,15 +114,15 @@ def from_config(cls, config_path: str | Path | dict): # Initialize the LightningModule model_class = cls.MODEL_REGISTRY[model_type] - beast.log_step(f"Creating {model_type} model instance", level='debug') - beast.log_step( + log_step(f"Creating {model_type} model instance", level='debug') + log_step( f"About to call {model_class.__name__}.__init__() - this may take several minutes if downloading pretrained weights", level='debug', ) init_start = time.time() model = model_class(config) init_duration = time.time() - init_start - beast.log_step(f"Model initialization completed in {init_duration:.2f} seconds", level='debug') + log_step(f"Model initialization completed in {init_duration:.2f} seconds", level='debug') print(f'Initialized a {model_class} model') From aa95a7d150170d00316db8d6f623183539f7a865 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 14:08:47 -0600 Subject: [PATCH 11/28] remove import time --- beast/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/beast/train.py b/beast/train.py index 94d561d..640ef94 100644 --- a/beast/train.py +++ b/beast/train.py @@ -1,7 +1,6 @@ import os import random import sys -import time from pathlib import Path import lightning.pytorch as pl From d056f02145f4b2168be37178eb18ec5c87aff5b3 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 14:27:56 -0600 Subject: [PATCH 12/28] remove extra logging --- beast/train.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/beast/train.py b/beast/train.py index 640ef94..22c02bc 100644 --- a/beast/train.py +++ b/beast/train.py @@ -83,15 +83,10 @@ def train(config: dict, model, output_dir: str | Path): log_step("Imgaug pipeline created", level='debug') # dataset - if rank_zero_only.rank == 0: - log_step(f"Creating BaseDataset with data_dir: {config['data']['data_dir']}", level='debug') - log_step("WARNING: This may take a long time if data directory is large (scanning for PNG files)", level='debug') dataset = BaseDataset( data_dir=config['data']['data_dir'], imgaug_pipeline=imgaug_pipeline_, ) - if rank_zero_only.rank == 0: - log_step(f"BaseDataset created. Found {len(dataset)} images", level='debug') # datamodule; breaks up dataset into train/val/test if rank_zero_only.rank == 0: From 23c307776a368ef97265f0847095c9e8cfdd2fab Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 14:32:44 -0600 Subject: [PATCH 13/28] use log_step --- beast/cli/commands/train.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/beast/cli/commands/train.py b/beast/cli/commands/train.py index 9582985..080a7e8 100644 --- a/beast/cli/commands/train.py +++ b/beast/cli/commands/train.py @@ -122,17 +122,15 @@ def handle(args): # if args.resume: # train_kwargs['resume_from_checkpoint'] = args.resume - - _logger.info(f'Training {type(model.model)} model') - _logger.info(f'Data directory: {args.data}') - _logger.info(f'Output directory: {args.output}') + log_step(f'Training {type(model.model)} model', level='info', logger=_logger) + log_step(f'Data directory: {args.data}', level='info', logger=_logger) + log_step(f'Output directory: {args.output}', level='info', logger=_logger) # Run training log_step("About to call model.train()", level='info', logger=_logger) model.train(output_dir=args.output) log_step("model.train() completed", level='info', logger=_logger) - - _logger.info(f'Training complete. Model saved to {args.output}') + log_step(f'Training complete. Model saved to {args.output}', level='info', logger=_logger) # except Exception as e: From ea3db8241d15821433aa02019a403b8bc376819f Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 14:36:12 -0600 Subject: [PATCH 14/28] move log_step inside the try block --- beast/data/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beast/data/datasets.py b/beast/data/datasets.py index d93efbc..c011520 100644 --- a/beast/data/datasets.py +++ b/beast/data/datasets.py @@ -38,9 +38,9 @@ def __init__(self, data_dir: str | Path, imgaug_pipeline: Callable | None) -> No self.imgaug_pipeline = imgaug_pipeline # collect ALL png files in data_dir - log_step(f"Starting to scan for PNG files in {self.data_dir} (this may take a while for large directories)...", level='debug') scan_start = time.time() try: + log_step(f"Starting to scan for PNG files in {self.data_dir} (this may take a while for large directories)...", level='debug') self.image_list = sorted(list(self.data_dir.rglob('*.png'))) scan_duration = time.time() - scan_start log_step(f"Finished scanning. Found {len(self.image_list)} PNG files in {scan_duration:.2f} seconds", level='debug') From d13c5bb09c82b75c6a6b653c586fb232f7bcbb44 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 14:38:37 -0600 Subject: [PATCH 15/28] change log_step level to "error" --- beast/__init__.py | 7 ++++++- beast/data/datasets.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/beast/__init__.py b/beast/__init__.py index b518b3c..fc43098 100644 --- a/beast/__init__.py +++ b/beast/__init__.py @@ -21,11 +21,12 @@ def log_step( msg : str Message to log. level : Optional[str] - Log level: None (plain timestamp + msg), 'info', or 'debug'. + Log level: None (plain timestamp + msg), 'info', 'debug', or 'error'. flush : bool Whether to flush stdout after printing. logger : Any If provided and level is 'info', also call logger.info(msg). + If provided and level is 'error', also call logger.error(msg). """ timestamp = time.strftime('%Y-%m-%d %H:%M:%S') if level == 'info': @@ -34,6 +35,10 @@ def log_step( logger.info(msg) elif level == 'debug': print(f"[{timestamp}] DEBUG: {msg}", flush=flush) + elif level == 'error': + print(f"[{timestamp}] ERROR: {msg}", flush=flush) + if logger is not None: + logger.error(msg) else: print(f"[{timestamp}] {msg}", flush=flush) diff --git a/beast/data/datasets.py b/beast/data/datasets.py index c011520..80dff31 100644 --- a/beast/data/datasets.py +++ b/beast/data/datasets.py @@ -45,7 +45,7 @@ def __init__(self, data_dir: str | Path, imgaug_pipeline: Callable | None) -> No scan_duration = time.time() - scan_start log_step(f"Finished scanning. Found {len(self.image_list)} PNG files in {scan_duration:.2f} seconds", level='debug') except Exception as e: - log_step(f"ERROR during file scanning: {e}", level='debug') + log_step(f"ERROR during file scanning: {e}", level='error') raise if len(self.image_list) == 0: raise ValueError(f'{self.data_dir} does not contain image data in png format') From f0280403011f1f6ce4407223d9d91fda5b82de5e Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 15:36:10 -0600 Subject: [PATCH 16/28] modify the docstring --- beast/__init__.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/beast/__init__.py b/beast/__init__.py index fc43098..10f9b7c 100644 --- a/beast/__init__.py +++ b/beast/__init__.py @@ -18,15 +18,11 @@ def log_step( Parameters ---------- - msg : str - Message to log. - level : Optional[str] - Log level: None (plain timestamp + msg), 'info', 'debug', or 'error'. - flush : bool - Whether to flush stdout after printing. - logger : Any - If provided and level is 'info', also call logger.info(msg). - If provided and level is 'error', also call logger.error(msg). + msg: message to log + level: None (plain timestamp + msg), 'info', 'debug', or 'error' + flush: whether to flush stdout after printing + logger: if provided and level is 'info', also call logger.info(msg); + if provided and level is 'error', also call logger.error(msg) """ timestamp = time.strftime('%Y-%m-%d %H:%M:%S') if level == 'info': From 02110be747b4de4a99a4ec39158140a4a15a6801 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 15:41:41 -0600 Subject: [PATCH 17/28] add docstrings and types --- beast/models/perceptual.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/beast/models/perceptual.py b/beast/models/perceptual.py index c2aa1c9..4304286 100644 --- a/beast/models/perceptual.py +++ b/beast/models/perceptual.py @@ -1,17 +1,25 @@ +import torch import torchvision from torch import nn +from typing import Any # https://github.com/MLReproHub/SMAE/blob/main/src/loss/perceptual.py class Perceptual(nn.Module): - - def __init__(self, *, network, criterion): + def __init__(self, *, network: nn.Module, criterion: nn.Module): + """Initialize perceptual loss module. + + Parameters + ---------- + network: feature extractor that maps input images to feature tensors + criterion: loss function applied to extracted features (e.g. MSELoss) + """ super(Perceptual, self).__init__() self.net = network self.criterion = criterion self.sigmoid = nn.Sigmoid() - def forward(self, x_hat, x): + def forward(self, x_hat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: x_hat_features = self.sigmoid(self.net(x_hat)) x_features = self.sigmoid(self.net(x)) loss = self.criterion(x_hat_features, x_features) @@ -19,11 +27,17 @@ def forward(self, x_hat, x): class AlexPerceptual(Perceptual): - """ - Implements perceptual loss with a pre-trained alex net [Pihlgren et al. 2020] - """ + def __init__(self, *, device: str | torch.device, **kwargs: Any): + """Perceptual loss using pretrained AlexNet features [Pihlgren et al. 2020]. + + Extracts features from the first five layers of AlexNet (pretrained on ImageNet) + and computes loss between reconstructed and target feature maps. - def __init__(self, *, device, **kwargs): + Parameters + ---------- + device: device to run the feature extractor on (e.g. 'cuda', 'cpu') + **kwargs: passed to parent; must include criterion (e.g. nn.MSELoss()) + """ # Load alex net pretrained on IN1k alex_net = torchvision.models.alexnet(weights='IMAGENET1K_V1') # Extract features after second relu activation From cdc7b34bf90af4ba2361b781299e0472f5b7c184 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 15:49:36 -0600 Subject: [PATCH 18/28] parameter extraction and logging infoNCE loss --- beast/models/vits.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/beast/models/vits.py b/beast/models/vits.py index 54d1aec..6913c1b 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -51,12 +51,19 @@ def __init__(self, config): # Get perceptual loss parameters from config use_perceptual_loss = config['model']['model_params'].get('use_perceptual_loss', False) - lambda_perceptual = config['model']['model_params'].get('lambda_perceptual', 1.0) + lambda_perceptual = config['model']['model_params'].get('lambda_perceptual', 10.0) device = config['model']['model_params'].get('device', 'cuda') - + + # Get InfoNCE loss parameters from config + use_infoNCE = config['model']['model_params'].get('use_infoNCE', False) + temp_scale = config['model']['model_params'].get('temp_scale', False) + if use_perceptual_loss: log_step(f"Perceptual loss enabled with lambda={lambda_perceptual}", level='debug') - + + if use_infoNCE: + log_step(f"InfoNCE loss enabled with temp_scale={temp_scale}", level='debug') + # Check if we should use pretrained weights or random initialization use_pretrained = not config['model']['model_params'].get('random_init', False) @@ -89,10 +96,10 @@ def __init__(self, config): self.mask_ratio = config['model']['model_params']['mask_ratio'] # contrastive loss - if config['model']['model_params']['use_infoNCE']: + if use_infoNCE: log_step("Setting up InfoNCE projection layer", level='debug') self.proj = BatchNormProjector(vit_mae_config) - if self.config['model']['model_params']['temp_scale']: + if temp_scale: self.temperature = nn.Parameter(torch.ones([]) * np.log(1)) log_step("InfoNCE projection layer created", level='debug') log_step("VisionTransformer initialization complete", level='debug') From ea88d7ecb18bb55e2aa8bd836364b115f7b7e415 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 15:51:35 -0600 Subject: [PATCH 19/28] use self.device --- beast/models/vits.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/beast/models/vits.py b/beast/models/vits.py index 6913c1b..3b1fe66 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -52,7 +52,6 @@ def __init__(self, config): # Get perceptual loss parameters from config use_perceptual_loss = config['model']['model_params'].get('use_perceptual_loss', False) lambda_perceptual = config['model']['model_params'].get('lambda_perceptual', 10.0) - device = config['model']['model_params'].get('device', 'cuda') # Get InfoNCE loss parameters from config use_infoNCE = config['model']['model_params'].get('use_infoNCE', False) @@ -79,7 +78,7 @@ def __init__(self, config): if use_perceptual_loss: # Initialize perceptual loss module on the correct device self.vit_mae.perceptual_loss = AlexPerceptual( - device=device, + device=self.device, criterion=nn.MSELoss() ) load_duration = time.time() - load_start @@ -90,7 +89,7 @@ def __init__(self, config): vit_mae_config, use_perceptual_loss=use_perceptual_loss, lambda_perceptual=lambda_perceptual, - device=device + device=self.device ) log_step("Randomly initialized model created", level='debug') From 84b7356f47c552bc6b3a921933249054af3636e4 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 15:54:50 -0600 Subject: [PATCH 20/28] remove fallback --- beast/models/vits.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/beast/models/vits.py b/beast/models/vits.py index 3b1fe66..dc344cb 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -131,19 +131,10 @@ def compute_loss( **kwargs, ) -> tuple[torch.tensor, list[dict]]: assert 'loss' in kwargs, "Loss is not in the kwargs" + assert 'mse_loss' in kwargs, "mse_loss must be provided by model outputs for logging" loss = kwargs['loss'] - # add all losses here for logging - # Get MSE loss directly from model output if available, otherwise use combined loss - if 'mse_loss' in kwargs: - mse_loss = kwargs['mse_loss'] - else: - # Fallback: if perceptual loss is available, extract MSE by subtraction - if 'perceptual_loss' in kwargs: - perceptual_loss = kwargs['perceptual_loss'] - mse_loss = loss - self.vit_mae.lambda_perceptual * perceptual_loss - else: - mse_loss = loss - + mse_loss = kwargs['mse_loss'] + log_list = [ {'name': f'{stage}_mse', 'value': mse_loss.detach().clone(), 'prog_bar': True}, ] From c626b7c4a878dc2149a383441882acc087bb0957 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 16:00:49 -0600 Subject: [PATCH 21/28] assert 'perceptual_loss' in kwargs --- beast/models/vits.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/beast/models/vits.py b/beast/models/vits.py index dc344cb..f301c8d 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -141,25 +141,15 @@ def compute_loss( # Always log perceptual loss if it's enabled in config if self.config['model']['model_params'].get('use_perceptual_loss', False): - if 'perceptual_loss' in kwargs: - perceptual_loss = kwargs['perceptual_loss'] - # Ensure it's a tensor and on the correct device - if not isinstance(perceptual_loss, torch.Tensor): - perceptual_loss = torch.tensor(perceptual_loss, device=loss.device, dtype=loss.dtype) - log_list.append({ - 'name': f'{stage}_perceptual', - 'value': perceptual_loss.detach().clone(), - 'prog_bar': True - }) - else: - # This shouldn't happen if perceptual loss is properly initialized - # Log 0 as fallback (shouldn't occur with the fix above) - perceptual_loss_value = torch.tensor(0.0, device=loss.device, dtype=loss.dtype) - log_list.append({ - 'name': f'{stage}_perceptual', - 'value': perceptual_loss_value, - 'prog_bar': True - }) + assert 'perceptual_loss' in kwargs, ( + "perceptual_loss must be in model outputs when use_perceptual_loss is enabled" + ) + perceptual_loss = kwargs['perceptual_loss'] + log_list.append({ + 'name': f'{stage}_perceptual', + 'value': perceptual_loss.detach().clone(), + 'prog_bar': True + }) if self.config['model']['model_params']['use_infoNCE']: z = kwargs['z'] From dea4e475f9d0f8c16b887c3a3dac710af24d1eb8 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 16:26:09 -0600 Subject: [PATCH 22/28] revert all changes in __init__ of VisionTransformer --- beast/models/vits.py | 86 ++++++++++++++------------------------------ 1 file changed, 27 insertions(+), 59 deletions(-) diff --git a/beast/models/vits.py b/beast/models/vits.py index f301c8d..d760fbc 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -70,31 +70,25 @@ def __init__(self, config): log_step("Loading pretrained model from 'facebook/vit-mae-base' (this may take several minutes if downloading)...", level='debug') log_step("Note: Model will be cached locally after first download", level='debug') load_start = time.time() - # Load pretrained weights first (from_pretrained creates a new instance) self.vit_mae = ViTMAE.from_pretrained("facebook/vit-mae-base", config=vit_mae_config) - # Set perceptual loss settings after loading (from_pretrained doesn't preserve custom init params) - self.vit_mae.use_perceptual_loss = use_perceptual_loss - self.vit_mae.lambda_perceptual = lambda_perceptual - if use_perceptual_loss: - # Initialize perceptual loss module on the correct device - self.vit_mae.perceptual_loss = AlexPerceptual( - device=self.device, - criterion=nn.MSELoss() - ) load_duration = time.time() - load_start log_step(f"Pretrained model loaded in {load_duration:.2f} seconds", level='debug') else: log_step("Using random initialization (random_init=True)", level='debug') - self.vit_mae = ViTMAE( - vit_mae_config, - use_perceptual_loss=use_perceptual_loss, - lambda_perceptual=lambda_perceptual, - device=self.device - ) + self.vit_mae = ViTMAE(vit_mae_config) log_step("Randomly initialized model created", level='debug') - + self.mask_ratio = config['model']['model_params']['mask_ratio'] - # contrastive loss + self.use_perceptual_loss = use_perceptual_loss + self.lambda_perceptual = lambda_perceptual + if use_perceptual_loss: + log_step("Setting up perceptual loss (AlexNet features)", level='debug') + self.perceptual_loss = AlexPerceptual( + device=self.device, + criterion=nn.MSELoss() + ) + log_step("Perceptual loss created", level='debug') + if use_infoNCE: log_step("Setting up InfoNCE projection layer", level='debug') self.proj = BatchNormProjector(vit_mae_config) @@ -108,6 +102,11 @@ def forward( x: Float[torch.Tensor, 'batch channels img_height img_width'], ) -> Dict[str, torch.Tensor]: results_dict = self.vit_mae(pixel_values=x, return_recon=True) + + if self.config['model']['model_params']['use_perceptual_loss']: + reconstructions = results_dict['reconstructions'] + results_dict['perceptual_loss'] = self.perceptual_loss(reconstructions, x) + if self.config['model']['model_params']['use_infoNCE']: cls_token = results_dict['latents'][:, 0, :] proj_hidden = self.proj(cls_token) @@ -139,12 +138,12 @@ def compute_loss( {'name': f'{stage}_mse', 'value': mse_loss.detach().clone(), 'prog_bar': True}, ] - # Always log perceptual loss if it's enabled in config - if self.config['model']['model_params'].get('use_perceptual_loss', False): + if self.config['model']['model_params']['use_perceptual_loss']: assert 'perceptual_loss' in kwargs, ( "perceptual_loss must be in model outputs when use_perceptual_loss is enabled" ) perceptual_loss = kwargs['perceptual_loss'] + loss = loss + self.lambda_perceptual * perceptual_loss log_list.append({ 'name': f'{stage}_perceptual', 'value': perceptual_loss.detach().clone(), @@ -187,20 +186,8 @@ def predict_step(self, batch_dict: dict, batch_idx: int) -> dict: class ViTMAE(ViTMAEForPreTraining): - # Overriding the forward method to return the latent and loss - # This is used for training and inference - # Huggingface Transformer library - def __init__(self, config, use_perceptual_loss: bool = False, lambda_perceptual: float = 1.0, device='cuda'): - super().__init__(config) - self.use_perceptual_loss = use_perceptual_loss - self.lambda_perceptual = lambda_perceptual - if use_perceptual_loss: - # Initialize AlexPerceptual with MSE criterion - self.perceptual_loss = AlexPerceptual( - device=device, - criterion=nn.MSELoss() - ) - + """ViT-MAE for masked autoencoding. Returns latents, reconstructions, and MSE loss.""" + def forward( self, pixel_values: torch.Tensor, @@ -212,7 +199,6 @@ def forward( return_latent: bool = False, return_recon: bool = False, ) -> Dict[str, torch.Tensor]: - # Setting default for return_dict based on the configuration return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (self.training or self.config.mask_ratio > 0) or return_recon: outputs = self.vit( @@ -225,16 +211,12 @@ def forward( ) latent = outputs.last_hidden_state else: - # use for fine-tuning, or inference - # mask_ratio = 0 embedding_output, mask, ids_restore = self.vit.embeddings(pixel_values) - embedding_output_ = embedding_output[:, 1:, :] # no cls token - # unshuffle the embedding output + embedding_output_ = embedding_output[:, 1:, :] index = ids_restore.unsqueeze(-1).repeat( 1, 1, embedding_output_.shape[2] ).to(embedding_output_.device) embedding_output_ = torch.gather(embedding_output_, dim=1, index=index) - # add cls token back embedding_output = torch.cat((embedding_output[:, :1, :], embedding_output_), dim=1) encoder_outputs = self.vit.encoder( embedding_output, @@ -243,38 +225,24 @@ def forward( sequence_output = encoder_outputs[0] latent = self.vit.layernorm(sequence_output) if not return_latent: - # return the cls token and 0 loss if not return_latent return latent[:, 0], 0 if return_latent: return latent - # extract cls latent - cls_latent = latent[:, 0] # shape (batch_size, hidden_size) + cls_latent = latent[:, 0] ids_restore = outputs.ids_restore mask = outputs.mask decoder_outputs = self.decoder(latent, ids_restore) logits = decoder_outputs.logits - # shape (batch_size, num_patches, patch_size*patch_size*num_channels) - mse_loss = self.forward_loss(pixel_values, logits, mask) - - # Compute perceptual loss if enabled and we have reconstructions - perceptual_loss_value = None - loss = mse_loss - if self.use_perceptual_loss and return_recon: - reconstructions = self.unpatchify(logits) - perceptual_loss_value = self.perceptual_loss(reconstructions, pixel_values) - loss = mse_loss + self.lambda_perceptual * perceptual_loss_value - + loss = self.forward_loss(pixel_values, logits, mask) + if return_recon: - result = { + return { 'latents': latent, 'loss': loss, - 'mse_loss': mse_loss, + 'mse_loss': loss, 'reconstructions': self.unpatchify(logits), } - if perceptual_loss_value is not None: - result['perceptual_loss'] = perceptual_loss_value - return result return { 'latents': cls_latent, 'loss': loss, From a4270da6552e6695afcd87edebb70f9d2b44d523 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 16:27:26 -0600 Subject: [PATCH 23/28] set use_infoNCE: False --- configs/vit.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/vit.yaml b/configs/vit.yaml index 16b906b..38c8154 100644 --- a/configs/vit.yaml +++ b/configs/vit.yaml @@ -27,7 +27,7 @@ model: embed_size: 768 # projected embedding size, used for contrastive learning temp_scale: False # temperature scaling for contrastive loss random_init: False # use random initialization instead of pretrained weights - use_infoNCE: True # use InfoNCE loss + use_infoNCE: False # use InfoNCE loss infoNCE_weight: 0.03 # weight for InfoNCE loss use_perceptual_loss: False # use perceptual loss (AlexNet features) lambda_perceptual: 10.0 # weight for perceptual loss From 15e7567193f4a2d69d206e214530853565fd4584 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Tue, 24 Feb 2026 16:34:22 -0600 Subject: [PATCH 24/28] revert changes to __get_package_version --- beast/__init__.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/beast/__init__.py b/beast/__init__.py index 10f9b7c..04449b7 100644 --- a/beast/__init__.py +++ b/beast/__init__.py @@ -58,17 +58,14 @@ def __get_package_version() -> str: # Fall back on getting it from a local pyproject.toml. # This works in a development environment where the # package has not been installed from a distribution. - try: - import warnings - import toml + import warnings - warnings.warn('beast not pip-installed, getting version from pyproject.toml.') + import toml - pyproject_toml_file = Path(__file__).parent.parent / 'pyproject.toml' - __package_version = toml.load(pyproject_toml_file)['project']['version'] - except (ImportError, FileNotFoundError, KeyError): - # If toml is not available or file doesn't exist, use a default version - __package_version = "dev" + warnings.warn('beast not pip-installed, getting version from pyproject.toml.') + + pyproject_toml_file = Path(__file__).parent.parent / 'pyproject.toml' + __package_version = toml.load(pyproject_toml_file)['project']['version'] return __package_version From 25bd8e619b5cd06611caaca03863bc8ad13ee123 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Wed, 25 Feb 2026 08:21:28 -0600 Subject: [PATCH 25/28] simplify VisionTransformer to match original style --- beast/models/vits.py | 95 +++++++++++--------------------------------- 1 file changed, 24 insertions(+), 71 deletions(-) diff --git a/beast/models/vits.py b/beast/models/vits.py index d760fbc..af38ffe 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -1,6 +1,5 @@ """Vision transformer autoencoder implementation.""" -import time from typing import Dict, Optional import numpy as np @@ -14,7 +13,6 @@ ) from typeguard import typechecked -from beast import log_step from beast.models.base import BaseLightningModel from beast.models.perceptual import AlexPerceptual @@ -45,69 +43,31 @@ class VisionTransformer(BaseLightningModel): def __init__(self, config): super().__init__(config) # Set up ViT architecture - log_step("Creating ViTMAEConfig", level='debug') vit_mae_config = ViTMAEConfig(**config['model']['model_params']) - log_step("ViTMAEConfig created", level='debug') - - # Get perceptual loss parameters from config - use_perceptual_loss = config['model']['model_params'].get('use_perceptual_loss', False) - lambda_perceptual = config['model']['model_params'].get('lambda_perceptual', 10.0) - - # Get InfoNCE loss parameters from config - use_infoNCE = config['model']['model_params'].get('use_infoNCE', False) - temp_scale = config['model']['model_params'].get('temp_scale', False) - - if use_perceptual_loss: - log_step(f"Perceptual loss enabled with lambda={lambda_perceptual}", level='debug') - - if use_infoNCE: - log_step(f"InfoNCE loss enabled with temp_scale={temp_scale}", level='debug') - - # Check if we should use pretrained weights or random initialization - use_pretrained = not config['model']['model_params'].get('random_init', False) - - if use_pretrained: - log_step("Loading pretrained model from 'facebook/vit-mae-base' (this may take several minutes if downloading)...", level='debug') - log_step("Note: Model will be cached locally after first download", level='debug') - load_start = time.time() - self.vit_mae = ViTMAE.from_pretrained("facebook/vit-mae-base", config=vit_mae_config) - load_duration = time.time() - load_start - log_step(f"Pretrained model loaded in {load_duration:.2f} seconds", level='debug') - else: - log_step("Using random initialization (random_init=True)", level='debug') - self.vit_mae = ViTMAE(vit_mae_config) - log_step("Randomly initialized model created", level='debug') - + self.vit_mae = ViTMAE.from_pretrained("facebook/vit-mae-base", config=vit_mae_config) self.mask_ratio = config['model']['model_params']['mask_ratio'] - self.use_perceptual_loss = use_perceptual_loss - self.lambda_perceptual = lambda_perceptual - if use_perceptual_loss: - log_step("Setting up perceptual loss (AlexNet features)", level='debug') + # perceptual loss + if config['model']['model_params'].get('use_perceptual_loss', False): self.perceptual_loss = AlexPerceptual( device=self.device, criterion=nn.MSELoss() ) - log_step("Perceptual loss created", level='debug') - - if use_infoNCE: - log_step("Setting up InfoNCE projection layer", level='debug') + # contrastive loss + if config['model']['model_params'].get('use_infoNCE', False): self.proj = BatchNormProjector(vit_mae_config) - if temp_scale: + if config['model']['model_params'].get('temp_scale', False): self.temperature = nn.Parameter(torch.ones([]) * np.log(1)) - log_step("InfoNCE projection layer created", level='debug') - log_step("VisionTransformer initialization complete", level='debug') def forward( self, x: Float[torch.Tensor, 'batch channels img_height img_width'], ) -> Dict[str, torch.Tensor]: results_dict = self.vit_mae(pixel_values=x, return_recon=True) - - if self.config['model']['model_params']['use_perceptual_loss']: - reconstructions = results_dict['reconstructions'] - results_dict['perceptual_loss'] = self.perceptual_loss(reconstructions, x) - - if self.config['model']['model_params']['use_infoNCE']: + if self.config['model']['model_params'].get('use_perceptual_loss', False): + results_dict['perceptual_loss'] = self.perceptual_loss( + results_dict['reconstructions'], x + ) + if self.config['model']['model_params'].get('use_infoNCE', False): cls_token = results_dict['latents'][:, 0, :] proj_hidden = self.proj(cls_token) # normalize projection @@ -130,42 +90,35 @@ def compute_loss( **kwargs, ) -> tuple[torch.tensor, list[dict]]: assert 'loss' in kwargs, "Loss is not in the kwargs" - assert 'mse_loss' in kwargs, "mse_loss must be provided by model outputs for logging" - loss = kwargs['loss'] - mse_loss = kwargs['mse_loss'] - + mse_loss = kwargs['loss'] + # add all losses here for logging log_list = [ - {'name': f'{stage}_mse', 'value': mse_loss.detach().clone(), 'prog_bar': True}, + {'name': f'{stage}_mse', 'value': mse_loss.clone()} ] - - if self.config['model']['model_params']['use_perceptual_loss']: - assert 'perceptual_loss' in kwargs, ( - "perceptual_loss must be in model outputs when use_perceptual_loss is enabled" - ) + loss = mse_loss + if self.config['model']['model_params'].get('use_perceptual_loss', False): perceptual_loss = kwargs['perceptual_loss'] - loss = loss + self.lambda_perceptual * perceptual_loss log_list.append({ 'name': f'{stage}_perceptual', - 'value': perceptual_loss.detach().clone(), - 'prog_bar': True + 'value': perceptual_loss.clone() }) - - if self.config['model']['model_params']['use_infoNCE']: + loss += self.config['model']['model_params'].get( + 'lambda_perceptual', 10.0 + ) * perceptual_loss + if self.config['model']['model_params'].get('use_infoNCE', False): z = kwargs['z'] sim_matrix = z @ z.T - if self.config['model']['model_params']['temp_scale']: + if self.config['model']['model_params'].get('temp_scale', False): sim_matrix /= self.temperature.exp() loss_dict = batch_wise_contrastive_loss(sim_matrix) loss_dict['infoNCE_loss'] *= self.config['model']['model_params']['infoNCE_weight'] log_list.append({ 'name': f'{stage}_infoNCE', - 'value': loss_dict['infoNCE_loss'].detach().clone(), - 'prog_bar': True + 'value': loss_dict['infoNCE_loss'] }) log_list.append({ 'name': f'{stage}_infoNCE_percent_correct', - 'value': loss_dict['percent_correct'].detach().clone(), - 'prog_bar': False + 'value': loss_dict['percent_correct'] }) loss += loss_dict['infoNCE_loss'] return loss, log_list From e5e40d33df3b3ca3a104315a746c9ee22b4c2f18 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Wed, 25 Feb 2026 12:33:12 -0600 Subject: [PATCH 26/28] keep comments --- beast/models/vits.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/beast/models/vits.py b/beast/models/vits.py index af38ffe..4bd92ab 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -152,6 +152,7 @@ def forward( return_latent: bool = False, return_recon: bool = False, ) -> Dict[str, torch.Tensor]: + # Setting default for return_dict based on the configuration return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (self.training or self.config.mask_ratio > 0) or return_recon: outputs = self.vit( @@ -164,12 +165,16 @@ def forward( ) latent = outputs.last_hidden_state else: + # use for fine-tuning, or inference + # mask_ratio = 0 embedding_output, mask, ids_restore = self.vit.embeddings(pixel_values) - embedding_output_ = embedding_output[:, 1:, :] + embedding_output_ = embedding_output[:, 1:, :] # no cls token + # unshuffle the embedding output index = ids_restore.unsqueeze(-1).repeat( 1, 1, embedding_output_.shape[2] ).to(embedding_output_.device) embedding_output_ = torch.gather(embedding_output_, dim=1, index=index) + # add cls token back embedding_output = torch.cat((embedding_output[:, :1, :], embedding_output_), dim=1) encoder_outputs = self.vit.encoder( embedding_output, @@ -178,15 +183,18 @@ def forward( sequence_output = encoder_outputs[0] latent = self.vit.layernorm(sequence_output) if not return_latent: + # return the cls token and 0 loss if not return_latent return latent[:, 0], 0 if return_latent: return latent - cls_latent = latent[:, 0] + # extract cls latent + cls_latent = latent[:, 0] # shape (batch_size, hidden_size) ids_restore = outputs.ids_restore mask = outputs.mask decoder_outputs = self.decoder(latent, ids_restore) logits = decoder_outputs.logits + # shape (batch_size, num_patches, patch_size*patch_size*num_channels) loss = self.forward_loss(pixel_values, logits, mask) if return_recon: From 78cdf8ee25d4ca44455a60dd8258fcc42736fb81 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Wed, 25 Feb 2026 12:44:49 -0600 Subject: [PATCH 27/28] keep the code for using a randomly initialized model --- beast/models/vits.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/beast/models/vits.py b/beast/models/vits.py index 4bd92ab..703c897 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -15,6 +15,7 @@ from beast.models.base import BaseLightningModel from beast.models.perceptual import AlexPerceptual +from beast import log_step class BatchNormProjector(nn.Module): @@ -44,7 +45,18 @@ def __init__(self, config): super().__init__(config) # Set up ViT architecture vit_mae_config = ViTMAEConfig(**config['model']['model_params']) - self.vit_mae = ViTMAE.from_pretrained("facebook/vit-mae-base", config=vit_mae_config) + + # Check if we should use pretrained weights or random initialization + use_pretrained = not config['model']['model_params'].get('random_init', False) + if use_pretrained: + log_step("Loading pretrained model from 'facebook/vit-mae-base' (this may take several minutes if downloading)...", level='debug') + log_step("Note: Model will be cached locally after first download", level='debug') + self.vit_mae = ViTMAE.from_pretrained("facebook/vit-mae-base", config=vit_mae_config) + else: + log_step("Using random initialization (random_init=True)", level='debug') + self.vit_mae = ViTMAE(vit_mae_config) + log_step("Randomly initialized model created", level='debug') + self.mask_ratio = config['model']['model_params']['mask_ratio'] # perceptual loss if config['model']['model_params'].get('use_perceptual_loss', False): From 7a4cecc3c722812b3d8830c45df59c75a3a06620 Mon Sep 17 00:00:00 2001 From: Mia Dai Date: Thu, 26 Feb 2026 14:16:40 -0600 Subject: [PATCH 28/28] Fix formatting and lint issues --- beast/cli/commands/train.py | 6 ++++-- beast/data/datamodules.py | 6 ++++-- beast/data/datasets.py | 9 ++++++--- beast/models/base.py | 6 +++--- beast/models/perceptual.py | 2 +- beast/models/vits.py | 7 ++++--- beast/train.py | 3 ++- tests/models/test_perceptual.py | 12 ++++++------ tests/models/test_vits.py | 3 ++- 9 files changed, 32 insertions(+), 22 deletions(-) diff --git a/beast/cli/commands/train.py b/beast/cli/commands/train.py index 080a7e8..1c8e20d 100644 --- a/beast/cli/commands/train.py +++ b/beast/cli/commands/train.py @@ -111,8 +111,10 @@ def handle(args): # Check for unsupported --checkpoint argument if hasattr(args, 'checkpoint') and args.checkpoint: - log_step(f"WARNING: --checkpoint argument provided but not supported: {args.checkpoint}", level='info', logger=_logger) - log_step("Checkpoint resuming is not currently implemented in the CLI", level='info', logger=_logger) + log_step( + f"WARNING: --checkpoint argument provided but not supported: {args.checkpoint}", level='info', logger=_logger) + log_step("Checkpoint resuming is not currently implemented in the CLI", + level='info', logger=_logger) # Initialize model log_step("Initializing model from config", level='info', logger=_logger) diff --git a/beast/data/datamodules.py b/beast/data/datamodules.py index f0ba13c..2ec88ad 100644 --- a/beast/data/datamodules.py +++ b/beast/data/datamodules.py @@ -161,7 +161,8 @@ def train_dataloader(self) -> torch.utils.data.DataLoader: sampler=self.sampler if self.use_sampler else None, generator=torch.Generator().manual_seed(self.seed), collate_fn=contrastive_collate_fn if self.use_sampler else None, - multiprocessing_context=multiprocessing.get_context('spawn') if self.num_workers > 0 else None, # More stable on HPC + multiprocessing_context=multiprocessing.get_context( + 'spawn') if self.num_workers > 0 else None, # More stable on HPC ) def val_dataloader(self) -> torch.utils.data.DataLoader: @@ -171,7 +172,8 @@ def val_dataloader(self) -> torch.utils.data.DataLoader: num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, pin_memory=True, - multiprocessing_context=multiprocessing.get_context('spawn') if self.num_workers > 0 else None, + multiprocessing_context=multiprocessing.get_context( + 'spawn') if self.num_workers > 0 else None, ) def test_dataloader(self) -> torch.utils.data.DataLoader: diff --git a/beast/data/datasets.py b/beast/data/datasets.py index 80dff31..488d696 100644 --- a/beast/data/datasets.py +++ b/beast/data/datasets.py @@ -40,16 +40,19 @@ def __init__(self, data_dir: str | Path, imgaug_pipeline: Callable | None) -> No # collect ALL png files in data_dir scan_start = time.time() try: - log_step(f"Starting to scan for PNG files in {self.data_dir} (this may take a while for large directories)...", level='debug') + log_step( + f"Starting to scan for PNG files in {self.data_dir} (this may take a while for large directories)...", level='debug') self.image_list = sorted(list(self.data_dir.rglob('*.png'))) scan_duration = time.time() - scan_start - log_step(f"Finished scanning. Found {len(self.image_list)} PNG files in {scan_duration:.2f} seconds", level='debug') + log_step( + f"Finished scanning. Found {len(self.image_list)} PNG files in {scan_duration:.2f} seconds", level='debug') except Exception as e: log_step(f"ERROR during file scanning: {e}", level='error') raise if len(self.image_list) == 0: raise ValueError(f'{self.data_dir} does not contain image data in png format') - log_step(f"BaseDataset initialization complete with {len(self.image_list)} images", level='debug') + log_step( + f"BaseDataset initialization complete with {len(self.image_list)} images", level='debug') # send image to tensor, resize to canonical dimensions, and normalize pytorch_transform_list = [ diff --git a/beast/models/base.py b/beast/models/base.py index bb295a1..8f11a7d 100644 --- a/beast/models/base.py +++ b/beast/models/base.py @@ -117,9 +117,9 @@ def evaluate_batch( on_step = (stage == 'train') on_epoch = True # Always log on epoch for both train and val self.log( - f'{stage}_loss', - loss, - prog_bar=True, + f'{stage}_loss', + loss, + prog_bar=True, sync_dist=True, on_step=on_step, on_epoch=on_epoch diff --git a/beast/models/perceptual.py b/beast/models/perceptual.py index 4304286..eef92c3 100644 --- a/beast/models/perceptual.py +++ b/beast/models/perceptual.py @@ -47,4 +47,4 @@ def __init__(self, *, device: str | torch.device, **kwargs: Any): for parameter in perceptual_net.parameters(): parameter.requires_grad = False - super(AlexPerceptual, self).__init__(network=perceptual_net, **kwargs) \ No newline at end of file + super(AlexPerceptual, self).__init__(network=perceptual_net, **kwargs) diff --git a/beast/models/vits.py b/beast/models/vits.py index 703c897..f9644ee 100644 --- a/beast/models/vits.py +++ b/beast/models/vits.py @@ -45,18 +45,19 @@ def __init__(self, config): super().__init__(config) # Set up ViT architecture vit_mae_config = ViTMAEConfig(**config['model']['model_params']) - + # Check if we should use pretrained weights or random initialization use_pretrained = not config['model']['model_params'].get('random_init', False) if use_pretrained: - log_step("Loading pretrained model from 'facebook/vit-mae-base' (this may take several minutes if downloading)...", level='debug') + log_step( + "Loading pretrained model from 'facebook/vit-mae-base' (this may take several minutes if downloading)...", level='debug') log_step("Note: Model will be cached locally after first download", level='debug') self.vit_mae = ViTMAE.from_pretrained("facebook/vit-mae-base", config=vit_mae_config) else: log_step("Using random initialization (random_init=True)", level='debug') self.vit_mae = ViTMAE(vit_mae_config) log_step("Randomly initialized model created", level='debug') - + self.mask_ratio = config['model']['model_params']['mask_ratio'] # perceptual loss if config['model']['model_params'].get('use_perceptual_loss', False): diff --git a/beast/train.py b/beast/train.py index 22c02bc..32ac8df 100644 --- a/beast/train.py +++ b/beast/train.py @@ -118,7 +118,8 @@ def train(config: dict, model, output_dir: str | Path): model.config['optimizer']['steps_per_epoch'] = steps_per_epoch model.config['optimizer']['total_steps'] = steps_per_epoch * num_epochs if rank_zero_only.rank == 0: - log_step(f"Training steps calculated: {steps_per_epoch} steps/epoch, {num_epochs} epochs", level='debug') + log_step( + f"Training steps calculated: {steps_per_epoch} steps/epoch, {num_epochs} epochs", level='debug') # ---------------------------------------------------------------------------------- # Save configuration in output directory diff --git a/tests/models/test_perceptual.py b/tests/models/test_perceptual.py index 8f2d1bb..3a077b0 100644 --- a/tests/models/test_perceptual.py +++ b/tests/models/test_perceptual.py @@ -15,8 +15,8 @@ def test_perceptual_forward(): x = torch.randn((5, 3, 224, 224)) loss = perceptual(x_hat, x) assert isinstance(loss, torch.Tensor) - assert loss.ndim == 0 - assert loss.item() >= 0 + assert loss.ndim == 0 + assert loss.item() >= 0 def test_alex_perceptual_forward(): @@ -28,8 +28,8 @@ def test_alex_perceptual_forward(): x = torch.randn((5, 3, 224, 224), device=device) loss = perceptual(x_hat, x) assert isinstance(loss, torch.Tensor) - assert loss.ndim == 0 - assert loss.item() >= 0 + assert loss.ndim == 0 + assert loss.item() >= 0 def test_alex_perceptual_different_inputs_produce_different_loss(): @@ -39,7 +39,7 @@ def test_alex_perceptual_different_inputs_produce_different_loss(): perceptual = AlexPerceptual(device=device, criterion=criterion) x = torch.randn((2, 3, 224, 224), device=device) loss_same = perceptual(x, x) - assert loss_same.item() < 1e-5 + assert loss_same.item() < 1e-5 x_hat = torch.randn((2, 3, 224, 224), device=device) loss_diff = perceptual(x_hat, x) - assert loss_diff.item() > 0 \ No newline at end of file + assert loss_diff.item() > 0 diff --git a/tests/models/test_vits.py b/tests/models/test_vits.py index f67d491..db4760a 100644 --- a/tests/models/test_vits.py +++ b/tests/models/test_vits.py @@ -40,8 +40,9 @@ def test_vit_autoencoder_contrastive_integration(config_vit, run_model_test): config['model']['model_params']['use_infoNCE'] = True run_model_test(config=config) + def test_alex_perceptual_integration(config_vit, run_model_test): """Test ViT autoencoder with AlexNet perceptual loss enabled.""" config = copy.deepcopy(config_vit) config['model']['model_params']['use_perceptual_loss'] = True - run_model_test(config=config) \ No newline at end of file + run_model_test(config=config)