-
Notifications
You must be signed in to change notification settings - Fork 226
[WIP] feat: add mlp transcoders #183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
dtch1997
wants to merge
11
commits into
decoderesearch:main
Choose a base branch
from
dtch1997:feat-transcoder
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
3d021ae
refactor: sae forward pass
dtch1997 99f40fe
fix: detached sae output in forward
dtch1997 a58ccaa
add mlp transcoder
dtch1997 215bac9
xfail test error term as not implemented
dtch1997 1dd7da6
add configs, light refactor of training_sae
dtch1997 406cd42
add transcoder training infra
dtch1997 9a4dccb
fix minor issue with multiple inheritance
dtch1997 a46ce5e
add ipykernel for running tutorials
dtch1997 0648576
add tutorial notebook for training mlp transcoder
dtch1997 7a22a75
fix: various minor bugs
dtch1997 1e020d5
add experiment harness
dtch1997 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| import os | ||
|
|
||
| import torch | ||
| from simple_parsing import ArgumentParser | ||
|
|
||
| from sae_lens.config import LanguageModelTranscoderRunnerConfig | ||
| from sae_lens.sae_training_runner import TranscoderTrainingRunner | ||
|
|
||
|
|
||
| def setup_env_vars(): | ||
| # Set the environment variables for the cache and the dataset. | ||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
|
|
||
|
|
||
| def get_default_config(): | ||
| if torch.cuda.is_available(): | ||
| device = "cuda" | ||
| elif torch.backends.mps.is_available(): | ||
| device = "mps" | ||
| else: | ||
| device = "cpu" | ||
|
|
||
| # total_training_steps = 20_000 | ||
| total_training_steps = 500 | ||
| batch_size = 4096 | ||
| total_training_tokens = total_training_steps * batch_size | ||
| print(f"Total Training Tokens: {total_training_tokens}") | ||
|
|
||
| lr_warm_up_steps = 0 | ||
| lr_decay_steps = 40_000 | ||
| print(f"lr_decay_steps: {lr_decay_steps}") | ||
| l1_warmup_steps = 10_000 | ||
| print(f"l1_warmup_steps: {l1_warmup_steps}") | ||
|
|
||
| return LanguageModelTranscoderRunnerConfig( | ||
| # Pick a tiny model to make this easier. | ||
| model_name="gelu-1l", | ||
| ## MLP Layer 0 ## | ||
| hook_name="blocks.0.ln2.hook_normalized", | ||
| hook_name_out="blocks.0.hook_mlp_out", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points) | ||
| hook_layer=0, # Only one layer in the model. | ||
| hook_layer_out=0, # Only one layer in the model. | ||
| d_in=512, # the width of the mlp input. | ||
| d_out=512, # the width of the mlp output. | ||
| dataset_path="NeelNanda/c4-tokenized-2b", | ||
| context_size=256, | ||
| is_dataset_tokenized=True, | ||
| prepend_bos=True, # I used to train GPT2 SAEs with a prepended-bos but no longer think we should do this. | ||
| # How big do we want our SAE to be? | ||
| expansion_factor=16, | ||
| # Dataset / Activation Store | ||
| # When we do a proper test | ||
| # training_tokens= 820_000_000, # 200k steps * 4096 batch size ~ 820M tokens (doable overnight on an A100) | ||
| # For now. | ||
| training_tokens=total_training_tokens, # For initial testing I think this is a good number. | ||
| train_batch_size_tokens=4096, | ||
| # Loss Function | ||
| ## Reconstruction Coefficient. | ||
| mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch. | ||
| ## Anthropic does not mention using an Lp norm other than L1. | ||
| l1_coefficient=5, | ||
| lp_norm=1.0, | ||
| # Instead, they multiply the L1 loss contribution | ||
| # from each feature of the activations by the decoder norm of the corresponding feature. | ||
| scale_sparsity_penalty_by_decoder_norm=True, | ||
| # Learning Rate | ||
| lr_scheduler_name="constant", # we set this independently of warmup and decay steps. | ||
| l1_warm_up_steps=l1_warmup_steps, | ||
| lr_warm_up_steps=lr_warm_up_steps, | ||
| lr_decay_steps=lr_warm_up_steps, | ||
| ## No ghost grad term. | ||
| use_ghost_grads=False, | ||
| # Initialization / Architecture | ||
| apply_b_dec_to_input=False, | ||
| # encoder bias zero's. (I'm not sure what it is by default now) | ||
| # decoder bias zero's. | ||
| b_dec_init_method="zeros", | ||
| normalize_sae_decoder=False, | ||
| decoder_heuristic_init=True, | ||
| init_encoder_as_decoder_transpose=True, | ||
| # Optimizer | ||
| lr=4e-5, | ||
| ## adam optimizer has no weight decay by default so worry about this. | ||
| adam_beta1=0.9, | ||
| adam_beta2=0.999, | ||
| # Buffer details won't matter in we cache / shuffle our activations ahead of time. | ||
| n_batches_in_buffer=64, | ||
| store_batch_size_prompts=16, | ||
| normalize_activations="constant_norm_rescale", | ||
| # Feature Store | ||
| feature_sampling_window=1000, | ||
| dead_feature_window=1000, | ||
| dead_feature_threshold=1e-4, | ||
| # performance enhancement: | ||
| compile_sae=True, | ||
| # WANDB | ||
| log_to_wandb=True, # always use wandb unless you are just testing code. | ||
| wandb_project="benchmark", | ||
| wandb_log_frequency=100, | ||
| # Misc | ||
| device=device, | ||
| seed=42, | ||
| n_checkpoints=0, | ||
| checkpoint_path="checkpoints", | ||
| dtype="float32", | ||
| ) | ||
|
|
||
|
|
||
| def run_training(cfg: LanguageModelTranscoderRunnerConfig): | ||
| sae = TranscoderTrainingRunner(cfg).run() | ||
| assert sae is not None | ||
| # know whether or not this works by looking at the dashboard! # know whether or not this works by looking at the dashboard! | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| parser = ArgumentParser() | ||
| parser.add_arguments( | ||
| LanguageModelTranscoderRunnerConfig, "cfg", default=get_default_config() | ||
| ) | ||
| args = parser.parse_args() | ||
| setup_env_vars() | ||
| run_training(args.cfg) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why the extra bias is needed. I'm probably just confused and missing something, but it would make the implementation simpler if you don't need it.
I understand that in normal SAEs people sometimes subtract b_dec from the input. This isn't really necessary but has a nice interpretation of choosing a new "0 point" which you can consider as the origin in the feature basis.
For transcoders this makes less sense. Since you aren't reconstructing the same activations you probably don't want to tie the pre-encoder bias with the post-decoder bias.
Thus, in the current implementation we do:
$$z = ReLU(W_{enc}(x - b_{dec}) + b_{enc})$$
$$out = W_{dec} x +b_\text{dec out}$$ $b_{dec}$ and $b_{enc}$ above) into a single bias term. I don't see a good reason why it would result in a more interpretable zero point for the encoder basis either.
and
This isn't any more expressive, you can always fold the first two biases (
Overall I'd recommend dropping the complexity here, which maybe means you can just eliminate the Transcoder class entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this makes sense! i'll try dropping the extra
b_decterm when training. I was initially concerned about supporting the previously-trained checkpoints, but as you say weight folding should solve that.