-
Notifications
You must be signed in to change notification settings - Fork 0
feat: configurable evolver features and refactored loss returns #128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Collaborator
vijk777
commented
Jan 23, 2026
- add tv norm regularization
- add feature flag for evolver zero initialization
- refactor to use dict[LossType, Tensor] returns from train step, which makes the code cleaner. No detectable compute overhead.
add `zero_init` flag to EvolverParams to control whether evolver
starts as identity function (z_{t+1} = z_t). this provides training
stability but may slow dynamics learning.
reconstruction_warmup_epochs was already configurable in TrainingConfig
and freezes evolver while training encoder/decoder on reconstruction loss.
both features can now be easily toggled via config or cli overrides:
- --evolver_params.zero_init false
- --training.reconstruction_warmup_epochs 10
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
implement option b tv norm regularization: directly penalize the
magnitude of evolver updates (Δz) using l1 norm. this stabilizes
dynamics and prevents explosive rollouts during long-horizon evolution.
changes:
- add `tv_reg_loss` parameter to EvolverParams (default: 0.0)
- compute tv loss as ||Δz||₁ at each evolver step
- add TV_LOSS to LossType enum and logging
- conditional computation: only compute delta_z explicitly when tv_reg_loss > 0
- update all config files with tv_reg_loss (default 0.0, typical: 1e-5 to 1e-3)
implementation:
- when tv_reg_loss > 0: explicitly compute delta_z = evolver(z_t, stim)
then accumulate tv_loss += ||delta_z||₁ * coeff before updating z_{t+1} = z_t + delta_z
- when tv_reg_loss = 0: use original path for efficiency
typical usage:
python latent.py exp latent_20step.yaml --evolver_params.tv_reg_loss 0.0001
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
add benchmark comparing tuple, namedtuple, and dict return types with torch.compile to determine best approach for loss returns. results (cpu, reduce-overhead mode): - tuple: 236.69 ± 24.02 µs/iter (baseline) - namedtuple: 245.24 ± 21.82 µs/iter (+3.6% overhead) - dict (enum keys): 244.83 ± 19.62 µs/iter (+3.4% overhead) - dict (str keys): 322.72 ± 157.59 µs/iter (+36.4%, high variance) conclusion: namedtuple has negligible overhead (<4%) and provides semantic access, type safety, and flexibility to omit unused fields. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
fix benchmark to use realistic training step computation instead of trivial mean/std operations. previous version was too small and showed compiled code being slower than uncompiled (nonsensical). changes: - simulate encoder/decoder with matrix multiplies and relu - add multiple loss computations (recon, l1 reg, temporal smoothness) - use batch_size=256, neurons=1000, latent=256 (realistic sizes) - ensure proper cuda synchronization - add compilation speedup metrics this should show proper speedup from torch.compile and accurate overhead comparison between tuple, namedtuple, and dict returns. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
replace tuple returns with dict[LossType, Tensor] for semantic access and programmatic tensorboard logging. changes: - train_step_nocompile: returns dict, builds incrementally with losses[LossType.X] = value - train_step_reconstruction_only_nocompile: returns dict with only computed losses (total, recon, reg) - loss accumulation: updated to work with dict instead of tuple indexing - tensorboard logging: now programmatic using loss_type.name.lower() iteration benefits: - semantic access: losses[LossType.RECON] instead of loss_tuple[1] - flexible returns: warmup only returns computed losses - programmatic logging: automatically logs all loss components - type safe: enum keys prevent typos benchmark showed dict with enum keys has <2% overhead vs tuple on gpu. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
test completed, results confirmed dict/namedtuple have <2% overhead. keeping results in git history for reference. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
document that dict[LossType, Tensor] has <2% overhead vs tuple based on gpu benchmarking with realistic computation. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.