Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
898d8d4
turned off contrastive loss
Xinming-Dai Nov 22, 2025
857c0b5
added perceptual loss
Xinming-Dai Nov 18, 2025
ffee13d
added perceptual loss to tensorboard
Xinming-Dai Nov 22, 2025
149fac7
delete perceptual config and resume_beast.sh
Xinming-Dai Feb 18, 2026
efafbaa
delete combined.py
Xinming-Dai Feb 18, 2026
651a16b
add test_alex_perceptual_integration method to test perceptual integr…
Xinming-Dai Feb 18, 2026
8d4a52f
add test for perceptual.py
Xinming-Dai Feb 18, 2026
4a202aa
move _log_step to beast.log_step with level param
Xinming-Dai Feb 23, 2026
7d6e2dc
move import to the top of the file
Xinming-Dai Feb 24, 2026
dc0c0b0
import log_step from beast
Xinming-Dai Feb 24, 2026
aa95a7d
remove import time
Xinming-Dai Feb 24, 2026
d056f02
remove extra logging
Xinming-Dai Feb 24, 2026
23c3077
use log_step
Xinming-Dai Feb 24, 2026
ea3db82
move log_step inside the try block
Xinming-Dai Feb 24, 2026
d13c5bb
change log_step level to "error"
Xinming-Dai Feb 24, 2026
f028040
modify the docstring
Xinming-Dai Feb 24, 2026
02110be
add docstrings and types
Xinming-Dai Feb 24, 2026
cdc7b34
parameter extraction and logging infoNCE loss
Xinming-Dai Feb 24, 2026
ea88d7e
use self.device
Xinming-Dai Feb 24, 2026
84b7356
remove fallback
Xinming-Dai Feb 24, 2026
c626b7c
assert 'perceptual_loss' in kwargs
Xinming-Dai Feb 24, 2026
dea4e47
revert all changes in __init__ of VisionTransformer
Xinming-Dai Feb 24, 2026
a4270da
set use_infoNCE: False
Xinming-Dai Feb 24, 2026
15e7567
revert changes to __get_package_version
Xinming-Dai Feb 24, 2026
25bd8e6
simplify VisionTransformer to match original style
Xinming-Dai Feb 25, 2026
e5e40d3
keep comments
Xinming-Dai Feb 25, 2026
78cdf8e
keep the code for using a randomly initialized model
Xinming-Dai Feb 25, 2026
7a4cecc
Fix formatting and lint issues
Xinming-Dai Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion beast/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,44 @@
# Hacky way to get version from pypackage.toml.
# Adapted from: https://github.com/python-poetry/poetry/issues/273#issuecomment-1877789967
import importlib.metadata
import time
from pathlib import Path
from typing import Any
from typing import Any, Optional

__package_version = "unknown"


def log_step(
msg: str,
level: Optional[str] = None,
flush: bool = True,
logger: Any = None,
) -> None:
"""Unified logging function with optional level.

Parameters
----------
msg: message to log
level: None (plain timestamp + msg), 'info', 'debug', or 'error'
flush: whether to flush stdout after printing
logger: if provided and level is 'info', also call logger.info(msg);
if provided and level is 'error', also call logger.error(msg)
"""
timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
if level == 'info':
print(f"[{timestamp}] INFO: {msg}", flush=flush)
if logger is not None:
logger.info(msg)
elif level == 'debug':
print(f"[{timestamp}] DEBUG: {msg}", flush=flush)
elif level == 'error':
print(f"[{timestamp}] ERROR: {msg}", flush=flush)
if logger is not None:
logger.error(msg)
else:
print(f"[{timestamp}] {msg}", flush=flush)


def __get_package_version() -> str:
"""Find the version of this package."""

Expand Down
10 changes: 10 additions & 0 deletions beast/api/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import os
import time
from pathlib import Path
from typing import Any

Expand All @@ -12,6 +13,7 @@
from beast.models.resnets import ResnetAutoencoder
from beast.models.vits import VisionTransformer
from beast.train import train
from beast import log_step


# TODO: Replace with contextlib.chdir in python 3.11.
Expand Down Expand Up @@ -112,7 +114,15 @@ def from_config(cls, config_path: str | Path | dict):

# Initialize the LightningModule
model_class = cls.MODEL_REGISTRY[model_type]
log_step(f"Creating {model_type} model instance", level='debug')
log_step(
f"About to call {model_class.__name__}.__init__() - this may take several minutes if downloading pretrained weights",
level='debug',
)
init_start = time.time()
model = model_class(config)
init_duration = time.time() - init_start
log_step(f"Model initialization completed in {init_duration:.2f} seconds", level='debug')

print(f'Initialized a {model_class} model')

Expand Down
36 changes: 30 additions & 6 deletions beast/cli/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from pathlib import Path

from beast import log_step
from beast.cli.types import config_file, output_dir

_logger = logging.getLogger('BEAST.CLI.TRAIN')
Expand Down Expand Up @@ -65,50 +66,73 @@ def register_parser(subparsers):
def handle(args):
"""Handle the train command execution."""

log_step("Starting train command handler", level='info', logger=_logger)

# Determine output directory
log_step("Determining output directory", level='info', logger=_logger)
if not args.output:
now = datetime.datetime.now()
args.output = Path('runs').resolve() / now.strftime('%Y-%m-%d') / now.strftime('%H-%M-%S')

args.output.mkdir(parents=True, exist_ok=True)
log_step(f"Output directory: {args.output}", level='info', logger=_logger)

# Set up logging to the model directory
log_step("Setting up model logging", level='info', logger=_logger)
model_log_handler = _setup_model_logging(args.output)
log_step("Model logging set up", level='info', logger=_logger)

# try:

# Load config
log_step(f"Loading config from: {args.config}", level='info', logger=_logger)
from beast.io import load_config
config = load_config(args.config)
log_step("Config loaded", level='info', logger=_logger)

# Apply overrides
if args.overrides:
log_step("Applying config overrides", level='info', logger=_logger)
from beast.io import apply_config_overrides
config = apply_config_overrides(config, args.overrides)
log_step("Config overrides applied", level='info', logger=_logger)

# Override specific values from command line
log_step("Applying command line overrides", level='info', logger=_logger)
if args.data:
config['data']['data_dir'] = str(args.data)
log_step(f"Data directory overridden to: {args.data}", level='info', logger=_logger)
if args.gpus is not None:
config['training']['num_gpus'] = args.gpus
log_step(f"Number of GPUs overridden to: {args.gpus}", level='info', logger=_logger)
if args.nodes is not None:
config['training']['num_nodes'] = args.nodes
log_step(f"Number of nodes overridden to: {args.nodes}", level='info', logger=_logger)

# Check for unsupported --checkpoint argument
if hasattr(args, 'checkpoint') and args.checkpoint:
log_step(
f"WARNING: --checkpoint argument provided but not supported: {args.checkpoint}", level='info', logger=_logger)
log_step("Checkpoint resuming is not currently implemented in the CLI",
level='info', logger=_logger)

# Initialize model
log_step("Initializing model from config", level='info', logger=_logger)
from beast.api.model import Model
model = Model.from_config(config)
log_step("Model initialized", level='info', logger=_logger)

# if args.resume:
# train_kwargs['resume_from_checkpoint'] = args.resume

_logger.info(f'Training {type(model.model)} model')
_logger.info(f'Data directory: {args.data}')
_logger.info(f'Output directory: {args.output}')
log_step(f'Training {type(model.model)} model', level='info', logger=_logger)
log_step(f'Data directory: {args.data}', level='info', logger=_logger)
log_step(f'Output directory: {args.output}', level='info', logger=_logger)

# Run training
log_step("About to call model.train()", level='info', logger=_logger)
model.train(output_dir=args.output)

_logger.info(f'Training complete. Model saved to {args.output}')
log_step("model.train() completed", level='info', logger=_logger)
log_step(f'Training complete. Model saved to {args.output}', level='info', logger=_logger)

# except Exception as e:

Expand Down
9 changes: 9 additions & 0 deletions beast/data/datamodules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Data modules split a dataset into train, val, and test modules."""

import copy
import multiprocessing
import os

import lightning.pytorch as pl
Expand Down Expand Up @@ -155,10 +156,13 @@ def train_dataloader(self) -> torch.utils.data.DataLoader:
batch_size=None if self.use_sampler else self.train_batch_size,
num_workers=self.num_workers,
persistent_workers=True if self.num_workers > 0 else False,
pin_memory=True, # Helps with GPU transfer
shuffle=True if not self.use_sampler else False,
sampler=self.sampler if self.use_sampler else None,
generator=torch.Generator().manual_seed(self.seed),
collate_fn=contrastive_collate_fn if self.use_sampler else None,
multiprocessing_context=multiprocessing.get_context(
'spawn') if self.num_workers > 0 else None, # More stable on HPC
)

def val_dataloader(self) -> torch.utils.data.DataLoader:
Expand All @@ -167,13 +171,18 @@ def val_dataloader(self) -> torch.utils.data.DataLoader:
batch_size=self.val_batch_size,
num_workers=self.num_workers,
persistent_workers=True if self.num_workers > 0 else False,
pin_memory=True,
multiprocessing_context=multiprocessing.get_context(
'spawn') if self.num_workers > 0 else None,
)

def test_dataloader(self) -> torch.utils.data.DataLoader:
return DataLoader(
self.test_dataset,
batch_size=self.test_batch_size,
num_workers=self.num_workers,
pin_memory=True,
multiprocessing_context='spawn' if self.num_workers > 0 else None,
)

def full_labeled_dataloader(self) -> torch.utils.data.DataLoader:
Expand Down
18 changes: 17 additions & 1 deletion beast/data/datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Dataset objects store images and augmentation pipeline."""

import time
from pathlib import Path
from typing import Callable

Expand All @@ -9,6 +10,7 @@
from torchvision import transforms
from typeguard import typechecked

from beast import log_step
from beast.data.types import ExampleDict

_IMAGENET_MEAN = [0.485, 0.456, 0.406]
Expand All @@ -28,15 +30,29 @@ def __init__(self, data_dir: str | Path, imgaug_pipeline: Callable | None) -> No
imgaug_transform: imgaug transform pipeline to apply to images

"""
log_step(f"BaseDataset.__init__ called with data_dir: {data_dir}", level='debug')
self.data_dir = Path(data_dir)
if not self.data_dir.is_dir():
raise ValueError(f'{self.data_dir} is not a directory')
log_step(f"Data directory exists: {self.data_dir}", level='debug')

self.imgaug_pipeline = imgaug_pipeline
# collect ALL png files in data_dir
self.image_list = sorted(list(self.data_dir.rglob('*.png')))
scan_start = time.time()
try:
log_step(
f"Starting to scan for PNG files in {self.data_dir} (this may take a while for large directories)...", level='debug')
self.image_list = sorted(list(self.data_dir.rglob('*.png')))
scan_duration = time.time() - scan_start
log_step(
f"Finished scanning. Found {len(self.image_list)} PNG files in {scan_duration:.2f} seconds", level='debug')
except Exception as e:
log_step(f"ERROR during file scanning: {e}", level='error')
raise
if len(self.image_list) == 0:
raise ValueError(f'{self.data_dir} does not contain image data in png format')
log_step(
f"BaseDataset initialization complete with {len(self.image_list)} images", level='debug')

# send image to tensor, resize to canonical dimensions, and normalize
pytorch_transform_list = [
Expand Down
14 changes: 13 additions & 1 deletion beast/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,26 @@ def evaluate_batch(
# multi-GPU training. Performance overhead was found negligible.

# log overall supervised loss
self.log(f'{stage}_loss', loss, prog_bar=True, sync_dist=True)
# For training: log on step; for validation: log on epoch
on_step = (stage == 'train')
on_epoch = True # Always log on epoch for both train and val
self.log(
f'{stage}_loss',
loss,
prog_bar=True,
sync_dist=True,
on_step=on_step,
on_epoch=on_epoch
)
# log individual supervised losses
for log_dict in log_list:
self.log(
log_dict['name'],
log_dict['value'].to(self.device),
prog_bar=log_dict.get('prog_bar', False),
sync_dist=True,
on_step=on_step,
on_epoch=on_epoch
)

return loss
Expand Down
50 changes: 50 additions & 0 deletions beast/models/perceptual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import torchvision
from torch import nn
from typing import Any
# https://github.com/MLReproHub/SMAE/blob/main/src/loss/perceptual.py


class Perceptual(nn.Module):
def __init__(self, *, network: nn.Module, criterion: nn.Module):
"""Initialize perceptual loss module.

Parameters
----------
network: feature extractor that maps input images to feature tensors
criterion: loss function applied to extracted features (e.g. MSELoss)
"""
super(Perceptual, self).__init__()
self.net = network
self.criterion = criterion
self.sigmoid = nn.Sigmoid()

def forward(self, x_hat: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
x_hat_features = self.sigmoid(self.net(x_hat))
x_features = self.sigmoid(self.net(x))
loss = self.criterion(x_hat_features, x_features)
return loss


class AlexPerceptual(Perceptual):
def __init__(self, *, device: str | torch.device, **kwargs: Any):
"""Perceptual loss using pretrained AlexNet features [Pihlgren et al. 2020].

Extracts features from the first five layers of AlexNet (pretrained on ImageNet)
and computes loss between reconstructed and target feature maps.

Parameters
----------
device: device to run the feature extractor on (e.g. 'cuda', 'cpu')
**kwargs: passed to parent; must include criterion (e.g. nn.MSELoss())
"""
# Load alex net pretrained on IN1k
alex_net = torchvision.models.alexnet(weights='IMAGENET1K_V1')
# Extract features after second relu activation
# Append sigmoid layer to normalize features
perceptual_net = alex_net.features[:5].to(device)
# Don't record gradients for the perceptual net, the gradients will still propagate through.
for parameter in perceptual_net.parameters():
parameter.requires_grad = False

super(AlexPerceptual, self).__init__(network=perceptual_net, **kwargs)
Loading