Skip to content

Conversation

@danbraunai-apollo
Copy link
Contributor

@danbraunai-apollo danbraunai-apollo commented Apr 22, 2025

Description

Adds streamlit dashboard that shows various properties of components on each token.

image

How Has This Been Tested?

None

Does this PR introduce a breaking change?

No

@danbraunai-apollo danbraunai-apollo merged commit 072085e into feature/lm Apr 22, 2025
1 check passed
danbraunai-apollo added a commit that referenced this pull request Apr 22, 2025
* Rename some topk_mask vars to mask

* Implement gating (untested)

* Fix grad attributions and calc_recon_mse

* Init gate with bias=1 and weights normal dist mean=0 std=0.2

* Fix lp sparsity loss

* Add random mask loss

* Use relud masks for lp sparsity loss

* Use masked_target_component_acts in calc_act_recon_mse

* Comment out grad attribution calculation so people don't use now

* Store gates in model class

* Remove buggy tms deprecated params replacement

* Tie the gates for TMS

* Plot masks

* Fix resid_mlp test (sensitive to float precision)

* Add init_from_target for tms

* Support init_from_target for resid_mlp

* Normalise lp sparsity by batch size

* Don't copy biases in init_spd_model_from_target_model

* Fix resid_mlp init_from_target test

* Add randrecon to run label

* Permute to identity for plotting mask_vals

* Remove post_relu_act_recon config arg

* Remove code from global scope in plotting

* Handle deprecated 'post_relu_act_recon' arg.

* Use mps if available

* Avoid mps as it breaks tms

* Untie gates in TMS

* Allow for detached inputs to gates and use target_out in random_mask_recon

* Add GateMLP

* Remove bias_val and train_bias config args

* Make calc_masked_target_component_acts einsums clearer

* Change bias init to 1 in GateMLP

* Plot unpermuted As

* Set in_bias in GateMLP to zeros

* plot_mask_vals in the root plotting.py instead of in tms experiment

* Plot permuted AB matrices

* Take mean over batch only for lp_sparsity_coeff

* Fix for normalizing by batch only; sum over m dim

* Fix docs for lp sparsity loss

* Fix return type of lp_sparsity_loss

* Use Kaiming normal everywhere

* Fix MLP bias init

* Always init TMS biases to 0

* Remove init_scale everywhere

* Fix init_scale deprecation

* Init A and B based on norm of target weights

* Set Gate biases to 0

* Load env vars when running sweeps too

* Add layerwise recon (#263)

* Add layerwise recon

* Add layerwise_random_recon_loss

* Protect the eyes of mathematicians

* Remove transformer-lens dependency

* Use new random masks for layerwise_random_masks

* Add jaxtyping to dependencies

* Add einops dependency

* Use calc_recon_mse in calc_random_masks_mse_loss for consistency

* Set bias to zero in GateMLP mlp_out

* WIP: Swap components with Llama nn.Linear modules

* Fix nn.Linear shape and handle masked components

* WIP: Add lm_decomposition script

* Fix module paths

* WIP: Add param_match_loss

* Add layerwise recon losses

* Add lp sparsity loss

* Minor comment and config clean

* Make components a submodule of SSModel and update model loading

* Add SSModel.from_pretrained()

* WIP: Fix download with weights_only=True

* Calc mask l0 for lms

* Fix missing GateMLP type references

* Update component_viz for new model format

* Plot mean components during apd run

* Re-organise wandb logging

* Add streamlit dashboard for lm (#2)

* WIP: Add dashboard

* Create base_cache_dir if it doesn't exist

* Functional dashboard

* Add simple-stories-train and datasets to pyproject.toml

* Remove unused set_nested_module_attr function
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants