Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 38 additions & 38 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,48 @@ name: Lint

on:
push:
branches: [ main ]
branches: [main]
pull_request:
branches: [ main ]
branches: [main]

jobs:
lint:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'

- name: Install ruff
run: pip install ruff

- name: Run ruff check
run: ruff check --exclude *.ipynb .

test:
runs-on: ubuntu-latest
needs: lint

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip'
cache-dependency-path: '.github/workflows/test-requirements.txt'

- name: Install test dependencies
run: pip install -r .github/workflows/test-requirements.txt

- name: Run unit tests
run: |
export PYTHONPATH=src
find src -name "*_test.py" -type f -exec grep -l "from unittest import\|^import unittest" {} \; | xargs -I {} python {}
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install ruff
run: pip install ruff

- name: Run ruff check
run: ruff check --exclude *.ipynb .

# test:
# runs-on: ubuntu-latest
# needs: lint

# steps:
# - name: Checkout code
# uses: actions/checkout@v4

# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.12'
# cache: 'pip'
# cache-dependency-path: '.github/workflows/test-requirements.txt'

# - name: Install test dependencies
# run: pip install -r .github/workflows/test-requirements.txt

# - name: Run unit tests
# run: |
# export PYTHONPATH=src
# find src -name "*_test.py" -type f -exec grep -l "from unittest import\|^import unittest" {} \; | xargs -I {} python {}
27 changes: 27 additions & 0 deletions hooks/pre-commit
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,33 @@ fi

echo -e "${GREEN}File size check passed!${NC}"

# 2. Run ruff linting
echo ""
echo "Running ruff linting..."
if ! ruff check --exclude "*.ipynb" .; then
echo -e "${RED}Error: Ruff linting failed${NC}"
echo "Please fix the linting errors before committing."
echo "You can run 'ruff check --fix --exclude \"*.ipynb\" .' to auto-fix some issues."
echo ""
echo "Note: Stricter rules (including unused variable checks) apply to src/LatentEvolution"
exit 1
fi
echo "All checks passed!"
echo -e "${GREEN}Ruff linting passed!${NC}"

# 3. Run unit tests in parallel
echo ""
echo "Running unit tests..."

# Run tests in parallel using python's unittest with buffer to capture output
if python -m unittest discover -s src -p "*_test.py" -v; then
echo -e "${GREEN}Unit tests passed!${NC}"
else
echo -e "${RED}Error: Unit tests failed${NC}"
echo "Please fix failing tests before committing."
exit 1
fi

echo ""
echo -e "${GREEN}All pre-commit checks passed!${NC}"
exit 0
117 changes: 89 additions & 28 deletions src/LatentEvolution/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from LatentEvolution.latent import ModelParams, LatentModel
from LatentEvolution.load_flyvis import NeuronData

from LatentEvolution.training_config import StimulusFrequency
from LatentEvolution.stimulus_utils import downsample_stimulus


class PlotMode(StrEnum):
"""Control plotting that happens during training vs post training"""
Expand Down Expand Up @@ -266,6 +269,7 @@ def compute_multi_start_rollout_mse(
plot_mode: PlotMode = PlotMode.TRAINING,
time_units: int = 1,
evolve_multiple_steps: int = 1,
stimulus_frequency: StimulusFrequency = StimulusFrequency.ALL,
) -> tuple[np.ndarray, dict[str, plt.Figure], dict[str, float], list[int]]:
"""
Compute MSE over time from multiple random starting points.
Expand Down Expand Up @@ -330,9 +334,13 @@ def compute_multi_start_rollout_mse(

# Perform rollout
if rollout_type == "latent":
predicted_segment = evolve_n_steps_latent(model, initial_state, stimulus_segment, n_steps)
predicted_segment = evolve_n_steps_latent(
model, initial_state, stimulus_segment, n_steps, stimulus_frequency, time_units
)
else: # activity
predicted_segment = evolve_n_steps(model, initial_state, stimulus_segment, n_steps)
predicted_segment = evolve_n_steps_activity(
model, initial_state, stimulus_segment, n_steps, stimulus_frequency, time_units
)

# Compute MSE per time step per neuron (squared error, not averaged)
squared_error = torch.pow(predicted_segment - real_segment, 2).detach().cpu().numpy()
Expand Down Expand Up @@ -495,6 +503,7 @@ def plot_time_aligned_mse(
ax.set_xlabel('time steps', fontsize=14)
ax.set_ylabel('mse (averaged over neurons)', fontsize=14)
ax.set_yscale('log')
ax.set_ylim(1e-3, 1.0)
ax.set_title(
f'time-aligned mse analysis - {rollout_type} rollout (tu={time_units}, ems={evolve_multiple_steps})',
fontsize=16,
Expand Down Expand Up @@ -550,7 +559,8 @@ def run_validation_diagnostics(
with torch.no_grad():
mse_array, new_figs, new_metrics, start_indices = compute_multi_start_rollout_mse(
model, val_data, val_stim, neuron_data, n_steps=2000, n_starts=20, rollout_type=rollout_type,
plot_mode=plot_mode, time_units=time_units, evolve_multiple_steps=evolve_multiple_steps
plot_mode=plot_mode, time_units=time_units, evolve_multiple_steps=evolve_multiple_steps,
stimulus_frequency=config.training.stimulus_frequency
)
metrics.update(new_metrics)
figures.update(new_figs)
Expand Down Expand Up @@ -612,52 +622,103 @@ def run_validation_diagnostics(
return metrics, figures


def evolve_n_steps(model: LatentModel, initial_state: torch.Tensor, stimulus: torch.Tensor, n_steps: int) -> torch.Tensor:
def evolve_n_steps_activity(
model: LatentModel,
initial_state: torch.Tensor,
stimulus: torch.Tensor,
n_steps: int,
stimulus_frequency: StimulusFrequency,
time_units: int,
) -> torch.Tensor:
"""
Evolve the model by n time steps using the predicted state at each step.
evolve the model by n time steps using the predicted state at each step.

This performs an autoregressive rollout starting from a single initial state.
At each step, encodes the current state, evolves in latent space, and decodes.
this performs an autoregressive rollout starting from a single initial state.
at each step, encodes the current state, evolves in latent space, and decodes.

Args:
model: The LatentModel to evolve
initial_state: Initial state tensor of shape (neurons,)
stimulus: Stimulus tensor of shape (T, stimulus_dim) where T >= n_steps
n_steps: Number of time steps to evolve
args:
model: the latentmodel to evolve
initial_state: initial state tensor of shape (neurons,)
stimulus: stimulus tensor of shape (T, stimulus_dim) where T >= n_steps
n_steps: number of time steps to evolve
stimulus_frequency: stimulus downsampling mode
time_units: observation interval for downsampling

Returns:
predicted_trace: Tensor of shape (n_steps, neurons) with predicted states
returns:
predicted_trace: tensor of shape (n_steps, neurons) with predicted states
"""
# pre-encode and downsample stimulus
stimulus_latent_all = model.stimulus_encoder(stimulus[:n_steps]) # shape (n_steps, stim_latent_dim)
stimulus_latent_all = stimulus_latent_all.unsqueeze(1) # (n_steps, 1, stim_latent_dim)

num_multiples = max(1, n_steps // time_units)
stimulus_latent = downsample_stimulus(
stimulus_latent_all,
tu=time_units,
num_multiples=num_multiples,
stimulus_frequency=stimulus_frequency,
)
stimulus_latent = stimulus_latent.squeeze(1) # (n_steps, stim_latent_dim)

predicted_trace = []
current_state = initial_state.unsqueeze(0) # shape (1, neurons)

for t in range(n_steps):
current_stimulus = stimulus[t:t+1] # shape (1, stimulus_dim)
next_state = model(current_state, current_stimulus)
# encode current state
current_latent = model.encoder(current_state)
# evolve with downsampled stimulus
next_latent = model.evolver(current_latent, stimulus_latent[t:t+1])
# decode
next_state = model.decoder(next_latent)
predicted_trace.append(next_state.squeeze(0))
current_state = next_state

return torch.stack(predicted_trace, dim=0)


def evolve_n_steps_latent(model: LatentModel, initial_state: torch.Tensor, stimulus: torch.Tensor, n_steps: int) -> torch.Tensor:
def evolve_n_steps_latent(
model: LatentModel,
initial_state: torch.Tensor,
stimulus: torch.Tensor,
n_steps: int,
stimulus_frequency: StimulusFrequency,
time_units: int,
) -> torch.Tensor:
"""
Evolve the model by n time steps entirely in latent space.
evolve the model by n time steps entirely in latent space.

This performs an autoregressive rollout starting from a single initial state.
Encodes once, evolves in latent space for n steps, then decodes all states.
this performs an autoregressive rollout starting from a single initial state.
encodes once, evolves in latent space for n steps, then decodes all states.

Args:
model: The LatentModel to evolve
initial_state: Initial state tensor of shape (neurons,)
stimulus: Stimulus tensor of shape (T, stimulus_dim) where T >= n_steps
n_steps: Number of time steps to evolve
args:
model: the latentmodel to evolve
initial_state: initial state tensor of shape (neurons,)
stimulus: stimulus tensor of shape (T, stimulus_dim) where T >= n_steps
n_steps: number of time steps to evolve
stimulus_frequency: stimulus downsampling mode
time_units: observation interval for downsampling

Returns:
predicted_trace: Tensor of shape (n_steps, neurons) with predicted states
returns:
predicted_trace: tensor of shape (n_steps, neurons) with predicted states
"""
current_latent = model.encoder(initial_state.unsqueeze(0)) # shape (1, latent_dim)
stimulus_latent = model.stimulus_encoder(stimulus) # shape (n_steps, stim_latent_dim)
stimulus_latent_all = model.stimulus_encoder(stimulus[:n_steps]) # shape (n_steps, stim_latent_dim)

# downsample stimulus based on frequency mode
# need to add batch dimension for downsample_stimulus
stimulus_latent_all = stimulus_latent_all.unsqueeze(1) # (n_steps, 1, stim_latent_dim)

# calculate num_multiples for downsampling
num_multiples = max(1, n_steps // time_units)

stimulus_latent = downsample_stimulus(
stimulus_latent_all,
tu=time_units,
num_multiples=num_multiples,
stimulus_frequency=stimulus_frequency,
) # (n_steps, 1, stim_latent_dim)

stimulus_latent = stimulus_latent.squeeze(1) # (n_steps, stim_latent_dim)

latent_trace = []
for t in range(n_steps):
Expand Down
38 changes: 38 additions & 0 deletions src/LatentEvolution/experiments/flyvis_voltage_100ms.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,16 @@ for zero_init in zero-init no-zero-init; do \
done
```

Results:

- it turns out the evolver initialization is the key change. This removes the blow up
and allows the system to learn the right update rule.
- warmup and activation don't seem to make a significant difference once the evolver
is correctly initialized.
- warmup does two nice things
- lower mse at intervening time steps
- training after epoch 1 is already stable in terms of rollout

Understand if TV norm can bring stability to the training without pretraining for
reconstruction or the other features we added.

Expand All @@ -354,3 +364,31 @@ for tv in 0.0 0.00001 0.0001 0.001; do \
--training.seed 97651
done
```

Results:

- tv norm at 1e-3 successfully mitigates the artifact
- the mse at t->t+1 is closer to a constant than from the previous experiment. So
while tv norm does help the training converge to a sensible evolution model, it
does harm the t->t+1 mse.

# Stimulus downsampling

We want to avoid depending on the details of the stimulus provided since in general
it won't be known with such granularity. As a first step, we only provide the
stimulus every `tu` steps. At time step `n <= t < n + tu` we linearly interpolate between
the stimulus at time `n` and the one at time `n+tu`.

```bash
for mode in TIME_UNITS_INTERPOLATE TIME_UNITS_CONSTANT NONE; do \
bsub -J stim_${mode} -q gpu_a100 -gpu "num=1" -n 8 -o stim_${mode}.log \
python src/LatentEvolution/latent.py stim_freq_sweep latent_20step.yaml \
--training.stimulus-frequency $mode
done
```

The results suggest that in the current setup we are critically reliant on the stimulus
being provided at every time step. When it is not, we see a blow up in the MSE when we
roll out past the training horizon. Even within the training horizon the error does not
fall below the linear interpolation baseline. And we are unable to learn the right
rule even at the loss time points 0, `tu`, `2tu`, ...
13 changes: 10 additions & 3 deletions src/LatentEvolution/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
load_metadata,
)
from LatentEvolution.training_config import DataSplit
from LatentEvolution.stimulus_utils import downsample_stimulus
from LatentEvolution.chunk_loader import RandomChunkLoader
from LatentEvolution.gpu_stats import GPUMonitor
from LatentEvolution.diagnostics import run_validation_diagnostics, PlotMode
Expand Down Expand Up @@ -307,8 +308,6 @@ def make_batches_random(
yield start_indices, selected_neurons, needed_indices




def train_step_reconstruction_only_nocompile(
model: LatentModel,
train_data: torch.Tensor,
Expand Down Expand Up @@ -417,7 +416,15 @@ def train_step_nocompile(
dim_stim = train_stim.shape[1]
dim_stim_latent = cfg.stimulus_encoder_params.num_output_dims
# total_steps x b x Ls
proj_stim_t = model.stimulus_encoder(stim_t.reshape((-1, dim_stim))).reshape((total_steps, -1, dim_stim_latent))
proj_stim_t_all = model.stimulus_encoder(stim_t.reshape((-1, dim_stim))).reshape((total_steps, -1, dim_stim_latent))

# downsample stimulus based on frequency mode
proj_stim_t = downsample_stimulus(
proj_stim_t_all,
tu=dt,
num_multiples=num_multiples,
stimulus_frequency=cfg.training.stimulus_frequency,
)

# reconstruction loss
recon_t = model.decoder(proj_t)
Expand Down
Loading