Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/LatentEvolution/eed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ class EvolverParams(BaseModel):
time_units: int = Field(
1, description="DEPRECATED: Use training.time_units instead. Kept for backwards compatibility."
)
zero_init: bool = Field(
True, description="If True, zero-initialize final layer so evolver starts as identity (z_{t+1} = z_t). Provides stability but may slow dynamics learning."
)
tv_reg_loss: float = Field(
0.0, description="total variation regularization on evolver updates. penalizes ||Δz|| to stabilize dynamics and prevent explosive rollouts. typical range: 1e-5 to 1e-3."
)
model_config = ConfigDict(extra="forbid", validate_assignment=True)

@field_validator("activation")
Expand Down Expand Up @@ -215,13 +221,14 @@ def __init__(self, latent_dims: int, stim_dims: int, evolver_params: EvolverPara
)
)

# zero-init final layer so evolver starts as identity (z_{t+1} = z_t)
if evolver_params.use_input_skips:
nn.init.zeros_(self.evolver.output_layer.weight)
nn.init.zeros_(self.evolver.output_layer.bias)
else:
nn.init.zeros_(self.evolver.layers[-1].weight)
nn.init.zeros_(self.evolver.layers[-1].bias)
# optionally zero-init final layer so evolver starts as identity (z_{t+1} = z_t)
if evolver_params.zero_init:
if evolver_params.use_input_skips:
nn.init.zeros_(self.evolver.output_layer.weight)
nn.init.zeros_(self.evolver.output_layer.bias)
else:
nn.init.zeros_(self.evolver.layers[-1].weight)
nn.init.zeros_(self.evolver.layers[-1].bias)

def forward(self, proj_t, proj_stim_t):
"""Evolve one time step in latent space."""
Expand Down
42 changes: 42 additions & 0 deletions src/LatentEvolution/experiments/flyvis_voltage_100ms.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,45 @@ for ems in 1 2 3 5; do \
python src/LatentEvolution/latent.py ems_sweep latent_20step.yaml \
--training.evolve-multiple-steps $ems
```

## Derivative experiment

Prior to changing the evolver (tanh activation, initialized to 0) and using
pretraining for reconstruction, we observed an instability in training in the
`tu20_seed_sweep_20260115_f6144bd` experiment. We want to revisit and see if
adding TV norm regularization can rescue that phenotype.

Understand which feature contributes to the stability. We changed many things at
the same time to get to a working point.

```bash
for zero_init in zero-init no-zero-init; do \
for activation in Tanh ReLU; do \
for warmup in 0 10; do \
name="z${zero_init}_${activation}_w${warmup}"
bsub -J $name -q gpu_a100 -gpu "num=1" -n 8 -o ${name}.log \
python src/LatentEvolution/latent.py test_stability latent_20step.yaml \
--evolver_params.${zero_init} \
--evolver_params.activation $activation \
--training.reconstruction-warmup-epochs $warmup \
--training.seed 35235
done
done
done
```

Understand if TV norm can bring stability to the training without pretraining for
reconstruction or the other features we added.

```bash

for tv in 0.0 0.00001 0.0001 0.001; do \
bsub -J tv${tv} -q gpu_a100 -gpu "num=1" -n 8 -o tv${tv}.log \
python src/LatentEvolution/latent.py tv_sweep latent_20step.yaml \
--evolver_params.no-zero-init \
--evolver_params.activation ReLU \
--training.reconstruction-warmup-epochs 0 \
--evolver_params.tv-reg-loss $tv \
--training.seed 97651
done
```
146 changes: 92 additions & 54 deletions src/LatentEvolution/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,22 @@ def load_val_only(


class LossType(Enum):
"""loss component types for EED model."""
"""loss component types for EED model.

train steps return dict[LossType, Tensor] for semantic access.

benchmarking results (gpu, 256 batch, realistic encoder/decoder ops):
- dict with enum keys: +0.5% to +1.4% overhead vs tuple
- namedtuple: +0.8% to +2.0% overhead vs tuple
- compilation speedup: 1.87-1.93x faster (all return types)
- conclusion: negligible overhead, semantic access worth it
"""
TOTAL = auto()
RECON = auto()
EVOLVE = auto()
REG = auto()
AUG_LOSS = auto()
TV_LOSS = auto()


# -------------------------------------------------------------------
Expand Down Expand Up @@ -307,7 +317,7 @@ def train_step_reconstruction_only_nocompile(
_selected_neurons: torch.Tensor,
_needed_indices: torch.Tensor,
cfg: ModelParams
):
) -> dict[LossType, torch.Tensor]:
"""Train encoder/decoder only with reconstruction loss."""
device = train_data.device
loss_fn = getattr(torch.nn.functional, cfg.training.loss_function)
Expand All @@ -331,11 +341,15 @@ def train_step_reconstruction_only_nocompile(
recon_t = model.decoder(proj_t)
recon_loss = loss_fn(recon_t, x_t)

# return same tuple format for compatibility
evolve_loss = torch.tensor(0.0, device=device)
aug_loss = torch.tensor(0.0, device=device)
loss = recon_loss + reg_loss
return (loss, recon_loss, evolve_loss, reg_loss, aug_loss)
# total loss
total_loss = recon_loss + reg_loss

# return only computed losses (no evolve, aug, tv during warmup)
return {
LossType.TOTAL: total_loss,
LossType.RECON: recon_loss,
LossType.REG: reg_loss,
}


train_step_reconstruction_only = torch.compile(
Expand All @@ -351,10 +365,13 @@ def train_step_nocompile(
selected_neurons: torch.Tensor,
needed_indices: torch.Tensor,
cfg: ModelParams
):
) -> dict[LossType, torch.Tensor]:

device=train_data.device

# initialize loss dict
losses: dict[LossType, torch.Tensor] = {}

# Get loss function from config
loss_fn = getattr(torch.nn.functional, cfg.training.loss_function)

Expand All @@ -369,6 +386,10 @@ def train_step_nocompile(
if cfg.evolver_params.l1_reg_loss > 0.:
for p in model.evolver.parameters():
reg_loss += torch.abs(p).mean()*cfg.evolver_params.l1_reg_loss
losses[LossType.REG] = reg_loss

# total variation regularization on evolver updates
tv_loss = torch.tensor(0.0, device=device)


# b = batch size
Expand Down Expand Up @@ -400,11 +421,17 @@ def train_step_nocompile(

# reconstruction loss
recon_t = model.decoder(proj_t)
recon_loss = loss_fn(recon_t, x_t)
losses[LossType.RECON] = loss_fn(recon_t, x_t)

# Evolve by 1 time step. This is a special case since we may opt to apply
# a connectome constraint via data augmentation.
proj_t = model.evolver(proj_t, proj_stim_t[0])
if cfg.evolver_params.tv_reg_loss > 0.:
# compute delta_z explicitly for TV norm
delta_z = model.evolver.evolver(torch.cat([proj_t, proj_stim_t[0]], dim=1))
tv_loss += torch.abs(delta_z).mean() * cfg.evolver_params.tv_reg_loss
proj_t = proj_t + delta_z
else:
proj_t = model.evolver(proj_t, proj_stim_t[0])
# apply connectome loss after evolving by 1 time step
aug_loss = torch.tensor(0.0, device=device)
if (
Expand All @@ -419,11 +446,18 @@ def train_step_nocompile(
proj_t_aug = model.evolver(model.encoder(x_t_aug), proj_stim_t[0])
pred_t_plus_1_aug = model.decoder(proj_t_aug)
aug_loss += cfg.training.unconnected_to_zero.loss_coeff * loss_fn(pred_t_plus_1_aug[:, selected_neurons], pred_t_plus_1[:, selected_neurons])
losses[LossType.AUG_LOSS] = aug_loss

# evolve for remaining dt-1 time steps (first window)
evolve_loss = torch.tensor(0.0, device=device)
for i in range(1, dt):
proj_t = model.evolver(proj_t, proj_stim_t[i])
if cfg.evolver_params.tv_reg_loss > 0.:
for i in range(1, dt):
delta_z = model.evolver.evolver(torch.cat([proj_t, proj_stim_t[i]], dim=1))
tv_loss += torch.abs(delta_z).mean() * cfg.evolver_params.tv_reg_loss
proj_t = proj_t + delta_z
else:
for i in range(1, dt):
proj_t = model.evolver(proj_t, proj_stim_t[i])

# loss at first multiple (dt)
pred_t_plus_dt = model.decoder(proj_t)
Expand All @@ -433,20 +467,40 @@ def train_step_nocompile(
evolve_loss = evolve_loss + loss_fn(pred_t_plus_dt, x_t_plus_dt)

# additional multiples (2, 3, ..., num_multiples)
for m in range(2, num_multiples + 1):
# evolve dt more steps
start_idx = (m - 1) * dt
for i in range(dt):
proj_t = model.evolver(proj_t, proj_stim_t[start_idx + i])
# loss at this multiple
pred = model.decoder(proj_t)
# target at t + m*dt
target_indices_m = observation_indices + m * dt
x_target = train_data[target_indices_m, neuron_indices] # (b, N)
evolve_loss = evolve_loss + loss_fn(pred, x_target)

loss = evolve_loss + recon_loss + reg_loss + aug_loss
return (loss, recon_loss, evolve_loss, reg_loss, aug_loss)
if cfg.evolver_params.tv_reg_loss > 0.:
for m in range(2, num_multiples + 1):
# evolve dt more steps
start_idx = (m - 1) * dt
for i in range(dt):
delta_z = model.evolver.evolver(torch.cat([proj_t, proj_stim_t[start_idx + i]], dim=1))
tv_loss += torch.abs(delta_z).mean() * cfg.evolver_params.tv_reg_loss
proj_t = proj_t + delta_z
# loss at this multiple
pred = model.decoder(proj_t)
# target at t + m*dt
target_indices_m = observation_indices + m * dt
x_target = train_data[target_indices_m, neuron_indices] # (b, N)
evolve_loss = evolve_loss + loss_fn(pred, x_target)
else:
for m in range(2, num_multiples + 1):
# evolve dt more steps
start_idx = (m - 1) * dt
for i in range(dt):
proj_t = model.evolver(proj_t, proj_stim_t[start_idx + i])
# loss at this multiple
pred = model.decoder(proj_t)
# target at t + m*dt
target_indices_m = observation_indices + m * dt
x_target = train_data[target_indices_m, neuron_indices] # (b, N)
evolve_loss = evolve_loss + loss_fn(pred, x_target)

losses[LossType.EVOLVE] = evolve_loss
losses[LossType.TV_LOSS] = tv_loss

# compute total loss
losses[LossType.TOTAL] = sum(losses.values())

return losses

train_step = torch.compile(train_step_nocompile, fullgraph=True, mode="reduce-overhead")

Expand Down Expand Up @@ -606,28 +660,21 @@ def train(cfg: ModelParams, run_dir: Path):
needed_indices = torch.empty(0, dtype=torch.long, device=device)

# use nocompile version for warmup
loss_tuple = train_step_reconstruction_only_nocompile(
losses = train_step_reconstruction_only_nocompile(
model, chunk_data, chunk_stim, observation_indices,
selected_neurons, needed_indices, cfg
)
loss_tuple[0].backward()
losses[LossType.TOTAL].backward()
optimizer.step()
warmup_losses.accumulate({
LossType.TOTAL: loss_tuple[0],
LossType.RECON: loss_tuple[1],
LossType.EVOLVE: loss_tuple[2],
LossType.REG: loss_tuple[3],
LossType.AUG_LOSS: loss_tuple[4],
})
warmup_losses.accumulate(losses)

warmup_epoch_duration = (datetime.now() - warmup_epoch_start).total_seconds()
mean_warmup = warmup_losses.mean()
print(f"warmup {warmup_epoch+1}/{recon_warmup_epochs} | recon loss: {mean_warmup[LossType.RECON]:.4e} | duration: {warmup_epoch_duration:.2f}s")

# log to tensorboard
writer.add_scalar("ReconWarmup/loss", mean_warmup[LossType.TOTAL], warmup_epoch)
writer.add_scalar("ReconWarmup/recon_loss", mean_warmup[LossType.RECON], warmup_epoch)
writer.add_scalar("ReconWarmup/reg_loss", mean_warmup[LossType.REG], warmup_epoch)
# log to tensorboard (only computed losses during warmup)
for loss_type, loss_value in mean_warmup.items():
writer.add_scalar(f"ReconWarmup/{loss_type.name.lower()}", loss_value, warmup_epoch)
writer.add_scalar("ReconWarmup/epoch_duration", warmup_epoch_duration, warmup_epoch)

model.evolver.requires_grad_(True)
Expand Down Expand Up @@ -738,13 +785,13 @@ def train(cfg: ModelParams, run_dir: Path):

# training step (timing for latency tracking)
forward_start = time.time()
loss_tuple = train_step_fn(
loss_dict = train_step_fn(
model, chunk_data, chunk_stim, observation_indices, selected_neurons, needed_indices, cfg
)
forward_time = time.time() - forward_start

backward_start = time.time()
loss_tuple[0].backward()
loss_dict[LossType.TOTAL].backward()
backward_time = time.time() - backward_start

step_start = time.time()
Expand All @@ -753,13 +800,7 @@ def train(cfg: ModelParams, run_dir: Path):
optimizer.step()
step_time = time.time() - step_start

losses.accumulate({
LossType.TOTAL: loss_tuple[0],
LossType.RECON: loss_tuple[1],
LossType.EVOLVE: loss_tuple[2],
LossType.REG: loss_tuple[3],
LossType.AUG_LOSS: loss_tuple[4],
})
losses.accumulate(loss_dict)

# sample timing every 10 batches
if batch_in_chunk % 10 == 0:
Expand All @@ -779,12 +820,9 @@ def train(cfg: ModelParams, run_dir: Path):
f"Duration: {epoch_duration:.2f}s (Total: {total_elapsed:.1f}s)"
)

# log to tensorboard
writer.add_scalar("Loss/train", mean_losses[LossType.TOTAL], epoch)
writer.add_scalar("Loss/train_recon", mean_losses[LossType.RECON], epoch)
writer.add_scalar("Loss/train_evolve", mean_losses[LossType.EVOLVE], epoch)
writer.add_scalar("Loss/train_reg", mean_losses[LossType.REG], epoch)
writer.add_scalar("Loss/train_aug_loss", mean_losses[LossType.AUG_LOSS], epoch)
# log to tensorboard (programmatically iterate over all losses)
for loss_type, loss_value in mean_losses.items():
writer.add_scalar(f"Loss/train_{loss_type.name.lower()}", loss_value, epoch)
writer.add_scalar("Time/epoch_duration", epoch_duration, epoch)
writer.add_scalar("Time/total_elapsed", total_elapsed, epoch)

Expand Down
2 changes: 2 additions & 0 deletions src/LatentEvolution/latent_1step.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ evolver_params:
l1_reg_loss: 0.0
activation: ReLU
use_input_skips: true # Use MLPWithSkips (input fed to each hidden layer via concatenation)
zero_init: true # Zero-initialize final layer so evolver starts as identity (z_{t+1} = z_t)
tv_reg_loss: 0.0 # Total variation regularization on evolver updates (typical: 1e-5 to 1e-3)

stimulus_encoder_params:
num_input_dims: 1736
Expand Down
2 changes: 2 additions & 0 deletions src/LatentEvolution/latent_20step.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ evolver_params:
l1_reg_loss: 0.0
activation: Tanh
use_input_skips: true # Use MLPWithSkips (input fed to each hidden layer via concatenation)
zero_init: true # Zero-initialize final layer so evolver starts as identity (z_{t+1} = z_t)
tv_reg_loss: 0.0 # Total variation regularization on evolver updates (typical: 1e-5 to 1e-3)

stimulus_encoder_params:
num_input_dims: 1736
Expand Down
2 changes: 2 additions & 0 deletions src/LatentEvolution/latent_5step.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ evolver_params:
l1_reg_loss: 0.0
activation: Tanh
use_input_skips: true # Use MLPWithSkips (input fed to each hidden layer via concatenation)
zero_init: true # Zero-initialize final layer so evolver starts as identity (z_{t+1} = z_t)
tv_reg_loss: 0.0 # Total variation regularization on evolver updates (typical: 1e-5 to 1e-3)

stimulus_encoder_params:
num_input_dims: 1736
Expand Down