From ca1e8a113038832bdbf4c65a2b346357f6516c4c Mon Sep 17 00:00:00 2001 From: Vijay Kumar Date: Thu, 22 Jan 2026 14:28:16 -0800 Subject: [PATCH] refactor: modularize training code for LatentEvolution MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit extract shared training utilities and configuration classes into separate modules to prepare for latent trajectory model implementation. created new modules: - training_utils.py: LossAccumulator (enum-based), seed_everything, get_device - training_config.py: DataSplit, ProfileConfig, TrainingConfig, CrossValidationConfig moved code between modules: - ModelParams: latent.py → eed_model.py (model-specific config) - load_column_slice, load_metadata: zarr_io.py → load_flyvis.py (flyvis-specific) - DataSplit: load_flyvis.py → training_config.py (training concept) updated imports: - chunk_streaming.py, post_run_analyze.py, benchmark_training.py: zarr_io → load_flyvis - latent.py: added load_dataset and load_val_only functions, updated to use new modules - zarr_io.py: removed flyvis-specific dependencies benefits: - no circular dependencies - proper module layering (general doesn't depend on specific) - shared code extracted for reuse across models - LossAccumulator is now generic and reusable 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/LatentEvolution/benchmark_training.py | 2 +- src/LatentEvolution/chunk_streaming.py | 2 +- src/LatentEvolution/eed_model.py | 86 +++- src/LatentEvolution/latent.py | 559 ++++++---------------- src/LatentEvolution/load_flyvis.py | 99 ++-- src/LatentEvolution/post_run_analyze.py | 2 +- src/LatentEvolution/training_config.py | 176 +++++++ src/LatentEvolution/training_utils.py | 62 +++ src/NeuralGraph/zarr_io.py | 84 +--- 9 files changed, 558 insertions(+), 514 deletions(-) create mode 100644 src/LatentEvolution/training_config.py create mode 100644 src/LatentEvolution/training_utils.py diff --git a/src/LatentEvolution/benchmark_training.py b/src/LatentEvolution/benchmark_training.py index 5fce0b61..c3c18748 100644 --- a/src/LatentEvolution/benchmark_training.py +++ b/src/LatentEvolution/benchmark_training.py @@ -13,7 +13,7 @@ from LatentEvolution.load_flyvis import FlyVisSim from LatentEvolution.latent import ModelParams, LatentModel, train_step, train_step_nocompile from LatentEvolution.acquisition import compute_neuron_phases, sample_batch_indices -from NeuralGraph.zarr_io import load_column_slice +from LatentEvolution.load_flyvis import load_column_slice def seed_everything(seed: int): diff --git a/src/LatentEvolution/chunk_streaming.py b/src/LatentEvolution/chunk_streaming.py index 3ddfacde..01435403 100644 --- a/src/LatentEvolution/chunk_streaming.py +++ b/src/LatentEvolution/chunk_streaming.py @@ -11,7 +11,7 @@ import numpy as np import torch -from NeuralGraph.zarr_io import load_column_slice +from LatentEvolution.load_flyvis import load_column_slice # ------------------------------------------------------------------- diff --git a/src/LatentEvolution/eed_model.py b/src/LatentEvolution/eed_model.py index 54129904..94b02e10 100644 --- a/src/LatentEvolution/eed_model.py +++ b/src/LatentEvolution/eed_model.py @@ -7,7 +7,9 @@ import torch import torch.nn as nn -from pydantic import BaseModel, Field, field_validator, ConfigDict +from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict + +from LatentEvolution.training_config import TrainingConfig, ProfileConfig, CrossValidationConfig # ------------------------------------------------------------------- @@ -225,3 +227,85 @@ def forward(self, proj_t, proj_stim_t): """Evolve one time step in latent space.""" proj_t_next = proj_t + self.evolver(torch.cat([proj_t, proj_stim_t], dim=1)) return proj_t_next + + +# ------------------------------------------------------------------- +# Model Configuration +# ------------------------------------------------------------------- + + +class ModelParams(BaseModel): + latent_dims: int = Field(..., json_schema_extra={"short_name": "ld"}) + num_neurons: int + use_batch_norm: bool = True + activation: str = Field("ReLU", description="activation function from torch.nn") + encoder_params: EncoderParams + decoder_params: DecoderParams + evolver_params: EvolverParams + stimulus_encoder_params: StimulusEncoderParams + training: TrainingConfig + profiling: ProfileConfig | None = Field( + None, description="optional profiler configuration to generate chrome traces for performance analysis" + ) + cross_validation_configs: list[CrossValidationConfig] = Field( + default_factory=lambda: [CrossValidationConfig(simulation_config="fly_N9_62_0")], + description="list of datasets to validate on after training" + ) + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + @field_validator("activation") + @classmethod + def validate_activation(cls, v: str) -> str: + if not hasattr(nn, v): + raise ValueError(f"unknown activation '{v}' in torch.nn") + return v + + @model_validator(mode='after') + def validate_encoder_decoder_symmetry(self): + """ensure encoder and decoder have symmetric mlp parameters.""" + if self.encoder_params.num_hidden_units != self.decoder_params.num_hidden_units: + raise ValueError( + f"encoder and decoder must have the same num_hidden_units. " + f"got encoder={self.encoder_params.num_hidden_units}, decoder={self.decoder_params.num_hidden_units}" + ) + if self.encoder_params.num_hidden_layers != self.decoder_params.num_hidden_layers: + raise ValueError( + f"encoder and decoder must have the same num_hidden_layers. " + f"got encoder={self.encoder_params.num_hidden_layers}, decoder={self.decoder_params.num_hidden_layers}" + ) + if self.encoder_params.use_input_skips != self.decoder_params.use_input_skips: + raise ValueError( + f"encoder and decoder must have the same use_input_skips setting. " + f"got encoder={self.encoder_params.use_input_skips}, decoder={self.decoder_params.use_input_skips}" + ) + return self + + def flatten(self, sep: str = ".") -> dict[str, int | float | str | bool]: + """ + flatten the modelparams into a single-level dictionary. + + args: + sep: separator to use for nested keys (default: ".") + + returns: + a flat dictionary with nested keys joined by the separator. + + example: + >>> params.flatten() + {'latent_dims': 10, 'encoder_params.num_hidden_units': 64, ...} + """ + def _flatten_dict( + d: dict[str, int | float | str | bool | dict], + parent_key: str = "", + ) -> dict[str, int | float | str | bool]: + items: list[tuple[str, int | float | str | bool]] = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(_flatten_dict(v, new_key).items()) + else: + items.append((new_key, v)) + return dict(items) + + return _flatten_dict(self.model_dump()) diff --git a/src/LatentEvolution/latent.py b/src/LatentEvolution/latent.py index 76126fe2..bb6bd6b9 100644 --- a/src/LatentEvolution/latent.py +++ b/src/LatentEvolution/latent.py @@ -6,8 +6,7 @@ from pathlib import Path from typing import Callable, Iterator from datetime import datetime -from dataclasses import dataclass -import random +from enum import Enum, auto import sys import re import time @@ -17,10 +16,15 @@ from torch.utils.tensorboard import SummaryWriter import yaml import tyro -import numpy as np -from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict -from LatentEvolution.load_flyvis import NeuronData, FlyVisSim, DataSplit, load_connectome_graph -from NeuralGraph.zarr_io import load_column_slice, load_metadata +from LatentEvolution.load_flyvis import ( + NeuronData, + FlyVisSim, + load_connectome_graph, + load_column_slice, + load_metadata, +) +from LatentEvolution.training_config import DataSplit +from LatentEvolution.chunk_loader import RandomChunkLoader from LatentEvolution.gpu_stats import GPUMonitor from LatentEvolution.diagnostics import run_validation_diagnostics, PlotMode from LatentEvolution.hparam_paths import create_run_directory, get_git_commit_hash @@ -29,243 +33,151 @@ MLPWithSkips, MLPParams, Evolver, - EvolverParams, - EncoderParams, - DecoderParams, - StimulusEncoderParams, + ModelParams, +) +from LatentEvolution.training_utils import ( + LossAccumulator, + seed_everything, + get_device, ) -from LatentEvolution.chunk_loader import RandomChunkLoader from LatentEvolution.chunk_streaming import ( - create_zarr_loader, calculate_chunk_params, ChunkLatencyStats, + create_zarr_loader, ) from LatentEvolution.acquisition import ( - AcquisitionMode, - AllTimePointsMode, compute_neuron_phases, sample_batch_indices, ) # ------------------------------------------------------------------- -# Pydantic Config Classes +# Data Loading # ------------------------------------------------------------------- -class ProfileConfig(BaseModel): - """Configuration for PyTorch profiler to generate Chrome traces.""" - wait: int = Field( - 1, description="Number of epochs to skip before starting profiler warmup" - ) - warmup: int = Field( - 1, description="Number of epochs for profiler warmup" - ) - active: int = Field( - 1, description="Number of epochs to actively profile" - ) - repeat: int = Field( - 0, description="Number of times to repeat the profiling cycle" - ) - record_shapes: bool = Field( - True, description="Record tensor shapes in the trace" - ) - profile_memory: bool = Field( - True, description="Profile memory usage" - ) - with_stack: bool = Field( - False, description="Record source code stack traces (increases overhead)" - ) - model_config = ConfigDict(extra="forbid", validate_assignment=True) +def load_dataset( + simulation_config: str, + column_to_model: str, + data_split: DataSplit, + num_input_dims: int, + device: torch.device, + chunk_size: int = 65536, + time_units: int = 1, +): + """ + load dataset from zarr with chunked streaming for training data. + + training data is streamed in chunks via RandomChunkLoader to reduce GPU memory. + validation data is loaded directly to GPU (small enough to fit). + + args: + simulation_config: name of simulation config (e.g., "fly_N9_62_1") + column_to_model: column name to model (e.g., "VOLTAGE", "CALCIUM") + data_split: DataSplit object with train/val time ranges + num_input_dims: number of stimulus input dimensions to keep + device: pytorch device to load data onto + chunk_size: chunk size for streaming (default: 65536 = 64K) + time_units: alignment constraint for chunk starts (default: 1) + returns: + tuple of (chunk_loader, val_data, val_stim, neuron_data, train_total_timesteps) + """ + data_path = f"graphs_data/fly/{simulation_config}/x_list_0" + column_idx = FlyVisSim[column_to_model].value -class UnconnectedToZeroConfig(BaseModel): - """Augmentation: add synthetic unconnected neurons with zero activity.""" - num_neurons: int = Field(0, description="Number of unconnected neurons to add") - loss_coeff: float = Field(1.0, description="Scalar weighting of the loss for unconnected neurons") - model_config = ConfigDict(extra="forbid", validate_assignment=True) + # load val data directly to GPU (small enough to fit) + val_data = torch.from_numpy( + load_column_slice(data_path, column_idx, data_split.validation_start, data_split.validation_end) + ).to(device) -class TrainingConfig(BaseModel): - time_units: int = Field( - 1, - description="Observation interval: activity data available every N steps. Evolver unrolled N times during training.", - json_schema_extra={"short_name": "tu"} - ) - acquisition_mode: AcquisitionMode = Field( - default_factory=AllTimePointsMode, - description="Data acquisition mode. Controls which timesteps have observable data for each neuron.", - json_schema_extra={"short_name": "acq"} - ) - intermediate_loss_steps: list[int] = Field( - default_factory=list, - description="DEPRECATED: Intermediate steps feature has been removed. Must be empty list.", - json_schema_extra={"short_name": "ils"} - ) - evolve_multiple_steps: int = Field( - 1, - description="Number of time_units multiples to evolve. Loss applied at each multiple.", - json_schema_extra={"short_name": "ems"} - ) - epochs: int = Field(10, json_schema_extra={"short_name": "ep"}) - batch_size: int = Field(32, json_schema_extra={"short_name": "bs"}) - learning_rate: float = Field(1e-3, json_schema_extra={"short_name": "lr"}) - optimizer: str = Field("Adam", description="Optimizer name from torch.optim", json_schema_extra={"short_name": "opt"}) - train_step: str = Field("train_step", description="Compiled train step function") - simulation_config: str - column_to_model: str = "CALCIUM" - use_tf32_matmul: bool = Field( - False, description="Enable fast tf32 multiplication on certain NVIDIA GPUs" - ) - seed: int = Field(42, json_schema_extra={"short_name": "seed"}) - data_split: DataSplit - data_passes_per_epoch: int = 1 - diagnostics_freq_epochs: int = Field( - 0, description="Run validation diagnostics every N epochs (0 = only at end of training)" - ) - save_checkpoint_every_n_epochs: int = Field( - 10, description="Save model checkpoint every N epochs (0 = disabled)" - ) - save_best_checkpoint: bool = Field( - True, description="Save checkpoint when validation loss improves" - ) - loss_function: str = Field( - "mse_loss", description="Loss function name from torch.nn.functional (e.g., 'mse_loss', 'huber_loss', 'l1_loss')" - ) - grad_clip_max_norm: float = Field( - 0.0, description="Max gradient norm for clipping (0 = disabled)", json_schema_extra={"short_name": "gc"} - ) - reconstruction_warmup_epochs: int = Field( - 0, description="Number of warmup epochs to train encoder/decoder only (reconstruction loss) before the main training loop. These are additional epochs, not counted in 'epochs'.", json_schema_extra={"short_name": "recon_wu"} - ) - unconnected_to_zero: UnconnectedToZeroConfig = Field(default_factory=UnconnectedToZeroConfig) - early_stop_intervening_mse: bool = Field( - False, description="Enable early stopping based on max intervening MSE metric (0 to tu-1)", json_schema_extra={"short_name": "es_int"} - ) - early_stop_patience_epochs: int = Field( - 10, description="Number of epochs to wait for 10% improvement in max intervening MSE before stopping", json_schema_extra={"short_name": "es_patience"} + val_stim = torch.from_numpy( + load_column_slice(data_path, FlyVisSim.STIMULUS.value, data_split.validation_start, data_split.validation_end, neuron_limit=num_input_dims) + ).to(device) + + # load neuron metadata + metadata = load_metadata(data_path) + neuron_data = NeuronData.from_metadata(metadata) + + # create chunk loader for training data (streams from disk -> GPU) + train_total_timesteps = data_split.train_end - data_split.train_start + + # create zarr loading function + zarr_load_fn = create_zarr_loader( + data_path=data_path, + column_idx=column_idx, + stim_column_idx=FlyVisSim.STIMULUS.value, + num_stim_dims=num_input_dims, ) - early_stop_min_divergence: int = Field( - 1000, description="Minimum first divergence step required for early stopping to activate", json_schema_extra={"short_name": "es_min_div"} + + # wrap to offset by train_start + def offset_load_fn(start: int, end: int): + return zarr_load_fn(data_split.train_start + start, data_split.train_start + end) + + # create chunk loader + chunk_loader = RandomChunkLoader( + load_fn=offset_load_fn, + total_timesteps=train_total_timesteps, + chunk_size=chunk_size, + device=device, + prefetch=6, # buffer 6 chunks ahead for better overlap + seed=None, # will be set per epoch in training loop + time_units=time_units, ) - model_config = ConfigDict(extra="forbid", validate_assignment=True) - - @field_validator("optimizer") - @classmethod - def validate_optimizer(cls, v: str) -> str: - if not hasattr(torch.optim, v): - raise ValueError(f"Unknown optimizer '{v}' in torch.optim") - return v - - @field_validator("loss_function") - @classmethod - def validate_loss_function(cls, v: str) -> str: - if not hasattr(torch.nn.functional, v): - raise ValueError(f"Unknown loss function '{v}' in torch.nn.functional") - return v - - @model_validator(mode='after') - def validate_training_config(self): - if len(self.intermediate_loss_steps) > 0: - raise ValueError("intermediate_loss_steps is deprecated and must be empty list") - if self.evolve_multiple_steps < 1: - raise ValueError("evolve_multiple_steps must be >= 1") - - # validate acquisition mode compatibility - from LatentEvolution.acquisition import StaggeredRandomMode - if isinstance(self.acquisition_mode, StaggeredRandomMode): - if self.unconnected_to_zero.num_neurons > 0: - raise ValueError( - "unconnected_to_zero augmentation is incompatible with staggered_random acquisition mode. " - "staggered mode observes neurons at different times, breaking the connectome assumption." - ) - return self + return chunk_loader, val_data, val_stim, neuron_data, train_total_timesteps -class CrossValidationConfig(BaseModel): - """Configuration for cross-dataset validation.""" - simulation_config: str - name: str | None = None # Optional human-readable name - data_split: DataSplit | None = None # data split +def load_val_only( + simulation_config: str, + column_to_model: str, + data_split: DataSplit, + num_input_dims: int, + device: torch.device +): + """ + load only validation data for cross-validation (memory efficient). - model_config = ConfigDict(extra="forbid", validate_assignment=True) + streams data directly from zarr to device memory. + args: + simulation_config: name of simulation config (e.g., "fly_N9_62_1") + column_to_model: column name to model (e.g., "VOLTAGE", "CALCIUM") + data_split: DataSplit object with train/val time ranges + num_input_dims: number of stimulus input dimensions to keep + device: pytorch device to load data onto -class ModelParams(BaseModel): - latent_dims: int = Field(..., json_schema_extra={"short_name": "ld"}) - num_neurons: int - use_batch_norm: bool = True - activation: str = Field("ReLU", description="Activation function from torch.nn") - encoder_params: EncoderParams - decoder_params: DecoderParams - evolver_params: EvolverParams - stimulus_encoder_params: StimulusEncoderParams - training: TrainingConfig - profiling: ProfileConfig | None = Field( - None, description="Optional profiler configuration to generate Chrome traces for performance analysis" - ) - cross_validation_configs: list[CrossValidationConfig] = Field( - default_factory=lambda: [CrossValidationConfig(simulation_config="fly_N9_62_0")], - description="List of datasets to validate on after training" - ) + returns: + tuple of (val_data, val_stim) + """ + data_path = f"graphs_data/fly/{simulation_config}/x_list_0" + column_idx = FlyVisSim[column_to_model].value - model_config = ConfigDict(extra="forbid", validate_assignment=True) - - @field_validator("activation") - @classmethod - def validate_activation(cls, v: str) -> str: - if not hasattr(nn, v): - raise ValueError(f"Unknown activation '{v}' in torch.nn") - return v - - @model_validator(mode='after') - def validate_encoder_decoder_symmetry(self): - """Ensure encoder and decoder have symmetric MLP parameters.""" - if self.encoder_params.num_hidden_units != self.decoder_params.num_hidden_units: - raise ValueError( - f"Encoder and decoder must have the same num_hidden_units. " - f"Got encoder={self.encoder_params.num_hidden_units}, decoder={self.decoder_params.num_hidden_units}" - ) - if self.encoder_params.num_hidden_layers != self.decoder_params.num_hidden_layers: - raise ValueError( - f"Encoder and decoder must have the same num_hidden_layers. " - f"Got encoder={self.encoder_params.num_hidden_layers}, decoder={self.decoder_params.num_hidden_layers}" - ) - if self.encoder_params.use_input_skips != self.decoder_params.use_input_skips: - raise ValueError( - f"Encoder and decoder must have the same use_input_skips setting. " - f"Got encoder={self.encoder_params.use_input_skips}, decoder={self.decoder_params.use_input_skips}" - ) - return self - - def flatten(self, sep: str = ".") -> dict[str, int | float | str | bool]: - """ - Flatten the ModelParams into a single-level dictionary. - - Args: - sep: Separator to use for nested keys (default: ".") - - Returns: - A flat dictionary with nested keys joined by the separator. - - Example: - >>> params.flatten() - {'latent_dims': 10, 'encoder_params.num_hidden_units': 64, ...} - """ - def _flatten_dict( - d: dict[str, int | float | str | bool | dict], - parent_key: str = "", - ) -> dict[str, int | float | str | bool]: - items: list[tuple[str, int | float | str | bool]] = [] - for k, v in d.items(): - new_key = f"{parent_key}{sep}{k}" if parent_key else k - if isinstance(v, dict): - items.extend(_flatten_dict(v, new_key).items()) - else: - items.append((new_key, v)) - return dict(items) - - return _flatten_dict(self.model_dump()) + # load val data column slice directly to device + val_data = torch.from_numpy( + load_column_slice(data_path, column_idx, data_split.validation_start, data_split.validation_end) + ).to(device) + + # load val stimulus slice directly to device + val_stim = torch.from_numpy( + load_column_slice(data_path, FlyVisSim.STIMULUS.value, data_split.validation_start, data_split.validation_end, neuron_limit=num_input_dims) + ).to(device) + + return val_data, val_stim + + +# ------------------------------------------------------------------- +# Loss Types +# ------------------------------------------------------------------- + + +class LossType(Enum): + """loss component types for EED model.""" + TOTAL = auto() + RECON = auto() + EVOLVE = auto() + REG = auto() + AUG_LOSS = auto() # ------------------------------------------------------------------- @@ -385,64 +297,6 @@ def make_batches_random( yield start_indices, selected_neurons, needed_indices -# ------------------------------------------------------------------- -# Utilities -# ------------------------------------------------------------------- - - -@dataclass -class LossComponents: - """Accumulator for tracking loss components.""" - total: float = 0.0 - recon: float = 0.0 - evolve: float = 0.0 - reg: float = 0.0 - aug_loss: float = 0.0 - count: int = 0 - - def accumulate(self, *losses): - """Add losses from one batch (total, recon, evolve, reg, aug_loss).""" - self.total += losses[0].detach().item() - self.recon += losses[1].detach().item() - self.evolve += losses[2].detach().item() - self.reg += losses[3].detach().item() - self.aug_loss += losses[4].detach().item() - self.count += 1 - - def mean(self) -> 'LossComponents': - """Return a new LossComponents with mean values.""" - if self.count == 0: - return LossComponents(count=0) - return LossComponents( - total=self.total / self.count, - recon=self.recon / self.count, - evolve=self.evolve / self.count, - reg=self.reg / self.count, - aug_loss=self.aug_loss / self.count, - count=self.count, - ) - - -def seed_everything(seed: int): - """Set all random seeds for reproducibility.""" - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - -def get_device() -> torch.device: - """Cross-platform device selection.""" - if torch.backends.mps.is_available() and torch.backends.mps.is_built(): - print("Using Apple MPS backend for training.") - return torch.device("mps") - elif torch.cuda.is_available(): - print(f"Using CUDA device: {torch.cuda.get_device_name(0)}") - return torch.device("cuda") - else: - print("Using CPU for training.") - return torch.device("cpu") def train_step_reconstruction_only_nocompile( @@ -596,121 +450,6 @@ def train_step_nocompile( train_step = torch.compile(train_step_nocompile, fullgraph=True, mode="reduce-overhead") -# ------------------------------------------------------------------- -# Data Loading and Evaluation -# ------------------------------------------------------------------- - - -def load_dataset( - simulation_config: str, - column_to_model: str, - data_split: DataSplit, - num_input_dims: int, - device: torch.device, - chunk_size: int = 65536, - time_units: int = 1, -): - """ - load dataset from zarr with chunked streaming for training data. - - training data is streamed in chunks via RandomChunkLoader to reduce GPU memory. - validation data is loaded directly to GPU (small enough to fit). - - args: - simulation_config: name of simulation config (e.g., "fly_N9_62_1") - column_to_model: column name to model (e.g., "VOLTAGE", "CALCIUM") - data_split: DataSplit object with train/val time ranges - num_input_dims: number of stimulus input dimensions to keep - device: pytorch device to load data onto - chunk_size: chunk size for streaming (default: 65536 = 64K) - time_units: alignment constraint for chunk starts (default: 1) - - returns: - tuple of (chunk_loader, val_data, val_stim, neuron_data, train_total_timesteps) - """ - data_path = f"graphs_data/fly/{simulation_config}/x_list_0" - column_idx = FlyVisSim[column_to_model].value - - # load val data directly to GPU (small enough to fit) - val_data = torch.from_numpy( - load_column_slice(data_path, column_idx, data_split.validation_start, data_split.validation_end) - ).to(device) - - val_stim = torch.from_numpy( - load_column_slice(data_path, FlyVisSim.STIMULUS.value, data_split.validation_start, data_split.validation_end, neuron_limit=num_input_dims) - ).to(device) - - # load neuron metadata - metadata = load_metadata(data_path) - neuron_data = NeuronData.from_metadata(metadata) - - # create chunk loader for training data (streams from disk -> GPU) - train_total_timesteps = data_split.train_end - data_split.train_start - - # create zarr loading function - zarr_load_fn = create_zarr_loader( - data_path=data_path, - column_idx=column_idx, - stim_column_idx=FlyVisSim.STIMULUS.value, - num_stim_dims=num_input_dims, - ) - - # wrap to offset by train_start - def offset_load_fn(start: int, end: int): - return zarr_load_fn(data_split.train_start + start, data_split.train_start + end) - - # create chunk loader - chunk_loader = RandomChunkLoader( - load_fn=offset_load_fn, - total_timesteps=train_total_timesteps, - chunk_size=chunk_size, - device=device, - prefetch=6, # buffer 6 chunks ahead for better overlap - seed=None, # will be set per epoch in training loop - time_units=time_units, - ) - - return chunk_loader, val_data, val_stim, neuron_data, train_total_timesteps - - -def load_val_only( - simulation_config: str, - column_to_model: str, - data_split: DataSplit, - num_input_dims: int, - device: torch.device -): - """ - Load only validation data for cross-validation (memory efficient). - - Streams data directly from zarr to device memory. - - Args: - simulation_config: Name of simulation config (e.g., "fly_N9_62_1") - column_to_model: Column name to model (e.g., "VOLTAGE", "CALCIUM") - data_split: DataSplit object with train/val time ranges - num_input_dims: Number of stimulus input dimensions to keep - device: PyTorch device to load data onto - - Returns: - Tuple of (val_data, val_stim) - """ - data_path = f"graphs_data/fly/{simulation_config}/x_list_0" - column_idx = FlyVisSim[column_to_model].value - - # load val data column slice directly to device - val_data = torch.from_numpy( - load_column_slice(data_path, column_idx, data_split.validation_start, data_split.validation_end) - ).to(device) - - # load val stimulus slice directly to device - val_stim = torch.from_numpy( - load_column_slice(data_path, FlyVisSim.STIMULUS.value, data_split.validation_start, data_split.validation_end, neuron_limit=num_input_dims) - ).to(device) - - return val_data, val_stim - - # ------------------------------------------------------------------- # Training # ------------------------------------------------------------------- @@ -837,7 +576,7 @@ def train(cfg: ModelParams, run_dir: Path): for warmup_epoch in range(recon_warmup_epochs): warmup_epoch_start = datetime.now() - warmup_losses = LossComponents() + warmup_losses = LossAccumulator(LossType) # start loading chunks for this warmup epoch chunk_loader.start_epoch(num_chunks=chunks_per_epoch) @@ -873,16 +612,22 @@ def train(cfg: ModelParams, run_dir: Path): ) loss_tuple[0].backward() optimizer.step() - warmup_losses.accumulate(*loss_tuple) + warmup_losses.accumulate({ + LossType.TOTAL: loss_tuple[0], + LossType.RECON: loss_tuple[1], + LossType.EVOLVE: loss_tuple[2], + LossType.REG: loss_tuple[3], + LossType.AUG_LOSS: loss_tuple[4], + }) warmup_epoch_duration = (datetime.now() - warmup_epoch_start).total_seconds() mean_warmup = warmup_losses.mean() - print(f"warmup {warmup_epoch+1}/{recon_warmup_epochs} | recon loss: {mean_warmup.recon:.4e} | duration: {warmup_epoch_duration:.2f}s") + print(f"warmup {warmup_epoch+1}/{recon_warmup_epochs} | recon loss: {mean_warmup[LossType.RECON]:.4e} | duration: {warmup_epoch_duration:.2f}s") # log to tensorboard - writer.add_scalar("ReconWarmup/loss", mean_warmup.total, warmup_epoch) - writer.add_scalar("ReconWarmup/recon_loss", mean_warmup.recon, warmup_epoch) - writer.add_scalar("ReconWarmup/reg_loss", mean_warmup.reg, warmup_epoch) + writer.add_scalar("ReconWarmup/loss", mean_warmup[LossType.TOTAL], warmup_epoch) + writer.add_scalar("ReconWarmup/recon_loss", mean_warmup[LossType.RECON], warmup_epoch) + writer.add_scalar("ReconWarmup/reg_loss", mean_warmup[LossType.REG], warmup_epoch) writer.add_scalar("ReconWarmup/epoch_duration", warmup_epoch_duration, warmup_epoch) model.evolver.requires_grad_(True) @@ -943,7 +688,7 @@ def train(cfg: ModelParams, run_dir: Path): for epoch in range(cfg.training.epochs): epoch_start = datetime.now() gpu_monitor.sample_epoch_start() - losses = LossComponents() + losses = LossAccumulator(LossType) # start loading chunks for this epoch chunk_loader.start_epoch(num_chunks=chunks_per_epoch) @@ -1008,7 +753,13 @@ def train(cfg: ModelParams, run_dir: Path): optimizer.step() step_time = time.time() - step_start - losses.accumulate(*loss_tuple) + losses.accumulate({ + LossType.TOTAL: loss_tuple[0], + LossType.RECON: loss_tuple[1], + LossType.EVOLVE: loss_tuple[2], + LossType.REG: loss_tuple[3], + LossType.AUG_LOSS: loss_tuple[4], + }) # sample timing every 10 batches if batch_in_chunk % 10 == 0: @@ -1024,16 +775,16 @@ def train(cfg: ModelParams, run_dir: Path): print( f"Epoch {epoch+1}/{cfg.training.epochs} | " - f"Train Loss: {mean_losses.total:.4e} | " + f"Train Loss: {mean_losses[LossType.TOTAL]:.4e} | " f"Duration: {epoch_duration:.2f}s (Total: {total_elapsed:.1f}s)" ) # log to tensorboard - writer.add_scalar("Loss/train", mean_losses.total, epoch) - writer.add_scalar("Loss/train_recon", mean_losses.recon, epoch) - writer.add_scalar("Loss/train_evolve", mean_losses.evolve, epoch) - writer.add_scalar("Loss/train_reg", mean_losses.reg, epoch) - writer.add_scalar("Loss/train_aug_loss", mean_losses.aug_loss, epoch) + writer.add_scalar("Loss/train", mean_losses[LossType.TOTAL], epoch) + writer.add_scalar("Loss/train_recon", mean_losses[LossType.RECON], epoch) + writer.add_scalar("Loss/train_evolve", mean_losses[LossType.EVOLVE], epoch) + writer.add_scalar("Loss/train_reg", mean_losses[LossType.REG], epoch) + writer.add_scalar("Loss/train_aug_loss", mean_losses[LossType.AUG_LOSS], epoch) writer.add_scalar("Time/epoch_duration", epoch_duration, epoch) writer.add_scalar("Time/total_elapsed", total_elapsed, epoch) @@ -1172,7 +923,7 @@ def train(cfg: ModelParams, run_dir: Path): metrics.update( { - "final_train_loss": mean_losses.total, + "final_train_loss": mean_losses[LossType.TOTAL], "commit_hash": commit_hash, "training_duration_seconds": round(total_training_duration, 2), "avg_epoch_duration_seconds": round(avg_epoch_duration, 2), diff --git a/src/LatentEvolution/load_flyvis.py b/src/LatentEvolution/load_flyvis.py index 610f573d..4ad86902 100644 --- a/src/LatentEvolution/load_flyvis.py +++ b/src/LatentEvolution/load_flyvis.py @@ -1,9 +1,11 @@ """Module to load flyvis simulation data.""" from enum import IntEnum +from pathlib import Path + import numpy as np -from pydantic import BaseModel, field_validator, ConfigDict import torch +import tensorstore as ts class FlyVisSim(IntEnum): """Column interpretation in flyvis simulation outputs.""" @@ -90,32 +92,6 @@ def from_metadata(cls, metadata: np.ndarray) -> "NeuronData": return obj -class DataSplit(BaseModel): - """Split the time series into train/validation sets.""" - - train_start: int - train_end: int - validation_start: int - validation_end: int - - model_config = ConfigDict(extra="forbid", validate_assignment=True) - - @field_validator("*") - @classmethod - def check_non_negative(cls, v: int) -> int: - if v < 0: - raise ValueError("Indices in data_split must be non-negative.") - return v - - @field_validator("train_end") - @classmethod - def check_order(cls, v, info): - # very basic ordering sanity check - d = info.data - if "train_start" in d and v <= d["train_start"]: - raise ValueError("train_end must be greater than train_start.") - return v - def load_connectome_graph(data_path: str): """FlyVis connectome. @@ -129,3 +105,72 @@ def load_connectome_graph(data_path: str): weights = torch.load(f"{data_path}/weights.pt", map_location="cpu").numpy() wmat = torch.sparse_coo_tensor(edge_index, weights).to_sparse_csr() return wmat + + +# mapping from dynamic column indices to timeseries column position +_DYNAMIC_COL_TO_TS = {col.value: i for i, col in enumerate(DYNAMIC_COLUMNS)} + + +def load_column_slice( + path: str | Path, + column: int, + time_start: int, + time_end: int, + neuron_limit: int | None = None, +) -> np.ndarray: + """load a time series column slice directly from zarr format. + + this avoids loading the full (T, N, 9) array when you only need one column. + + args: + path: base path to zarr data (without extension) + column: dynamic column index (VOLTAGE=3, STIMULUS=4, CALCIUM=7, FLUORESCENCE=8) + time_start: start time index + time_end: end time index + neuron_limit: optional limit on neurons (first N) + + returns: + numpy array of shape (time_end - time_start, N) or (time_end - time_start, neuron_limit) + + raises: + AssertionError: if column is a static column (use load_metadata instead) + """ + assert column in _DYNAMIC_COL_TO_TS, ( + f"column {column} is static, use load_metadata() instead" + ) + + path = Path(path) + base_path = path.with_suffix('') if path.suffix in ('.npy', '.zarr') else path + + ts_col = _DYNAMIC_COL_TO_TS[column] + neuron_slice = slice(None, neuron_limit) if neuron_limit else slice(None) + + ts_path = base_path / 'timeseries.zarr' + spec = { + 'driver': 'zarr', + 'kvstore': {'driver': 'file', 'path': str(ts_path)}, + } + store = ts.open(spec).result() + data = store[time_start:time_end, neuron_slice, ts_col].read().result() + + return np.ascontiguousarray(data) + + +def load_metadata(path: str | Path) -> np.ndarray: + """load metadata from V2 zarr format. + + args: + path: base path to zarr data + + returns: + numpy array of shape (N, 5) with static columns + """ + path = Path(path) + base_path = path.with_suffix('') if path.suffix in ('.npy', '.zarr') else path + meta_path = base_path / 'metadata.zarr' + + spec = { + 'driver': 'zarr', + 'kvstore': {'driver': 'file', 'path': str(meta_path)}, + } + return ts.open(spec).result().read().result() diff --git a/src/LatentEvolution/post_run_analyze.py b/src/LatentEvolution/post_run_analyze.py index baafe407..e36744da 100644 --- a/src/LatentEvolution/post_run_analyze.py +++ b/src/LatentEvolution/post_run_analyze.py @@ -20,7 +20,7 @@ from LatentEvolution.latent import LatentModel, ModelParams, get_device, load_val_only from LatentEvolution.diagnostics import PlotMode, run_validation_diagnostics from LatentEvolution.load_flyvis import NeuronData -from NeuralGraph.zarr_io import load_metadata +from LatentEvolution.load_flyvis import load_metadata def main(run_dir: Path, epoch: int | None = None) -> None: diff --git a/src/LatentEvolution/training_config.py b/src/LatentEvolution/training_config.py new file mode 100644 index 00000000..9d04291a --- /dev/null +++ b/src/LatentEvolution/training_config.py @@ -0,0 +1,176 @@ +""" +shared training configuration classes for neural dynamics models. + +includes profiling, training hyperparameters, and cross-validation configs. +""" + +from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict +import torch + +from LatentEvolution.acquisition import AcquisitionMode, AllTimePointsMode + + +class DataSplit(BaseModel): + """split the time series into train/validation sets.""" + + train_start: int + train_end: int + validation_start: int + validation_end: int + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + @field_validator("*") + @classmethod + def check_non_negative(cls, v: int) -> int: + if v < 0: + raise ValueError("indices in data_split must be non-negative.") + return v + + @field_validator("train_end") + @classmethod + def check_order(cls, v, info): + # very basic ordering sanity check + d = info.data + if "train_start" in d and v <= d["train_start"]: + raise ValueError("train_end must be greater than train_start.") + return v + + +class ProfileConfig(BaseModel): + """configuration for pytorch profiler to generate chrome traces.""" + wait: int = Field( + 1, description="number of epochs to skip before starting profiler warmup" + ) + warmup: int = Field( + 1, description="number of epochs for profiler warmup" + ) + active: int = Field( + 1, description="number of epochs to actively profile" + ) + repeat: int = Field( + 0, description="number of times to repeat the profiling cycle" + ) + record_shapes: bool = Field( + True, description="record tensor shapes in the trace" + ) + profile_memory: bool = Field( + True, description="profile memory usage" + ) + with_stack: bool = Field( + False, description="record source code stack traces (increases overhead)" + ) + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + +class UnconnectedToZeroConfig(BaseModel): + """augmentation: add synthetic unconnected neurons with zero activity.""" + num_neurons: int = Field(0, description="number of unconnected neurons to add") + loss_coeff: float = Field(1.0, description="scalar weighting of the loss for unconnected neurons") + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + +class TrainingConfig(BaseModel): + time_units: int = Field( + 1, + description="observation interval: activity data available every n steps. evolver unrolled n times during training.", + json_schema_extra={"short_name": "tu"} + ) + acquisition_mode: AcquisitionMode = Field( + default_factory=AllTimePointsMode, + description="data acquisition mode. controls which timesteps have observable data for each neuron.", + json_schema_extra={"short_name": "acq"} + ) + intermediate_loss_steps: list[int] = Field( + default_factory=list, + description="deprecated: intermediate steps feature has been removed. must be empty list.", + json_schema_extra={"short_name": "ils"} + ) + evolve_multiple_steps: int = Field( + 1, + description="number of time_units multiples to evolve. loss applied at each multiple.", + json_schema_extra={"short_name": "ems"} + ) + epochs: int = Field(10, json_schema_extra={"short_name": "ep"}) + batch_size: int = Field(32, json_schema_extra={"short_name": "bs"}) + learning_rate: float = Field(1e-3, json_schema_extra={"short_name": "lr"}) + optimizer: str = Field("Adam", description="optimizer name from torch.optim", json_schema_extra={"short_name": "opt"}) + train_step: str = Field("train_step", description="compiled train step function") + simulation_config: str + column_to_model: str = "CALCIUM" + use_tf32_matmul: bool = Field( + False, description="enable fast tf32 multiplication on certain nvidia gpus" + ) + seed: int = Field(42, json_schema_extra={"short_name": "seed"}) + data_split: DataSplit + data_passes_per_epoch: int = 1 + diagnostics_freq_epochs: int = Field( + 0, description="run validation diagnostics every n epochs (0 = only at end of training)" + ) + save_checkpoint_every_n_epochs: int = Field( + 10, description="save model checkpoint every n epochs (0 = disabled)" + ) + save_best_checkpoint: bool = Field( + True, description="save checkpoint when validation loss improves" + ) + loss_function: str = Field( + "mse_loss", description="loss function name from torch.nn.functional (e.g., 'mse_loss', 'huber_loss', 'l1_loss')" + ) + grad_clip_max_norm: float = Field( + 0.0, description="max gradient norm for clipping (0 = disabled)", json_schema_extra={"short_name": "gc"} + ) + reconstruction_warmup_epochs: int = Field( + 0, description="number of warmup epochs to train encoder/decoder only (reconstruction loss) before the main training loop. these are additional epochs, not counted in 'epochs'.", json_schema_extra={"short_name": "recon_wu"} + ) + unconnected_to_zero: UnconnectedToZeroConfig = Field(default_factory=lambda: UnconnectedToZeroConfig()) + early_stop_intervening_mse: bool = Field( + False, description="enable early stopping based on max intervening mse metric (0 to tu-1)", json_schema_extra={"short_name": "es_int"} + ) + early_stop_patience_epochs: int = Field( + 10, description="number of epochs to wait for 10% improvement in max intervening mse before stopping", json_schema_extra={"short_name": "es_patience"} + ) + early_stop_min_divergence: int = Field( + 1000, description="minimum first divergence step required for early stopping to activate", json_schema_extra={"short_name": "es_min_div"} + ) + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + @field_validator("optimizer") + @classmethod + def validate_optimizer(cls, v: str) -> str: + if not hasattr(torch.optim, v): + raise ValueError(f"unknown optimizer '{v}' in torch.optim") + return v + + @field_validator("loss_function") + @classmethod + def validate_loss_function(cls, v: str) -> str: + if not hasattr(torch.nn.functional, v): + raise ValueError(f"unknown loss function '{v}' in torch.nn.functional") + return v + + @model_validator(mode='after') + def validate_training_config(self): + if len(self.intermediate_loss_steps) > 0: + raise ValueError("intermediate_loss_steps is deprecated and must be empty list") + if self.evolve_multiple_steps < 1: + raise ValueError("evolve_multiple_steps must be >= 1") + + # validate acquisition mode compatibility + from LatentEvolution.acquisition import StaggeredRandomMode + if isinstance(self.acquisition_mode, StaggeredRandomMode): + if self.unconnected_to_zero.num_neurons > 0: + raise ValueError( + "unconnected_to_zero augmentation is incompatible with staggered_random acquisition mode. " + "staggered mode observes neurons at different times, breaking the connectome assumption." + ) + + return self + + +class CrossValidationConfig(BaseModel): + """configuration for cross-dataset validation.""" + simulation_config: str + name: str | None = None # optional human-readable name + data_split: DataSplit | None = None # data split + + model_config = ConfigDict(extra="forbid", validate_assignment=True) diff --git a/src/LatentEvolution/training_utils.py b/src/LatentEvolution/training_utils.py new file mode 100644 index 00000000..9ce0cb8f --- /dev/null +++ b/src/LatentEvolution/training_utils.py @@ -0,0 +1,62 @@ +""" +shared training utilities for neural dynamics models. + +includes loss accumulation, seeding, and device selection. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict +import random + +import numpy as np +import torch + + +@dataclass +class LossAccumulator: + """generic loss accumulator with enum-keyed dict storage.""" + loss_types: type[Enum] # enum class defining loss component types + components: Dict[Enum, float] = field(init=False) + count: int = field(init=False, default=0) + + def __post_init__(self): + self.components = {lt: 0.0 for lt in self.loss_types} + + def accumulate(self, loss_dict: Dict[Enum, torch.Tensor]) -> None: + """accumulate losses from dict mapping enum -> tensor.""" + for loss_type, value in loss_dict.items(): + self.components[loss_type] += value.detach().item() + self.count += 1 + + def mean(self) -> Dict[Enum, float]: + """return mean of each component.""" + if self.count == 0: + return {k: 0.0 for k in self.components} + return {k: v / self.count for k, v in self.components.items()} + + def __getitem__(self, loss_type: Enum) -> float: + """allow bracket access: losses[LossType.TOTAL]""" + return self.components[loss_type] + + +def seed_everything(seed: int): + """set all random seeds for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def get_device() -> torch.device: + """cross-platform device selection.""" + if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + print("using apple mps backend for training.") + return torch.device("mps") + elif torch.cuda.is_available(): + print(f"using cuda device: {torch.cuda.get_device_name(0)}") + return torch.device("cuda") + else: + print("using cpu for training.") + return torch.device("cpu") diff --git a/src/NeuralGraph/zarr_io.py b/src/NeuralGraph/zarr_io.py index e45b6202..39482c2a 100644 --- a/src/NeuralGraph/zarr_io.py +++ b/src/NeuralGraph/zarr_io.py @@ -15,17 +15,12 @@ import numpy as np import tensorstore as ts -from LatentEvolution.load_flyvis import STATIC_COLUMNS, DYNAMIC_COLUMNS - -# chunking strategies for flyvis data -# data is accessed by column (axis 2), so use small chunks along that axis -# typical access: x_list[:, :, 3] (all time, all neurons, single column) -FLYVIS_CHUNKS = (2000, 14011, 1) # ~112MB per chunk, optimized for column access - -# derive column lists from canonical definitions -STATIC_COLS = [col.value for col in STATIC_COLUMNS] -DYNAMIC_COLS = [col.value for col in DYNAMIC_COLUMNS] +# flyvis format constants (for V2 writer compatibility) +# static columns: INDEX=0, XPOS=1, YPOS=2, GROUP_TYPE=5, TYPE=6 +# dynamic columns: VOLTAGE=3, STIMULUS=4, CALCIUM=7, FLUORESCENCE=8 +STATIC_COLS = [0, 1, 2, 5, 6] +DYNAMIC_COLS = [3, 4, 7, 8] N_STATIC = len(STATIC_COLS) N_DYNAMIC = len(DYNAMIC_COLS) @@ -491,75 +486,6 @@ def load_simulation_data(path: str | Path) -> np.ndarray: return _load_zarr_v2(base_path) -# mapping from dynamic column indices to timeseries column position -_DYNAMIC_COL_TO_TS = {col.value: i for i, col in enumerate(DYNAMIC_COLUMNS)} - - -def load_column_slice( - path: str | Path, - column: int, - time_start: int, - time_end: int, - neuron_limit: int | None = None, -) -> np.ndarray: - """load a time series column slice directly from zarr format. - - this avoids loading the full (T, N, 9) array when you only need one column. - - args: - path: base path to zarr data (without extension) - column: dynamic column index (VOLTAGE=3, STIMULUS=4, CALCIUM=7, FLUORESCENCE=8) - time_start: start time index - time_end: end time index - neuron_limit: optional limit on neurons (first N) - - returns: - numpy array of shape (time_end - time_start, N) or (time_end - time_start, neuron_limit) - - raises: - AssertionError: if column is a static column (use load_metadata instead) - """ - assert column in _DYNAMIC_COL_TO_TS, ( - f"column {column} is static, use load_metadata() instead" - ) - - path = Path(path) - base_path = path.with_suffix('') if path.suffix in ('.npy', '.zarr') else path - - ts_col = _DYNAMIC_COL_TO_TS[column] - neuron_slice = slice(None, neuron_limit) if neuron_limit else slice(None) - - ts_path = base_path / 'timeseries.zarr' - spec = { - 'driver': 'zarr', - 'kvstore': {'driver': 'file', 'path': str(ts_path)}, - } - store = ts.open(spec).result() - data = store[time_start:time_end, neuron_slice, ts_col].read().result() - - return np.ascontiguousarray(data) - - -def load_metadata(path: str | Path) -> np.ndarray: - """load metadata from V2 zarr format. - - args: - path: base path to zarr data - - returns: - numpy array of shape (N, 5) with static columns - """ - path = Path(path) - base_path = path.with_suffix('') if path.suffix in ('.npy', '.zarr') else path - meta_path = base_path / 'metadata.zarr' - - spec = { - 'driver': 'zarr', - 'kvstore': {'driver': 'file', 'path': str(meta_path)}, - } - return ts.open(spec).result().read().result() - - def load_zarr_lazy(path: str | Path) -> ts.TensorStore: """load zarr file as tensorstore handle for lazy access.