diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 37d502a0..65707ca8 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -2,48 +2,48 @@ name: Lint on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] jobs: lint: runs-on: ubuntu-latest steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - - - name: Install ruff - run: pip install ruff - - - name: Run ruff check - run: ruff check --exclude *.ipynb . - - test: - runs-on: ubuntu-latest - needs: lint - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - cache: 'pip' - cache-dependency-path: '.github/workflows/test-requirements.txt' - - - name: Install test dependencies - run: pip install -r .github/workflows/test-requirements.txt - - - name: Run unit tests - run: | - export PYTHONPATH=src - find src -name "*_test.py" -type f -exec grep -l "from unittest import\|^import unittest" {} \; | xargs -I {} python {} + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install ruff + run: pip install ruff + + - name: Run ruff check + run: ruff check --exclude *.ipynb . + + # test: + # runs-on: ubuntu-latest + # needs: lint + + # steps: + # - name: Checkout code + # uses: actions/checkout@v4 + + # - name: Set up Python + # uses: actions/setup-python@v5 + # with: + # python-version: '3.12' + # cache: 'pip' + # cache-dependency-path: '.github/workflows/test-requirements.txt' + + # - name: Install test dependencies + # run: pip install -r .github/workflows/test-requirements.txt + + # - name: Run unit tests + # run: | + # export PYTHONPATH=src + # find src -name "*_test.py" -type f -exec grep -l "from unittest import\|^import unittest" {} \; | xargs -I {} python {} diff --git a/hooks/pre-commit b/hooks/pre-commit index 9d6d81b2..f2cf6336 100755 --- a/hooks/pre-commit +++ b/hooks/pre-commit @@ -52,6 +52,33 @@ fi echo -e "${GREEN}File size check passed!${NC}" +# 2. Run ruff linting +echo "" +echo "Running ruff linting..." +if ! ruff check --exclude "*.ipynb" .; then + echo -e "${RED}Error: Ruff linting failed${NC}" + echo "Please fix the linting errors before committing." + echo "You can run 'ruff check --fix --exclude \"*.ipynb\" .' to auto-fix some issues." + echo "" + echo "Note: Stricter rules (including unused variable checks) apply to src/LatentEvolution" + exit 1 +fi +echo "All checks passed!" +echo -e "${GREEN}Ruff linting passed!${NC}" + +# 3. Run unit tests in parallel +echo "" +echo "Running unit tests..." + +# Run tests in parallel using python's unittest with buffer to capture output +if python -m unittest discover -s src -p "*_test.py" -v; then + echo -e "${GREEN}Unit tests passed!${NC}" +else + echo -e "${RED}Error: Unit tests failed${NC}" + echo "Please fix failing tests before committing." + exit 1 +fi + echo "" echo -e "${GREEN}All pre-commit checks passed!${NC}" exit 0 diff --git a/src/LatentEvolution/test_chunk_loader.py b/src/LatentEvolution/chunk_loader_test.py similarity index 100% rename from src/LatentEvolution/test_chunk_loader.py rename to src/LatentEvolution/chunk_loader_test.py diff --git a/src/LatentEvolution/diagnostics.py b/src/LatentEvolution/diagnostics.py index 1a892cb5..66ec82b3 100644 --- a/src/LatentEvolution/diagnostics.py +++ b/src/LatentEvolution/diagnostics.py @@ -19,6 +19,9 @@ from LatentEvolution.latent import ModelParams, LatentModel from LatentEvolution.load_flyvis import NeuronData +from LatentEvolution.training_config import StimulusFrequency +from LatentEvolution.stimulus_utils import downsample_stimulus + class PlotMode(StrEnum): """Control plotting that happens during training vs post training""" @@ -266,6 +269,7 @@ def compute_multi_start_rollout_mse( plot_mode: PlotMode = PlotMode.TRAINING, time_units: int = 1, evolve_multiple_steps: int = 1, + stimulus_frequency: StimulusFrequency = StimulusFrequency.ALL, ) -> tuple[np.ndarray, dict[str, plt.Figure], dict[str, float], list[int]]: """ Compute MSE over time from multiple random starting points. @@ -330,9 +334,13 @@ def compute_multi_start_rollout_mse( # Perform rollout if rollout_type == "latent": - predicted_segment = evolve_n_steps_latent(model, initial_state, stimulus_segment, n_steps) + predicted_segment = evolve_n_steps_latent( + model, initial_state, stimulus_segment, n_steps, stimulus_frequency, time_units + ) else: # activity - predicted_segment = evolve_n_steps(model, initial_state, stimulus_segment, n_steps) + predicted_segment = evolve_n_steps_activity( + model, initial_state, stimulus_segment, n_steps, stimulus_frequency, time_units + ) # Compute MSE per time step per neuron (squared error, not averaged) squared_error = torch.pow(predicted_segment - real_segment, 2).detach().cpu().numpy() @@ -495,6 +503,7 @@ def plot_time_aligned_mse( ax.set_xlabel('time steps', fontsize=14) ax.set_ylabel('mse (averaged over neurons)', fontsize=14) ax.set_yscale('log') + ax.set_ylim(1e-3, 1.0) ax.set_title( f'time-aligned mse analysis - {rollout_type} rollout (tu={time_units}, ems={evolve_multiple_steps})', fontsize=16, @@ -550,7 +559,8 @@ def run_validation_diagnostics( with torch.no_grad(): mse_array, new_figs, new_metrics, start_indices = compute_multi_start_rollout_mse( model, val_data, val_stim, neuron_data, n_steps=2000, n_starts=20, rollout_type=rollout_type, - plot_mode=plot_mode, time_units=time_units, evolve_multiple_steps=evolve_multiple_steps + plot_mode=plot_mode, time_units=time_units, evolve_multiple_steps=evolve_multiple_steps, + stimulus_frequency=config.training.stimulus_frequency ) metrics.update(new_metrics) figures.update(new_figs) @@ -612,52 +622,103 @@ def run_validation_diagnostics( return metrics, figures -def evolve_n_steps(model: LatentModel, initial_state: torch.Tensor, stimulus: torch.Tensor, n_steps: int) -> torch.Tensor: +def evolve_n_steps_activity( + model: LatentModel, + initial_state: torch.Tensor, + stimulus: torch.Tensor, + n_steps: int, + stimulus_frequency: StimulusFrequency, + time_units: int, +) -> torch.Tensor: """ - Evolve the model by n time steps using the predicted state at each step. + evolve the model by n time steps using the predicted state at each step. - This performs an autoregressive rollout starting from a single initial state. - At each step, encodes the current state, evolves in latent space, and decodes. + this performs an autoregressive rollout starting from a single initial state. + at each step, encodes the current state, evolves in latent space, and decodes. - Args: - model: The LatentModel to evolve - initial_state: Initial state tensor of shape (neurons,) - stimulus: Stimulus tensor of shape (T, stimulus_dim) where T >= n_steps - n_steps: Number of time steps to evolve + args: + model: the latentmodel to evolve + initial_state: initial state tensor of shape (neurons,) + stimulus: stimulus tensor of shape (T, stimulus_dim) where T >= n_steps + n_steps: number of time steps to evolve + stimulus_frequency: stimulus downsampling mode + time_units: observation interval for downsampling - Returns: - predicted_trace: Tensor of shape (n_steps, neurons) with predicted states + returns: + predicted_trace: tensor of shape (n_steps, neurons) with predicted states """ + # pre-encode and downsample stimulus + stimulus_latent_all = model.stimulus_encoder(stimulus[:n_steps]) # shape (n_steps, stim_latent_dim) + stimulus_latent_all = stimulus_latent_all.unsqueeze(1) # (n_steps, 1, stim_latent_dim) + + num_multiples = max(1, n_steps // time_units) + stimulus_latent = downsample_stimulus( + stimulus_latent_all, + tu=time_units, + num_multiples=num_multiples, + stimulus_frequency=stimulus_frequency, + ) + stimulus_latent = stimulus_latent.squeeze(1) # (n_steps, stim_latent_dim) + predicted_trace = [] current_state = initial_state.unsqueeze(0) # shape (1, neurons) for t in range(n_steps): - current_stimulus = stimulus[t:t+1] # shape (1, stimulus_dim) - next_state = model(current_state, current_stimulus) + # encode current state + current_latent = model.encoder(current_state) + # evolve with downsampled stimulus + next_latent = model.evolver(current_latent, stimulus_latent[t:t+1]) + # decode + next_state = model.decoder(next_latent) predicted_trace.append(next_state.squeeze(0)) current_state = next_state return torch.stack(predicted_trace, dim=0) -def evolve_n_steps_latent(model: LatentModel, initial_state: torch.Tensor, stimulus: torch.Tensor, n_steps: int) -> torch.Tensor: +def evolve_n_steps_latent( + model: LatentModel, + initial_state: torch.Tensor, + stimulus: torch.Tensor, + n_steps: int, + stimulus_frequency: StimulusFrequency, + time_units: int, +) -> torch.Tensor: """ - Evolve the model by n time steps entirely in latent space. + evolve the model by n time steps entirely in latent space. - This performs an autoregressive rollout starting from a single initial state. - Encodes once, evolves in latent space for n steps, then decodes all states. + this performs an autoregressive rollout starting from a single initial state. + encodes once, evolves in latent space for n steps, then decodes all states. - Args: - model: The LatentModel to evolve - initial_state: Initial state tensor of shape (neurons,) - stimulus: Stimulus tensor of shape (T, stimulus_dim) where T >= n_steps - n_steps: Number of time steps to evolve + args: + model: the latentmodel to evolve + initial_state: initial state tensor of shape (neurons,) + stimulus: stimulus tensor of shape (T, stimulus_dim) where T >= n_steps + n_steps: number of time steps to evolve + stimulus_frequency: stimulus downsampling mode + time_units: observation interval for downsampling - Returns: - predicted_trace: Tensor of shape (n_steps, neurons) with predicted states + returns: + predicted_trace: tensor of shape (n_steps, neurons) with predicted states """ current_latent = model.encoder(initial_state.unsqueeze(0)) # shape (1, latent_dim) - stimulus_latent = model.stimulus_encoder(stimulus) # shape (n_steps, stim_latent_dim) + stimulus_latent_all = model.stimulus_encoder(stimulus[:n_steps]) # shape (n_steps, stim_latent_dim) + + # downsample stimulus based on frequency mode + # need to add batch dimension for downsample_stimulus + stimulus_latent_all = stimulus_latent_all.unsqueeze(1) # (n_steps, 1, stim_latent_dim) + + # calculate num_multiples for downsampling + num_multiples = max(1, n_steps // time_units) + + stimulus_latent = downsample_stimulus( + stimulus_latent_all, + tu=time_units, + num_multiples=num_multiples, + stimulus_frequency=stimulus_frequency, + ) # (n_steps, 1, stim_latent_dim) + + stimulus_latent = stimulus_latent.squeeze(1) # (n_steps, stim_latent_dim) latent_trace = [] for t in range(n_steps): diff --git a/src/LatentEvolution/experiments/flyvis_voltage_100ms.md b/src/LatentEvolution/experiments/flyvis_voltage_100ms.md index 55f15d38..79f25174 100644 --- a/src/LatentEvolution/experiments/flyvis_voltage_100ms.md +++ b/src/LatentEvolution/experiments/flyvis_voltage_100ms.md @@ -339,6 +339,16 @@ for zero_init in zero-init no-zero-init; do \ done ``` +Results: + +- it turns out the evolver initialization is the key change. This removes the blow up + and allows the system to learn the right update rule. +- warmup and activation don't seem to make a significant difference once the evolver + is correctly initialized. +- warmup does two nice things + - lower mse at intervening time steps + - training after epoch 1 is already stable in terms of rollout + Understand if TV norm can bring stability to the training without pretraining for reconstruction or the other features we added. @@ -354,3 +364,31 @@ for tv in 0.0 0.00001 0.0001 0.001; do \ --training.seed 97651 done ``` + +Results: + +- tv norm at 1e-3 successfully mitigates the artifact +- the mse at t->t+1 is closer to a constant than from the previous experiment. So + while tv norm does help the training converge to a sensible evolution model, it + does harm the t->t+1 mse. + +# Stimulus downsampling + +We want to avoid depending on the details of the stimulus provided since in general +it won't be known with such granularity. As a first step, we only provide the +stimulus every `tu` steps. At time step `n <= t < n + tu` we linearly interpolate between +the stimulus at time `n` and the one at time `n+tu`. + +```bash +for mode in TIME_UNITS_INTERPOLATE TIME_UNITS_CONSTANT NONE; do \ + bsub -J stim_${mode} -q gpu_a100 -gpu "num=1" -n 8 -o stim_${mode}.log \ + python src/LatentEvolution/latent.py stim_freq_sweep latent_20step.yaml \ + --training.stimulus-frequency $mode +done +``` + +The results suggest that in the current setup we are critically reliant on the stimulus +being provided at every time step. When it is not, we see a blow up in the MSE when we +roll out past the training horizon. Even within the training horizon the error does not +fall below the linear interpolation baseline. And we are unable to learn the right +rule even at the loss time points 0, `tu`, `2tu`, ... diff --git a/src/LatentEvolution/latent.py b/src/LatentEvolution/latent.py index ad9f9217..f4e879c5 100644 --- a/src/LatentEvolution/latent.py +++ b/src/LatentEvolution/latent.py @@ -24,6 +24,7 @@ load_metadata, ) from LatentEvolution.training_config import DataSplit +from LatentEvolution.stimulus_utils import downsample_stimulus from LatentEvolution.chunk_loader import RandomChunkLoader from LatentEvolution.gpu_stats import GPUMonitor from LatentEvolution.diagnostics import run_validation_diagnostics, PlotMode @@ -307,8 +308,6 @@ def make_batches_random( yield start_indices, selected_neurons, needed_indices - - def train_step_reconstruction_only_nocompile( model: LatentModel, train_data: torch.Tensor, @@ -417,7 +416,15 @@ def train_step_nocompile( dim_stim = train_stim.shape[1] dim_stim_latent = cfg.stimulus_encoder_params.num_output_dims # total_steps x b x Ls - proj_stim_t = model.stimulus_encoder(stim_t.reshape((-1, dim_stim))).reshape((total_steps, -1, dim_stim_latent)) + proj_stim_t_all = model.stimulus_encoder(stim_t.reshape((-1, dim_stim))).reshape((total_steps, -1, dim_stim_latent)) + + # downsample stimulus based on frequency mode + proj_stim_t = downsample_stimulus( + proj_stim_t_all, + tu=dt, + num_multiples=num_multiples, + stimulus_frequency=cfg.training.stimulus_frequency, + ) # reconstruction loss recon_t = model.decoder(proj_t) diff --git a/src/LatentEvolution/stimulus_utils.py b/src/LatentEvolution/stimulus_utils.py new file mode 100644 index 00000000..057bae2a --- /dev/null +++ b/src/LatentEvolution/stimulus_utils.py @@ -0,0 +1,112 @@ +""" +utilities for stimulus processing and downsampling. +""" + +import torch + +from LatentEvolution.training_config import StimulusFrequency + + +def downsample_stimulus( + proj_stim_t: torch.Tensor, + tu: int, + num_multiples: int, + stimulus_frequency: StimulusFrequency, +) -> torch.Tensor: + """ + downsample encoded stimulus based on frequency mode. + + args: + proj_stim_t: (total_steps, b, dim_stim_latent) encoded stimulus at all time points + tu: time_units (observation interval) + num_multiples: evolve_multiple_steps + stimulus_frequency: StimulusFrequency enum mode + + returns: + downsampled_proj_stim_t: (total_steps, b, dim_stim_latent) downsampled stimulus + """ + total_steps, batch_size, dim_stim_latent = proj_stim_t.shape + device = proj_stim_t.device + + if stimulus_frequency == StimulusFrequency.ALL: + # use all time points (no downsampling) + return proj_stim_t + + elif stimulus_frequency == StimulusFrequency.NONE: + # no stimulus: return zeros + return torch.zeros_like(proj_stim_t) + + elif stimulus_frequency == StimulusFrequency.TIME_UNITS_CONSTANT: + # sample at 0, tu, 2*tu, ..., and hold constant + # each stimulus at time t is used for interval [t-tu/2, t+tu/2) + # i.e., centered around its time point + num_samples = min(num_multiples + 1, (total_steps + tu - 1) // tu) + sample_indices = torch.arange(num_samples, device=device) * tu # [0, tu, 2*tu, ...] + proj_samples = proj_stim_t[sample_indices, :, :] # (num_samples, b, dim_stim_latent) + + # build output by assigning each sample to its centered interval + half_tu = tu // 2 + downsampled = torch.zeros(total_steps, batch_size, dim_stim_latent, device=device, dtype=proj_stim_t.dtype) + + for step in range(total_steps): + # find which sample this step belongs to + # step is in interval [t-tu/2, t+tu/2) centered at t + # so we find the nearest sample point + nearest_sample_idx = min((step + half_tu) // tu, num_samples - 1) + downsampled[step] = proj_samples[nearest_sample_idx] + + return downsampled + + elif stimulus_frequency == StimulusFrequency.TIME_UNITS_INTERPOLATE: + # sample at time unit boundaries and linearly interpolate + # check if we have enough data for the final boundary point + final_boundary_idx = num_multiples * tu + + if final_boundary_idx >= total_steps: + # not enough data for final boundary - use constant mode for last interval + # this happens when total_steps = num_multiples * tu (no extra boundary point) + sample_indices = torch.arange(num_multiples, device=device) * tu # [0, tu, 2*tu, ..., (num_multiples-1)*tu] + proj_samples = proj_stim_t[sample_indices, :, :] # (num_multiples, b, dim_stim_latent) + + proj_stim_list = [] + # interpolate for all intervals except the last + for m in range(num_multiples - 1): + start_proj = proj_samples[m] + end_proj = proj_samples[m + 1] + weights = torch.linspace(0, 1, tu + 1, device=device)[:-1] + for w in weights: + interp = (1 - w) * start_proj + w * end_proj + proj_stim_list.append(interp) + + # for the last interval, hold constant (no boundary point to interpolate to) + last_proj = proj_samples[-1] + remaining_steps = total_steps - len(proj_stim_list) + for _ in range(remaining_steps): + proj_stim_list.append(last_proj) + + downsampled = torch.stack(proj_stim_list, dim=0) + else: + # we have enough data for full interpolation including final boundary + sample_indices = torch.arange(num_multiples + 1, device=device) * tu + proj_samples = proj_stim_t[sample_indices, :, :] # (num_samples, b, dim_stim_latent) + + # interpolate in latent space between consecutive samples + proj_stim_list = [] + for m in range(num_multiples): + start_proj = proj_samples[m] # (b, dim_stim_latent) + end_proj = proj_samples[m + 1] # (b, dim_stim_latent) + + # linear interpolation weights for tu steps + # exclude 1.0 to avoid duplicates at boundaries + weights = torch.linspace(0, 1, tu + 1, device=device)[:-1] # (tu,) + + for w in weights: + interp = (1 - w) * start_proj + w * end_proj # (b, dim_stim_latent) + proj_stim_list.append(interp) + + downsampled = torch.stack(proj_stim_list, dim=0) # (total_steps, b, dim_stim_latent) + + return downsampled + + else: + raise ValueError(f"unknown stimulus frequency: {stimulus_frequency}") diff --git a/src/LatentEvolution/stimulus_utils_test.py b/src/LatentEvolution/stimulus_utils_test.py new file mode 100644 index 00000000..35307bbe --- /dev/null +++ b/src/LatentEvolution/stimulus_utils_test.py @@ -0,0 +1,271 @@ +""" +unit tests for stimulus_utils module. +""" + +import unittest +import torch + +from LatentEvolution.stimulus_utils import downsample_stimulus +from LatentEvolution.training_config import StimulusFrequency + + +class TestDownsampleStimulus(unittest.TestCase): + """tests for downsample_stimulus function.""" + + def setUp(self): + """set up test fixtures.""" + self.device = torch.device("cpu") + # (total_steps=20, batch_size=2, dim_stim_latent=5) + self.sample_stimulus = torch.randn(20, 2, 5, device=self.device) + + def test_all_mode_returns_unchanged(self): + """test that ALL mode returns the input unchanged.""" + result = downsample_stimulus( + self.sample_stimulus, + tu=5, + num_multiples=4, + stimulus_frequency=StimulusFrequency.ALL, + ) + self.assertTrue(torch.equal(result, self.sample_stimulus)) + self.assertEqual(result.shape, self.sample_stimulus.shape) + + def test_none_mode_returns_zeros(self): + """test that NONE mode returns zeros with same shape.""" + result = downsample_stimulus( + self.sample_stimulus, + tu=5, + num_multiples=4, + stimulus_frequency=StimulusFrequency.NONE, + ) + self.assertTrue(torch.all(result == 0)) + self.assertEqual(result.shape, self.sample_stimulus.shape) + + def test_constant_mode_shape(self): + """test that TIME_UNITS_CONSTANT mode returns correct shape.""" + result = downsample_stimulus( + self.sample_stimulus, + tu=5, + num_multiples=4, + stimulus_frequency=StimulusFrequency.TIME_UNITS_CONSTANT, + ) + self.assertEqual(result.shape, self.sample_stimulus.shape) + + def test_constant_mode_values_held(self): + """test that TIME_UNITS_CONSTANT mode holds values constant within centered intervals.""" + # create known stimulus: each sample has distinct value + # with tu=10, samples at t=0, 10, 20, 30 + proj_stim = torch.zeros(40, 1, 3, device=self.device) + proj_stim[0, :, :] = 1.0 # sample at t=0 + proj_stim[10, :, :] = 2.0 # sample at t=10 + proj_stim[20, :, :] = 3.0 # sample at t=20 + proj_stim[30, :, :] = 4.0 # sample at t=30 + + result = downsample_stimulus( + proj_stim, + tu=10, + num_multiples=4, + stimulus_frequency=StimulusFrequency.TIME_UNITS_CONSTANT, + ) + + # each sample is centered around its time point + # sample at t=0: used for [0, 5) (first sample starts at 0) + # sample at t=10: used for [5, 15) + # sample at t=20: used for [15, 25) + # sample at t=30: used for [25, 40) (last sample extends to end) + self.assertTrue(torch.all(result[0:5] == 1.0)) + self.assertTrue(torch.all(result[5:15] == 2.0)) + self.assertTrue(torch.all(result[15:25] == 3.0)) + self.assertTrue(torch.all(result[25:40] == 4.0)) + + def test_constant_mode_centered_intervals(self): + """test that TIME_UNITS_CONSTANT mode uses centered intervals around sample points.""" + # test with tu=20: samples at 0, 20, 40, 60 + proj_stim = torch.zeros(80, 1, 1, device=self.device) + proj_stim[0, :, :] = 10.0 # sample at t=0 + proj_stim[20, :, :] = 20.0 # sample at t=20 + proj_stim[40, :, :] = 30.0 # sample at t=40 + proj_stim[60, :, :] = 40.0 # sample at t=60 + + result = downsample_stimulus( + proj_stim, + tu=20, + num_multiples=4, + stimulus_frequency=StimulusFrequency.TIME_UNITS_CONSTANT, + ) + + # centered intervals (half_tu = 10): + # t=0: [0, 10) -> value 10.0 + # t=20: [10, 30) -> value 20.0 + # t=40: [30, 50) -> value 30.0 + # t=60: [50, 80) -> value 40.0 (last sample extends to end) + self.assertTrue(torch.all(result[0:10] == 10.0)) + self.assertTrue(torch.all(result[10:30] == 20.0)) + self.assertTrue(torch.all(result[30:50] == 30.0)) + self.assertTrue(torch.all(result[50:80] == 40.0)) + + def test_interpolate_mode_shape(self): + """test that TIME_UNITS_INTERPOLATE mode returns correct shape.""" + result = downsample_stimulus( + self.sample_stimulus, + tu=5, + num_multiples=4, + stimulus_frequency=StimulusFrequency.TIME_UNITS_INTERPOLATE, + ) + self.assertEqual(result.shape, self.sample_stimulus.shape) + + def test_interpolate_mode_boundary_values(self): + """test that TIME_UNITS_INTERPOLATE mode matches boundary values.""" + # create stimulus with known values at boundaries + proj_stim = torch.zeros(20, 1, 2, device=self.device) + proj_stim[0, 0, :] = torch.tensor([0.0, 0.0]) + proj_stim[5, 0, :] = torch.tensor([1.0, 2.0]) + proj_stim[10, 0, :] = torch.tensor([2.0, 4.0]) + proj_stim[15, 0, :] = torch.tensor([3.0, 6.0]) + + result = downsample_stimulus( + proj_stim, + tu=5, + num_multiples=4, + stimulus_frequency=StimulusFrequency.TIME_UNITS_INTERPOLATE, + ) + + # check boundary values are preserved (at start of each interval) + self.assertTrue(torch.allclose(result[0, 0, :], torch.tensor([0.0, 0.0]))) + self.assertTrue(torch.allclose(result[5, 0, :], torch.tensor([1.0, 2.0]))) + self.assertTrue(torch.allclose(result[10, 0, :], torch.tensor([2.0, 4.0]))) + # last interval uses constant (no final boundary), so check it matches previous boundary + self.assertTrue(torch.allclose(result[15, 0, :], torch.tensor([3.0, 6.0]))) + + def test_interpolate_mode_midpoint_values(self): + """test that TIME_UNITS_INTERPOLATE mode correctly interpolates midpoints.""" + # create simple linear stimulus + proj_stim = torch.zeros(20, 1, 1, device=self.device) + proj_stim[0, 0, 0] = 0.0 + proj_stim[5, 0, 0] = 10.0 + proj_stim[10, 0, 0] = 20.0 + proj_stim[15, 0, 0] = 30.0 + + result = downsample_stimulus( + proj_stim, + tu=5, + num_multiples=4, + stimulus_frequency=StimulusFrequency.TIME_UNITS_INTERPOLATE, + ) + + # check interpolation in first interval (0 to 10) + # at step 2 (midpoint between 0 and 5), should be ~4.0 + # weights go from 0.0, 0.2, 0.4, 0.6, 0.8 for tu=5 + expected_step_2 = 0.0 * 0.6 + 10.0 * 0.4 # w=0.4 at step 2 + self.assertTrue(torch.allclose(result[2, 0, 0], torch.tensor(expected_step_2), atol=0.1)) + + def test_edge_case_no_final_boundary(self): + """test edge case when total_steps = num_multiples * tu (no final boundary).""" + # this is the common case in training + proj_stim = torch.randn(20, 2, 3, device=self.device) + + result = downsample_stimulus( + proj_stim, + tu=5, + num_multiples=4, + stimulus_frequency=StimulusFrequency.TIME_UNITS_INTERPOLATE, + ) + + # should handle gracefully without indexing errors + self.assertEqual(result.shape, (20, 2, 3)) + + def test_edge_case_with_final_boundary(self): + """test case when we have extra data point for final boundary.""" + # 21 steps = 4 intervals * 5 tu + 1 extra for final boundary + proj_stim = torch.randn(21, 2, 3, device=self.device) + + result = downsample_stimulus( + proj_stim, + tu=5, + num_multiples=4, + stimulus_frequency=StimulusFrequency.TIME_UNITS_INTERPOLATE, + ) + + # should only return 20 steps (4 complete intervals) + self.assertEqual(result.shape, (20, 2, 3)) + + def test_different_tu_values(self): + """test with different time_units values.""" + for tu in [1, 2, 5, 10, 20]: + total_steps = tu * 5 # 5 multiples + proj_stim = torch.randn(total_steps, 2, 3, device=self.device) + + result = downsample_stimulus( + proj_stim, + tu=tu, + num_multiples=5, + stimulus_frequency=StimulusFrequency.TIME_UNITS_CONSTANT, + ) + + self.assertEqual(result.shape, proj_stim.shape) + + def test_invalid_stimulus_frequency_raises(self): + """test that invalid stimulus frequency raises ValueError.""" + # use a mock invalid enum value + class FakeStimulusFrequency: + pass + + with self.assertRaisesRegex(ValueError, "unknown stimulus frequency"): + downsample_stimulus( + self.sample_stimulus, + tu=5, + num_multiples=4, + stimulus_frequency=FakeStimulusFrequency(), + ) + + def test_preserves_device(self): + """test that output is on the same device as input.""" + if torch.cuda.is_available(): + device = torch.device("cuda:0") + proj_stim = torch.randn(20, 2, 3, device=device) + + result = downsample_stimulus( + proj_stim, + tu=5, + num_multiples=4, + stimulus_frequency=StimulusFrequency.TIME_UNITS_INTERPOLATE, + ) + + self.assertEqual(result.device, device) + + def test_preserves_dtype(self): + """test that output preserves input dtype.""" + for dtype in [torch.float32, torch.float64]: + proj_stim = torch.randn(20, 2, 3, device=self.device, dtype=dtype) + + result = downsample_stimulus( + proj_stim, + tu=5, + num_multiples=4, + stimulus_frequency=StimulusFrequency.TIME_UNITS_CONSTANT, + ) + + self.assertEqual(result.dtype, dtype) + + def test_batch_dimension_independence(self): + """test that each batch is processed independently.""" + # create stimulus with different values for each batch + proj_stim = torch.zeros(20, 3, 2, device=self.device) + proj_stim[:, 0, :] = 1.0 # batch 0 + proj_stim[:, 1, :] = 2.0 # batch 1 + proj_stim[:, 2, :] = 3.0 # batch 2 + + result = downsample_stimulus( + proj_stim, + tu=5, + num_multiples=4, + stimulus_frequency=StimulusFrequency.TIME_UNITS_CONSTANT, + ) + + # each batch should maintain its distinct values + self.assertTrue(torch.all(result[:, 0, :] == 1.0)) + self.assertTrue(torch.all(result[:, 1, :] == 2.0)) + self.assertTrue(torch.all(result[:, 2, :] == 3.0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/LatentEvolution/training_config.py b/src/LatentEvolution/training_config.py index 9d04291a..53830ab7 100644 --- a/src/LatentEvolution/training_config.py +++ b/src/LatentEvolution/training_config.py @@ -4,12 +4,24 @@ includes profiling, training hyperparameters, and cross-validation configs. """ +from enum import Enum, auto from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict import torch from LatentEvolution.acquisition import AcquisitionMode, AllTimePointsMode +class StimulusFrequency(Enum): + """stimulus frequency for training. + + controls how stimulus is provided during evolution steps between observations. + """ + ALL = auto() # use stimulus at every time step (current behavior) + NONE = auto() # no stimulus provided (set to zero) + TIME_UNITS_CONSTANT = auto() # use stimulus at time_units intervals, hold constant between + TIME_UNITS_INTERPOLATE = auto() # use stimulus at time_units intervals, linearly interpolate between + + class DataSplit(BaseModel): """split the time series into train/validation sets.""" @@ -81,6 +93,11 @@ class TrainingConfig(BaseModel): description="data acquisition mode. controls which timesteps have observable data for each neuron.", json_schema_extra={"short_name": "acq"} ) + stimulus_frequency: StimulusFrequency = Field( + StimulusFrequency.ALL, + description="stimulus frequency. controls how stimulus is provided during evolution steps.", + json_schema_extra={"short_name": "stim_freq"} + ) intermediate_loss_steps: list[int] = Field( default_factory=list, description="deprecated: intermediate steps feature has been removed. must be empty list.", @@ -164,6 +181,13 @@ def validate_training_config(self): "staggered mode observes neurons at different times, breaking the connectome assumption." ) + # validate stimulus frequency compatibility + if self.time_units == 1 and self.stimulus_frequency != StimulusFrequency.ALL: + raise ValueError( + f"stimulus_frequency must be ALL when time_units=1. " + f"got stimulus_frequency={self.stimulus_frequency.name}, time_units={self.time_units}" + ) + return self