Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
b53df86
Rename some topk_mask vars to mask
danbraunai-apollo Feb 12, 2025
0f4e7f8
Implement gating (untested)
danbraunai-apollo Feb 13, 2025
c784489
Fix grad attributions and calc_recon_mse
danbraunai-apollo Feb 13, 2025
e3c3eb0
Init gate with bias=1 and weights normal dist mean=0 std=0.2
danbraunai-apollo Feb 13, 2025
15b310c
Fix lp sparsity loss
danbraunai-apollo Feb 13, 2025
3aff69b
Add random mask loss
danbraunai-apollo Feb 13, 2025
13b8097
Use relud masks for lp sparsity loss
danbraunai-apollo Feb 13, 2025
0923c0f
Use masked_target_component_acts in calc_act_recon_mse
danbraunai-apollo Feb 13, 2025
3aceb8a
Comment out grad attribution calculation so people don't use now
danbraunai-apollo Feb 14, 2025
61247dc
Store gates in model class
danbraunai-apollo Feb 14, 2025
64c3a23
Remove buggy tms deprecated params replacement
danbraunai-apollo Feb 14, 2025
ed32237
Tie the gates for TMS
danbraunai-apollo Feb 14, 2025
60cc056
Plot masks
danbraunai-apollo Feb 14, 2025
bc9505c
Fix resid_mlp test (sensitive to float precision)
danbraunai-apollo Feb 14, 2025
01a03bc
Add init_from_target for tms
danbraunai-apollo Feb 14, 2025
6d6d99f
Support init_from_target for resid_mlp
danbraunai-apollo Feb 14, 2025
c303c14
Normalise lp sparsity by batch size
danbraunai-apollo Feb 14, 2025
41bd85b
Don't copy biases in init_spd_model_from_target_model
danbraunai-apollo Feb 15, 2025
befac1d
Fix resid_mlp init_from_target test
danbraunai-apollo Feb 16, 2025
e7e60a7
Add randrecon to run label
danbraunai-apollo Feb 20, 2025
3845ca3
Permute to identity for plotting mask_vals
danbraunai-apollo Feb 24, 2025
3bb654c
Remove post_relu_act_recon config arg
danbraunai-apollo Feb 27, 2025
ebee911
Remove code from global scope in plotting
danbraunai-apollo Feb 27, 2025
0b3f61d
Handle deprecated 'post_relu_act_recon' arg.
danbraunai-apollo Feb 27, 2025
931b6f3
Use mps if available
danbraunai-apollo Mar 3, 2025
19d7181
Avoid mps as it breaks tms
danbraunai-apollo Mar 3, 2025
8560f1b
Untie gates in TMS
danbraunai-apollo Mar 3, 2025
79391e9
Allow for detached inputs to gates and use target_out in random_mask_…
danbraunai-apollo Mar 4, 2025
cd23609
Add GateMLP
danbraunai-apollo Mar 5, 2025
96939c2
Remove bias_val and train_bias config args
danbraunai-apollo Mar 6, 2025
58eb606
Make calc_masked_target_component_acts einsums clearer
danbraunai-apollo Mar 6, 2025
f536743
Change bias init to 1 in GateMLP
danbraunai-apollo Mar 6, 2025
b6a35cc
Plot unpermuted As
danbraunai-apollo Mar 6, 2025
10cad29
Set in_bias in GateMLP to zeros
danbraunai-apollo Mar 6, 2025
6aa82a8
plot_mask_vals in the root plotting.py instead of in tms experiment
danbraunai-apollo Mar 6, 2025
99da31b
Plot permuted AB matrices
danbraunai-apollo Mar 6, 2025
aa453f7
Take mean over batch only for lp_sparsity_coeff
danbraunai-apollo Mar 6, 2025
f6bc57d
Fix for normalizing by batch only; sum over m dim
danbraunai-apollo Mar 6, 2025
5f216b3
Fix docs for lp sparsity loss
danbraunai-apollo Mar 6, 2025
d1b82fa
Fix return type of lp_sparsity_loss
danbraunai-apollo Mar 6, 2025
e93c5c9
Use Kaiming normal everywhere
danbraunai-apollo Mar 7, 2025
52e6d91
Fix MLP bias init
danbraunai-apollo Mar 7, 2025
244883f
Always init TMS biases to 0
danbraunai-apollo Mar 7, 2025
bddc0ed
Remove init_scale everywhere
danbraunai-apollo Mar 7, 2025
a1d40c4
Fix init_scale deprecation
danbraunai-apollo Mar 7, 2025
c71ace6
Init A and B based on norm of target weights
danbraunai-apollo Mar 7, 2025
3898599
Set Gate biases to 0
danbraunai-apollo Mar 7, 2025
5afdc92
Load env vars when running sweeps too
danbraunai-apollo Mar 17, 2025
e80f874
Add layerwise recon (#263)
danbraunai-apollo Mar 24, 2025
16992e5
Remove transformer-lens dependency
danbraunai-apollo Mar 24, 2025
7f6a94b
Use new random masks for layerwise_random_masks
danbraunai-apollo Mar 24, 2025
5c632f9
Add jaxtyping to dependencies
danbraunai-apollo Mar 24, 2025
5981df6
Add einops dependency
danbraunai-apollo Mar 24, 2025
fcff304
Use calc_recon_mse in calc_random_masks_mse_loss for consistency
danbraunai-apollo Mar 24, 2025
7ac2a42
Set bias to zero in GateMLP mlp_out
danbraunai-apollo Mar 25, 2025
037caf1
WIP: Swap components with Llama nn.Linear modules
danbraunai-apollo Apr 1, 2025
1a1dcaf
Fix nn.Linear shape and handle masked components
danbraunai-apollo Apr 3, 2025
993da44
WIP: Add lm_decomposition script
danbraunai-apollo Apr 3, 2025
fccc189
Fix module paths
danbraunai-apollo Apr 3, 2025
3fcf593
WIP: Add param_match_loss
danbraunai-apollo Apr 4, 2025
aa7cacf
Add layerwise recon losses
danbraunai-apollo Apr 8, 2025
82b505a
Add lp sparsity loss
danbraunai-apollo Apr 8, 2025
96ae954
Minor comment and config clean
danbraunai-apollo Apr 8, 2025
cb12ed1
Make components a submodule of SSModel and update model loading
danbraunai-apollo Apr 10, 2025
d3a7c76
Add SSModel.from_pretrained()
danbraunai-apollo Apr 10, 2025
1425354
WIP: Fix download with weights_only=True
danbraunai-apollo Apr 10, 2025
8ba8ca9
Merge branch 'main' into feature/lm
danbraunai-apollo Apr 14, 2025
7a23520
Calc mask l0 for lms
danbraunai-apollo Apr 14, 2025
2706112
Merge branch 'main' into feature/lm
danbraunai-apollo Apr 14, 2025
0103c0c
Fix missing GateMLP type references
danbraunai-apollo Apr 14, 2025
bcd3e09
Merge branch 'feature/lm' into feature/lm-temp
danbraunai-apollo Apr 16, 2025
60fa3cc
Update component_viz for new model format
danbraunai-apollo Apr 17, 2025
04bcbe1
Plot mean components during apd run
danbraunai-apollo Apr 17, 2025
c2bdda1
Re-organise wandb logging
danbraunai-apollo Apr 17, 2025
072085e
Add streamlit dashboard for lm (#2)
danbraunai-apollo Apr 22, 2025
04a2138
Remove unused set_nested_module_attr function
danbraunai-apollo Apr 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,29 @@
"PYDEVD_DISABLE_FILE_VALIDATION": "1"
}
},
{
"name": "lm",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/spd/experiments/lm/lm_decomposition.py",
"args": "${workspaceFolder}/spd/experiments/lm/lm_config.yaml",
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYDEVD_DISABLE_FILE_VALIDATION": "1"
}
},
{
"name": "lm streamlit",
"type": "debugpy",
"request": "launch",
"module": "streamlit",
"args": [
"run",
"${workspaceFolder}/spd/experiments/lm/app.py",
"--server.port",
"2000"
]
}
]
}
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ dependencies = [
"python-dotenv",
"wandb<=0.17.7", # due to https://github.com/wandb/wandb/issues/8248
"sympy",
"streamlit",
"streamlit-antd-components",
"datasets",
"simple-stories-train"
]

[project.optional-dependencies]
Expand Down
18 changes: 17 additions & 1 deletion spd/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ class ResidualMLPTaskConfig(BaseModel):
pretrained_model_path: ModelPath # e.g. wandb:spd-resid-mlp/runs/j9kmavzi


class LMTaskConfig(BaseModel):
model_config = ConfigDict(extra="forbid", frozen=True)
task_name: Literal["lm"] = "lm"
model_size: str # e.g. "1.25M"
max_seq_len: PositiveInt = 512
buffer_size: PositiveInt = 1000
dataset_name: str = "lennart-finke/SimpleStories"
train_data_split: str = "train"
eval_data_split: str = "test"
n_eval_steps: PositiveInt = 100
# List of fnmatch patterns for nn.Linear modules to decompose
target_module_patterns: list[str] = ["transformer.h.*.mlp.*_proj"]


class Config(BaseModel):
model_config = ConfigDict(extra="forbid", frozen=True)
wandb_project: str | None = None
Expand Down Expand Up @@ -68,7 +82,9 @@ class Config(BaseModel):
unit_norm_matrices: bool = False
attribution_type: Literal["gradient"] = "gradient"
n_gate_hidden_neurons: PositiveInt | None = None
task_config: TMSTaskConfig | ResidualMLPTaskConfig = Field(..., discriminator="task_name")
task_config: TMSTaskConfig | ResidualMLPTaskConfig | LMTaskConfig = Field(
..., discriminator="task_name"
)

DEPRECATED_CONFIG_KEYS: ClassVar[list[str]] = []
RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = {}
Expand Down
Loading