Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/LatentEvolution/benchmark_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from LatentEvolution.load_flyvis import FlyVisSim
from LatentEvolution.latent import ModelParams, LatentModel, train_step, train_step_nocompile
from LatentEvolution.acquisition import compute_neuron_phases, sample_batch_indices
from NeuralGraph.zarr_io import load_column_slice
from LatentEvolution.load_flyvis import load_column_slice


def seed_everything(seed: int):
Expand Down
2 changes: 1 addition & 1 deletion src/LatentEvolution/chunk_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import torch

from NeuralGraph.zarr_io import load_column_slice
from LatentEvolution.load_flyvis import load_column_slice


# -------------------------------------------------------------------
Expand Down
86 changes: 85 additions & 1 deletion src/LatentEvolution/eed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import torch
import torch.nn as nn
from pydantic import BaseModel, Field, field_validator, ConfigDict
from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict

from LatentEvolution.training_config import TrainingConfig, ProfileConfig, CrossValidationConfig


# -------------------------------------------------------------------
Expand Down Expand Up @@ -225,3 +227,85 @@ def forward(self, proj_t, proj_stim_t):
"""Evolve one time step in latent space."""
proj_t_next = proj_t + self.evolver(torch.cat([proj_t, proj_stim_t], dim=1))
return proj_t_next


# -------------------------------------------------------------------
# Model Configuration
# -------------------------------------------------------------------


class ModelParams(BaseModel):
latent_dims: int = Field(..., json_schema_extra={"short_name": "ld"})
num_neurons: int
use_batch_norm: bool = True
activation: str = Field("ReLU", description="activation function from torch.nn")
encoder_params: EncoderParams
decoder_params: DecoderParams
evolver_params: EvolverParams
stimulus_encoder_params: StimulusEncoderParams
training: TrainingConfig
profiling: ProfileConfig | None = Field(
None, description="optional profiler configuration to generate chrome traces for performance analysis"
)
cross_validation_configs: list[CrossValidationConfig] = Field(
default_factory=lambda: [CrossValidationConfig(simulation_config="fly_N9_62_0")],
description="list of datasets to validate on after training"
)

model_config = ConfigDict(extra="forbid", validate_assignment=True)

@field_validator("activation")
@classmethod
def validate_activation(cls, v: str) -> str:
if not hasattr(nn, v):
raise ValueError(f"unknown activation '{v}' in torch.nn")
return v

@model_validator(mode='after')
def validate_encoder_decoder_symmetry(self):
"""ensure encoder and decoder have symmetric mlp parameters."""
if self.encoder_params.num_hidden_units != self.decoder_params.num_hidden_units:
raise ValueError(
f"encoder and decoder must have the same num_hidden_units. "
f"got encoder={self.encoder_params.num_hidden_units}, decoder={self.decoder_params.num_hidden_units}"
)
if self.encoder_params.num_hidden_layers != self.decoder_params.num_hidden_layers:
raise ValueError(
f"encoder and decoder must have the same num_hidden_layers. "
f"got encoder={self.encoder_params.num_hidden_layers}, decoder={self.decoder_params.num_hidden_layers}"
)
if self.encoder_params.use_input_skips != self.decoder_params.use_input_skips:
raise ValueError(
f"encoder and decoder must have the same use_input_skips setting. "
f"got encoder={self.encoder_params.use_input_skips}, decoder={self.decoder_params.use_input_skips}"
)
return self

def flatten(self, sep: str = ".") -> dict[str, int | float | str | bool]:
"""
flatten the modelparams into a single-level dictionary.

args:
sep: separator to use for nested keys (default: ".")

returns:
a flat dictionary with nested keys joined by the separator.

example:
>>> params.flatten()
{'latent_dims': 10, 'encoder_params.num_hidden_units': 64, ...}
"""
def _flatten_dict(
d: dict[str, int | float | str | bool | dict],
parent_key: str = "",
) -> dict[str, int | float | str | bool]:
items: list[tuple[str, int | float | str | bool]] = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(_flatten_dict(v, new_key).items())
else:
items.append((new_key, v))
return dict(items)

return _flatten_dict(self.model_dump())
Loading