Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions experiments/run_train_transcoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import os

import torch
from simple_parsing import ArgumentParser

from sae_lens.config import LanguageModelTranscoderRunnerConfig
from sae_lens.sae_training_runner import TranscoderTrainingRunner


def setup_env_vars():
# Set the environment variables for the cache and the dataset.
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def get_default_config():
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"

# total_training_steps = 20_000
total_training_steps = 500
batch_size = 4096
total_training_tokens = total_training_steps * batch_size
print(f"Total Training Tokens: {total_training_tokens}")

lr_warm_up_steps = 0
lr_decay_steps = 40_000
print(f"lr_decay_steps: {lr_decay_steps}")
l1_warmup_steps = 10_000
print(f"l1_warmup_steps: {l1_warmup_steps}")

return LanguageModelTranscoderRunnerConfig(
# Pick a tiny model to make this easier.
model_name="gelu-1l",
## MLP Layer 0 ##
hook_name="blocks.0.ln2.hook_normalized",
hook_name_out="blocks.0.hook_mlp_out", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
hook_layer=0, # Only one layer in the model.
hook_layer_out=0, # Only one layer in the model.
d_in=512, # the width of the mlp input.
d_out=512, # the width of the mlp output.
dataset_path="NeelNanda/c4-tokenized-2b",
context_size=256,
is_dataset_tokenized=True,
prepend_bos=True, # I used to train GPT2 SAEs with a prepended-bos but no longer think we should do this.
# How big do we want our SAE to be?
expansion_factor=16,
# Dataset / Activation Store
# When we do a proper test
# training_tokens= 820_000_000, # 200k steps * 4096 batch size ~ 820M tokens (doable overnight on an A100)
# For now.
training_tokens=total_training_tokens, # For initial testing I think this is a good number.
train_batch_size_tokens=4096,
# Loss Function
## Reconstruction Coefficient.
mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch.
## Anthropic does not mention using an Lp norm other than L1.
l1_coefficient=5,
lp_norm=1.0,
# Instead, they multiply the L1 loss contribution
# from each feature of the activations by the decoder norm of the corresponding feature.
scale_sparsity_penalty_by_decoder_norm=True,
# Learning Rate
lr_scheduler_name="constant", # we set this independently of warmup and decay steps.
l1_warm_up_steps=l1_warmup_steps,
lr_warm_up_steps=lr_warm_up_steps,
lr_decay_steps=lr_warm_up_steps,
## No ghost grad term.
use_ghost_grads=False,
# Initialization / Architecture
apply_b_dec_to_input=False,
# encoder bias zero's. (I'm not sure what it is by default now)
# decoder bias zero's.
b_dec_init_method="zeros",
normalize_sae_decoder=False,
decoder_heuristic_init=True,
init_encoder_as_decoder_transpose=True,
# Optimizer
lr=4e-5,
## adam optimizer has no weight decay by default so worry about this.
adam_beta1=0.9,
adam_beta2=0.999,
# Buffer details won't matter in we cache / shuffle our activations ahead of time.
n_batches_in_buffer=64,
store_batch_size_prompts=16,
normalize_activations="constant_norm_rescale",
# Feature Store
feature_sampling_window=1000,
dead_feature_window=1000,
dead_feature_threshold=1e-4,
# performance enhancement:
compile_sae=True,
# WANDB
log_to_wandb=True, # always use wandb unless you are just testing code.
wandb_project="benchmark",
wandb_log_frequency=100,
# Misc
device=device,
seed=42,
n_checkpoints=0,
checkpoint_path="checkpoints",
dtype="float32",
)


def run_training(cfg: LanguageModelTranscoderRunnerConfig):
sae = TranscoderTrainingRunner(cfg).run()
assert sae is not None
# know whether or not this works by looking at the dashboard! # know whether or not this works by looking at the dashboard!


if __name__ == "__main__":

parser = ArgumentParser()
parser.add_arguments(
LanguageModelTranscoderRunnerConfig, "cfg", default=get_default_config()
)
args = parser.parse_args()
setup_env_vars()
run_training(args.cfg)
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ mkdocs-section-index = "^0.3.8"
mkdocstrings = "^0.24.1"
mkdocstrings-python = "^1.9.0"


[tool.poetry.group.tutorials.dependencies]
ipykernel = "^6.29.4"
simple-parsing = "^0.1.5"

[tool.poetry.extras]
mamba = ["mamba-lens"]

Expand Down
18 changes: 18 additions & 0 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,24 @@ def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
return cls(**cfg)


@dataclass
class LanguageModelTranscoderRunnerConfig(LanguageModelSAERunnerConfig):
d_out: int = 512
hook_name_out: str = "blocks.0.hook_mlp_out"
hook_layer_out: int = 0
hook_head_index_out: Optional[int] = None

def get_base_sae_cfg_dict(self) -> dict[str, Any]:
"""Returns the config for the base Transcoder."""
return {
**super().get_base_sae_cfg_dict(),
"d_out": self.d_out,
"hook_name_out": self.hook_name_out,
"hook_layer_out": self.hook_layer_out,
"hook_head_index_out": self.hook_head_index_out,
}


@dataclass
class CacheActivationsRunnerConfig:
"""
Expand Down
147 changes: 99 additions & 48 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ def to_dict(self) -> dict[str, Any]:
}


@dataclass
class TranscoderConfig(SAEConfig):
# transcoder-specific forward pass details
d_out: int

# transcoder-specific dataset details
hook_name_out: str
hook_layer_out: int
hook_head_index_out: Optional[int]


class SAE(HookedRootModule):
"""
Core Sparse Autoencoder (SAE) class used for inference. For training, see `TrainingSAE`.
Expand Down Expand Up @@ -216,48 +227,48 @@ def forward(
feature_acts = self.encode(x)
sae_out = self.decode(feature_acts)

if self.use_error_term:
with torch.no_grad():
# Recompute everything without hooks to get true error term
# Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct
# This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail
# NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction.

# move x to correct dtype
x = x.to(self.dtype)

# handle hook z reshaping if needed.
sae_in = self.reshape_fn_in(x) # type: ignore

# handle run time activation normalization if needed
sae_in = self.run_time_activation_norm_fn_in(sae_in)
if not self.use_error_term:
return self.hook_sae_output(sae_out)

# apply b_dec_to_input if using that method.
sae_in_cent = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input)

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = sae_in_cent @ self.W_enc + self.b_enc
feature_acts = self.activation_fn(hidden_pre)
x_reconstruct_clean = self.reshape_fn_out(
self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec
+ self.b_dec,
d_head=self.d_head,
)

sae_out = self.run_time_activation_norm_fn_out(sae_out)
sae_error = self.hook_sae_error(x - x_reconstruct_clean)

return self.hook_sae_output(sae_out + sae_error)

return self.hook_sae_output(sae_out)
# If using error term, compute the error term and add it to the output
with torch.no_grad():
# Recompute everything without hooks to get true error term
# Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct
# This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail
# NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction.
feature_acts_clean = self.encode(x, apply_hooks=False)
x_reconstruct_clean = self.decode(feature_acts_clean, apply_hooks=False)
sae_error = self.hook_sae_error(x - x_reconstruct_clean)
return self.hook_sae_output(sae_out + sae_error)

def encode(
self, x: Float[torch.Tensor, "... d_in"]
self, x: Float[torch.Tensor, "... d_in"], apply_hooks: bool = True
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calcuate SAE features from inputs
"""
sae_in = self.get_sae_in(x)
if apply_hooks:
sae_in = self.hook_sae_input(sae_in)

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = sae_in @ self.W_enc + self.b_enc
if apply_hooks:
hidden_pre = self.hook_sae_acts_pre(hidden_pre)

feature_acts = self.activation_fn(hidden_pre)
if apply_hooks:
feature_acts = self.hook_sae_acts_post(feature_acts)

return feature_acts

def get_sae_in(
self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_in_reshaped"]:
"""Get the input to the SAE.

Fixes dtype, reshapes, normalizes, and applies b_dec if necessary.
"""
# move x to correct dtype
x = x.to(self.dtype)

Expand All @@ -268,31 +279,33 @@ def encode(
x = self.run_time_activation_norm_fn_in(x)

# apply b_dec_to_input if using that method.
sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input))

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))

return feature_acts
sae_in = x - (self.b_dec * self.cfg.apply_b_dec_to_input)
return sae_in

def decode(
self, feature_acts: Float[torch.Tensor, "... d_sae"]
self, feature_acts: Float[torch.Tensor, "... d_sae"], apply_hooks: bool = True
) -> Float[torch.Tensor, "... d_in"]:
"""Decodes SAE feature activation tensor into a reconstructed input activation tensor."""
# "... d_sae, d_sae d_in -> ... d_in",
sae_out = self.hook_sae_recons(
self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec
)
sae_recons = self.get_sae_recons(feature_acts)
if apply_hooks:
sae_recons = self.hook_sae_recons(sae_recons)

# handle run time activation normalization if needed
# will fail if you call this twice without calling encode in between.
sae_out = self.run_time_activation_norm_fn_out(sae_out)
sae_recons = self.run_time_activation_norm_fn_out(sae_recons)

# handle hook z reshaping if needed.
sae_out = self.reshape_fn_out(sae_out, self.d_head) # type: ignore
sae_recons = self.reshape_fn_out(sae_recons, self.d_head) # type: ignore

return sae_out
return sae_recons

def get_sae_recons(
self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_in"]:
return (
self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec
)

@torch.no_grad()
def fold_W_dec_norm(self):
Expand Down Expand Up @@ -443,3 +456,41 @@ def tanh_relu(input: torch.Tensor) -> torch.Tensor:
return tanh_relu
else:
raise ValueError(f"Unknown activation function: {activation_fn}")


class Transcoder(SAE):
"""A variant of sparse autoencoders that have different input and output hook points."""

cfg: TranscoderConfig # type: ignore
dtype: torch.dtype
device: torch.device

def __init__(
self,
cfg: TranscoderConfig,
use_error_term: bool = False,
):
assert isinstance(
cfg, TranscoderConfig
), f"Expected TranscoderConfig, got {cfg}"
if use_error_term:
raise NotImplementedError("Error term not yet supported for Transcoder")
super().__init__(cfg, use_error_term)

def initialize_weights_basic(self):
super().initialize_weights_basic()

# NOTE: Transcoders have an additional b_dec_out parameter.
# Reference: https://github.com/jacobdunefsky/transcoder_circuits/blob/7b44d870a5a301ef29eddfd77cb1f4dca854760a/sae_training/sparse_autoencoder.py#L93C1-L97C14
self.b_dec_out = nn.Parameter(
torch.zeros(self.cfg.d_out, dtype=self.dtype, device=self.device)
)
Comment on lines +483 to +487
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the extra bias is needed. I'm probably just confused and missing something, but it would make the implementation simpler if you don't need it.

I understand that in normal SAEs people sometimes subtract b_dec from the input. This isn't really necessary but has a nice interpretation of choosing a new "0 point" which you can consider as the origin in the feature basis.

For transcoders this makes less sense. Since you aren't reconstructing the same activations you probably don't want to tie the pre-encoder bias with the post-decoder bias.

Thus, in the current implementation we do:
$$z = ReLU(W_{enc}(x - b_{dec}) + b_{enc})$$
and
$$out = W_{dec} x +b_\text{dec out}$$
This isn't any more expressive, you can always fold the first two biases ($b_{dec}$ and $b_{enc}$ above) into a single bias term. I don't see a good reason why it would result in a more interpretable zero point for the encoder basis either.

Overall I'd recommend dropping the complexity here, which maybe means you can just eliminate the Transcoder class entirely.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes sense! i'll try dropping the extra b_dec term when training. I was initially concerned about supporting the previously-trained checkpoints, but as you say weight folding should solve that.


def get_sae_recons(
self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_out"]:
# NOTE: b_dec_out instead of b_dec
return (
self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec
+ self.b_dec_out
)
Loading