diff --git a/src/LatentEvolution/eed_model.py b/src/LatentEvolution/eed_model.py index 94b02e10..57fbb0cd 100644 --- a/src/LatentEvolution/eed_model.py +++ b/src/LatentEvolution/eed_model.py @@ -51,6 +51,12 @@ class EvolverParams(BaseModel): time_units: int = Field( 1, description="DEPRECATED: Use training.time_units instead. Kept for backwards compatibility." ) + zero_init: bool = Field( + True, description="If True, zero-initialize final layer so evolver starts as identity (z_{t+1} = z_t). Provides stability but may slow dynamics learning." + ) + tv_reg_loss: float = Field( + 0.0, description="total variation regularization on evolver updates. penalizes ||Δz|| to stabilize dynamics and prevent explosive rollouts. typical range: 1e-5 to 1e-3." + ) model_config = ConfigDict(extra="forbid", validate_assignment=True) @field_validator("activation") @@ -215,13 +221,14 @@ def __init__(self, latent_dims: int, stim_dims: int, evolver_params: EvolverPara ) ) - # zero-init final layer so evolver starts as identity (z_{t+1} = z_t) - if evolver_params.use_input_skips: - nn.init.zeros_(self.evolver.output_layer.weight) - nn.init.zeros_(self.evolver.output_layer.bias) - else: - nn.init.zeros_(self.evolver.layers[-1].weight) - nn.init.zeros_(self.evolver.layers[-1].bias) + # optionally zero-init final layer so evolver starts as identity (z_{t+1} = z_t) + if evolver_params.zero_init: + if evolver_params.use_input_skips: + nn.init.zeros_(self.evolver.output_layer.weight) + nn.init.zeros_(self.evolver.output_layer.bias) + else: + nn.init.zeros_(self.evolver.layers[-1].weight) + nn.init.zeros_(self.evolver.layers[-1].bias) def forward(self, proj_t, proj_stim_t): """Evolve one time step in latent space.""" diff --git a/src/LatentEvolution/experiments/flyvis_voltage_100ms.md b/src/LatentEvolution/experiments/flyvis_voltage_100ms.md index d7ee70f6..55f15d38 100644 --- a/src/LatentEvolution/experiments/flyvis_voltage_100ms.md +++ b/src/LatentEvolution/experiments/flyvis_voltage_100ms.md @@ -312,3 +312,45 @@ for ems in 1 2 3 5; do \ python src/LatentEvolution/latent.py ems_sweep latent_20step.yaml \ --training.evolve-multiple-steps $ems ``` + +## Derivative experiment + +Prior to changing the evolver (tanh activation, initialized to 0) and using +pretraining for reconstruction, we observed an instability in training in the +`tu20_seed_sweep_20260115_f6144bd` experiment. We want to revisit and see if +adding TV norm regularization can rescue that phenotype. + +Understand which feature contributes to the stability. We changed many things at +the same time to get to a working point. + +```bash +for zero_init in zero-init no-zero-init; do \ + for activation in Tanh ReLU; do \ + for warmup in 0 10; do \ + name="z${zero_init}_${activation}_w${warmup}" + bsub -J $name -q gpu_a100 -gpu "num=1" -n 8 -o ${name}.log \ + python src/LatentEvolution/latent.py test_stability latent_20step.yaml \ + --evolver_params.${zero_init} \ + --evolver_params.activation $activation \ + --training.reconstruction-warmup-epochs $warmup \ + --training.seed 35235 + done + done +done +``` + +Understand if TV norm can bring stability to the training without pretraining for +reconstruction or the other features we added. + +```bash + +for tv in 0.0 0.00001 0.0001 0.001; do \ + bsub -J tv${tv} -q gpu_a100 -gpu "num=1" -n 8 -o tv${tv}.log \ + python src/LatentEvolution/latent.py tv_sweep latent_20step.yaml \ + --evolver_params.no-zero-init \ + --evolver_params.activation ReLU \ + --training.reconstruction-warmup-epochs 0 \ + --evolver_params.tv-reg-loss $tv \ + --training.seed 97651 +done +``` diff --git a/src/LatentEvolution/latent.py b/src/LatentEvolution/latent.py index bb6bd6b9..ad9f9217 100644 --- a/src/LatentEvolution/latent.py +++ b/src/LatentEvolution/latent.py @@ -172,12 +172,22 @@ def load_val_only( class LossType(Enum): - """loss component types for EED model.""" + """loss component types for EED model. + + train steps return dict[LossType, Tensor] for semantic access. + + benchmarking results (gpu, 256 batch, realistic encoder/decoder ops): + - dict with enum keys: +0.5% to +1.4% overhead vs tuple + - namedtuple: +0.8% to +2.0% overhead vs tuple + - compilation speedup: 1.87-1.93x faster (all return types) + - conclusion: negligible overhead, semantic access worth it + """ TOTAL = auto() RECON = auto() EVOLVE = auto() REG = auto() AUG_LOSS = auto() + TV_LOSS = auto() # ------------------------------------------------------------------- @@ -307,7 +317,7 @@ def train_step_reconstruction_only_nocompile( _selected_neurons: torch.Tensor, _needed_indices: torch.Tensor, cfg: ModelParams - ): + ) -> dict[LossType, torch.Tensor]: """Train encoder/decoder only with reconstruction loss.""" device = train_data.device loss_fn = getattr(torch.nn.functional, cfg.training.loss_function) @@ -331,11 +341,15 @@ def train_step_reconstruction_only_nocompile( recon_t = model.decoder(proj_t) recon_loss = loss_fn(recon_t, x_t) - # return same tuple format for compatibility - evolve_loss = torch.tensor(0.0, device=device) - aug_loss = torch.tensor(0.0, device=device) - loss = recon_loss + reg_loss - return (loss, recon_loss, evolve_loss, reg_loss, aug_loss) + # total loss + total_loss = recon_loss + reg_loss + + # return only computed losses (no evolve, aug, tv during warmup) + return { + LossType.TOTAL: total_loss, + LossType.RECON: recon_loss, + LossType.REG: reg_loss, + } train_step_reconstruction_only = torch.compile( @@ -351,10 +365,13 @@ def train_step_nocompile( selected_neurons: torch.Tensor, needed_indices: torch.Tensor, cfg: ModelParams - ): + ) -> dict[LossType, torch.Tensor]: device=train_data.device + # initialize loss dict + losses: dict[LossType, torch.Tensor] = {} + # Get loss function from config loss_fn = getattr(torch.nn.functional, cfg.training.loss_function) @@ -369,6 +386,10 @@ def train_step_nocompile( if cfg.evolver_params.l1_reg_loss > 0.: for p in model.evolver.parameters(): reg_loss += torch.abs(p).mean()*cfg.evolver_params.l1_reg_loss + losses[LossType.REG] = reg_loss + + # total variation regularization on evolver updates + tv_loss = torch.tensor(0.0, device=device) # b = batch size @@ -400,11 +421,17 @@ def train_step_nocompile( # reconstruction loss recon_t = model.decoder(proj_t) - recon_loss = loss_fn(recon_t, x_t) + losses[LossType.RECON] = loss_fn(recon_t, x_t) # Evolve by 1 time step. This is a special case since we may opt to apply # a connectome constraint via data augmentation. - proj_t = model.evolver(proj_t, proj_stim_t[0]) + if cfg.evolver_params.tv_reg_loss > 0.: + # compute delta_z explicitly for TV norm + delta_z = model.evolver.evolver(torch.cat([proj_t, proj_stim_t[0]], dim=1)) + tv_loss += torch.abs(delta_z).mean() * cfg.evolver_params.tv_reg_loss + proj_t = proj_t + delta_z + else: + proj_t = model.evolver(proj_t, proj_stim_t[0]) # apply connectome loss after evolving by 1 time step aug_loss = torch.tensor(0.0, device=device) if ( @@ -419,11 +446,18 @@ def train_step_nocompile( proj_t_aug = model.evolver(model.encoder(x_t_aug), proj_stim_t[0]) pred_t_plus_1_aug = model.decoder(proj_t_aug) aug_loss += cfg.training.unconnected_to_zero.loss_coeff * loss_fn(pred_t_plus_1_aug[:, selected_neurons], pred_t_plus_1[:, selected_neurons]) + losses[LossType.AUG_LOSS] = aug_loss # evolve for remaining dt-1 time steps (first window) evolve_loss = torch.tensor(0.0, device=device) - for i in range(1, dt): - proj_t = model.evolver(proj_t, proj_stim_t[i]) + if cfg.evolver_params.tv_reg_loss > 0.: + for i in range(1, dt): + delta_z = model.evolver.evolver(torch.cat([proj_t, proj_stim_t[i]], dim=1)) + tv_loss += torch.abs(delta_z).mean() * cfg.evolver_params.tv_reg_loss + proj_t = proj_t + delta_z + else: + for i in range(1, dt): + proj_t = model.evolver(proj_t, proj_stim_t[i]) # loss at first multiple (dt) pred_t_plus_dt = model.decoder(proj_t) @@ -433,20 +467,40 @@ def train_step_nocompile( evolve_loss = evolve_loss + loss_fn(pred_t_plus_dt, x_t_plus_dt) # additional multiples (2, 3, ..., num_multiples) - for m in range(2, num_multiples + 1): - # evolve dt more steps - start_idx = (m - 1) * dt - for i in range(dt): - proj_t = model.evolver(proj_t, proj_stim_t[start_idx + i]) - # loss at this multiple - pred = model.decoder(proj_t) - # target at t + m*dt - target_indices_m = observation_indices + m * dt - x_target = train_data[target_indices_m, neuron_indices] # (b, N) - evolve_loss = evolve_loss + loss_fn(pred, x_target) - - loss = evolve_loss + recon_loss + reg_loss + aug_loss - return (loss, recon_loss, evolve_loss, reg_loss, aug_loss) + if cfg.evolver_params.tv_reg_loss > 0.: + for m in range(2, num_multiples + 1): + # evolve dt more steps + start_idx = (m - 1) * dt + for i in range(dt): + delta_z = model.evolver.evolver(torch.cat([proj_t, proj_stim_t[start_idx + i]], dim=1)) + tv_loss += torch.abs(delta_z).mean() * cfg.evolver_params.tv_reg_loss + proj_t = proj_t + delta_z + # loss at this multiple + pred = model.decoder(proj_t) + # target at t + m*dt + target_indices_m = observation_indices + m * dt + x_target = train_data[target_indices_m, neuron_indices] # (b, N) + evolve_loss = evolve_loss + loss_fn(pred, x_target) + else: + for m in range(2, num_multiples + 1): + # evolve dt more steps + start_idx = (m - 1) * dt + for i in range(dt): + proj_t = model.evolver(proj_t, proj_stim_t[start_idx + i]) + # loss at this multiple + pred = model.decoder(proj_t) + # target at t + m*dt + target_indices_m = observation_indices + m * dt + x_target = train_data[target_indices_m, neuron_indices] # (b, N) + evolve_loss = evolve_loss + loss_fn(pred, x_target) + + losses[LossType.EVOLVE] = evolve_loss + losses[LossType.TV_LOSS] = tv_loss + + # compute total loss + losses[LossType.TOTAL] = sum(losses.values()) + + return losses train_step = torch.compile(train_step_nocompile, fullgraph=True, mode="reduce-overhead") @@ -606,28 +660,21 @@ def train(cfg: ModelParams, run_dir: Path): needed_indices = torch.empty(0, dtype=torch.long, device=device) # use nocompile version for warmup - loss_tuple = train_step_reconstruction_only_nocompile( + losses = train_step_reconstruction_only_nocompile( model, chunk_data, chunk_stim, observation_indices, selected_neurons, needed_indices, cfg ) - loss_tuple[0].backward() + losses[LossType.TOTAL].backward() optimizer.step() - warmup_losses.accumulate({ - LossType.TOTAL: loss_tuple[0], - LossType.RECON: loss_tuple[1], - LossType.EVOLVE: loss_tuple[2], - LossType.REG: loss_tuple[3], - LossType.AUG_LOSS: loss_tuple[4], - }) + warmup_losses.accumulate(losses) warmup_epoch_duration = (datetime.now() - warmup_epoch_start).total_seconds() mean_warmup = warmup_losses.mean() print(f"warmup {warmup_epoch+1}/{recon_warmup_epochs} | recon loss: {mean_warmup[LossType.RECON]:.4e} | duration: {warmup_epoch_duration:.2f}s") - # log to tensorboard - writer.add_scalar("ReconWarmup/loss", mean_warmup[LossType.TOTAL], warmup_epoch) - writer.add_scalar("ReconWarmup/recon_loss", mean_warmup[LossType.RECON], warmup_epoch) - writer.add_scalar("ReconWarmup/reg_loss", mean_warmup[LossType.REG], warmup_epoch) + # log to tensorboard (only computed losses during warmup) + for loss_type, loss_value in mean_warmup.items(): + writer.add_scalar(f"ReconWarmup/{loss_type.name.lower()}", loss_value, warmup_epoch) writer.add_scalar("ReconWarmup/epoch_duration", warmup_epoch_duration, warmup_epoch) model.evolver.requires_grad_(True) @@ -738,13 +785,13 @@ def train(cfg: ModelParams, run_dir: Path): # training step (timing for latency tracking) forward_start = time.time() - loss_tuple = train_step_fn( + loss_dict = train_step_fn( model, chunk_data, chunk_stim, observation_indices, selected_neurons, needed_indices, cfg ) forward_time = time.time() - forward_start backward_start = time.time() - loss_tuple[0].backward() + loss_dict[LossType.TOTAL].backward() backward_time = time.time() - backward_start step_start = time.time() @@ -753,13 +800,7 @@ def train(cfg: ModelParams, run_dir: Path): optimizer.step() step_time = time.time() - step_start - losses.accumulate({ - LossType.TOTAL: loss_tuple[0], - LossType.RECON: loss_tuple[1], - LossType.EVOLVE: loss_tuple[2], - LossType.REG: loss_tuple[3], - LossType.AUG_LOSS: loss_tuple[4], - }) + losses.accumulate(loss_dict) # sample timing every 10 batches if batch_in_chunk % 10 == 0: @@ -779,12 +820,9 @@ def train(cfg: ModelParams, run_dir: Path): f"Duration: {epoch_duration:.2f}s (Total: {total_elapsed:.1f}s)" ) - # log to tensorboard - writer.add_scalar("Loss/train", mean_losses[LossType.TOTAL], epoch) - writer.add_scalar("Loss/train_recon", mean_losses[LossType.RECON], epoch) - writer.add_scalar("Loss/train_evolve", mean_losses[LossType.EVOLVE], epoch) - writer.add_scalar("Loss/train_reg", mean_losses[LossType.REG], epoch) - writer.add_scalar("Loss/train_aug_loss", mean_losses[LossType.AUG_LOSS], epoch) + # log to tensorboard (programmatically iterate over all losses) + for loss_type, loss_value in mean_losses.items(): + writer.add_scalar(f"Loss/train_{loss_type.name.lower()}", loss_value, epoch) writer.add_scalar("Time/epoch_duration", epoch_duration, epoch) writer.add_scalar("Time/total_elapsed", total_elapsed, epoch) diff --git a/src/LatentEvolution/latent_1step.yaml b/src/LatentEvolution/latent_1step.yaml index c8bcdf14..de588829 100644 --- a/src/LatentEvolution/latent_1step.yaml +++ b/src/LatentEvolution/latent_1step.yaml @@ -25,6 +25,8 @@ evolver_params: l1_reg_loss: 0.0 activation: ReLU use_input_skips: true # Use MLPWithSkips (input fed to each hidden layer via concatenation) + zero_init: true # Zero-initialize final layer so evolver starts as identity (z_{t+1} = z_t) + tv_reg_loss: 0.0 # Total variation regularization on evolver updates (typical: 1e-5 to 1e-3) stimulus_encoder_params: num_input_dims: 1736 diff --git a/src/LatentEvolution/latent_20step.yaml b/src/LatentEvolution/latent_20step.yaml index 3012ce63..5abe817d 100644 --- a/src/LatentEvolution/latent_20step.yaml +++ b/src/LatentEvolution/latent_20step.yaml @@ -25,6 +25,8 @@ evolver_params: l1_reg_loss: 0.0 activation: Tanh use_input_skips: true # Use MLPWithSkips (input fed to each hidden layer via concatenation) + zero_init: true # Zero-initialize final layer so evolver starts as identity (z_{t+1} = z_t) + tv_reg_loss: 0.0 # Total variation regularization on evolver updates (typical: 1e-5 to 1e-3) stimulus_encoder_params: num_input_dims: 1736 diff --git a/src/LatentEvolution/latent_5step.yaml b/src/LatentEvolution/latent_5step.yaml index 4da34b2b..ff9d24a4 100644 --- a/src/LatentEvolution/latent_5step.yaml +++ b/src/LatentEvolution/latent_5step.yaml @@ -25,6 +25,8 @@ evolver_params: l1_reg_loss: 0.0 activation: Tanh use_input_skips: true # Use MLPWithSkips (input fed to each hidden layer via concatenation) + zero_init: true # Zero-initialize final layer so evolver starts as identity (z_{t+1} = z_t) + tv_reg_loss: 0.0 # Total variation regularization on evolver updates (typical: 1e-5 to 1e-3) stimulus_encoder_params: num_input_dims: 1736