Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
53f937e
test: add unit tests for ActivationsStore multi-layer support
mkbehr Mar 30, 2025
2b7a438
Implement multilayer activations store except normalization
mkbehr Apr 5, 2025
d1c603b
CrosscoderSAE implementation, some tests
mkbehr Apr 6, 2025
f8bf44e
more norm tests
mkbehr Apr 6, 2025
ce63c0b
save-and-load support
mkbehr Apr 6, 2025
51bfac7
WIP name
mkbehr Apr 6, 2025
d118887
TrainingCrosscoderSAE implementation, decoder norm scaling test
mkbehr Apr 13, 2025
92ae2bd
test_TrainingCrosscoderSAE_encode_returns_same_value_as_encode_with_h…
mkbehr Apr 13, 2025
91190e5
test_sae_forward
mkbehr Apr 13, 2025
a73e8a7
test_sae_forward_with_mse_loss_norm
mkbehr Apr 13, 2025
f7149f4
mark ghost grads unsupported
mkbehr Apr 13, 2025
229192f
fix hook name in tests
mkbehr Apr 13, 2025
8d36da4
can_add_noise_to_hidden_pre test
mkbehr Apr 13, 2025
5c694e0
b_dec init note
mkbehr Apr 13, 2025
e38de51
fix from_dict
mkbehr Apr 13, 2025
b2ddc70
CrosscoderSAETrainer implementation, one test
mkbehr Apr 13, 2025
529109e
two more CrosscoderSAETrainer tests
mkbehr Apr 13, 2025
3935dbc
test log dict
mkbehr Apr 13, 2025
588de21
test_train_sae_group_on_language_model__runs
mkbehr Apr 13, 2025
2bc00bb
fix TrainingCrosscoderSAEConfig.to_dict
mkbehr Apr 14, 2025
7d54ea4
quick name fixes to satisfy wandb
mkbehr Apr 17, 2025
be7f780
use crosscoder from training runner
mkbehr Apr 17, 2025
0e6acdc
initialize W_dec in TrainingCrosscoderSAE
mkbehr Apr 17, 2025
9cf93ea
temporarily hardcode evals off
mkbehr Apr 17, 2025
c1cfde5
training script
mkbehr Apr 17, 2025
b17b607
add ActivationsStore.hook_names()
mkbehr Apr 17, 2025
5fb5b49
l2/sparsity/variance evals for crosscoders
mkbehr Apr 18, 2025
05512da
tiny-stories-1m experiments
mkbehr Apr 19, 2025
3d1abbd
tiny-stories-28m
mkbehr Apr 19, 2025
9ea92c8
minor fixes
mkbehr Apr 20, 2025
c45b08a
training changes
mkbehr Apr 20, 2025
c098be0
scale W_dec init norm
mkbehr Apr 20, 2025
7c01f2d
scale activations by layer
mkbehr Apr 20, 2025
7957fa9
some training changes
mkbehr Apr 21, 2025
093bc39
clean up some TODOs
mkbehr Apr 27, 2025
09abaab
trim CrosscoderSAETrainer
mkbehr Apr 27, 2025
e2deb2b
TODO notes in crosscoder trainer
mkbehr Apr 28, 2025
f2ea460
Change hook name syntax from {} to {layer}
mkbehr Apr 28, 2025
f8107b0
fix evals_test
mkbehr Apr 28, 2025
41dbb5b
fix activations store test
mkbehr Apr 28, 2025
b4e6c0d
fix test_cache_activations_runner
mkbehr Apr 28, 2025
7c090eb
fix test_crosscoder_sae
mkbehr Apr 28, 2025
eff955d
fix crosscoder sae trainer train step log dict
mkbehr Apr 28, 2025
406202e
Configure crosscoder decoder init norms
mkbehr May 3, 2025
8b397d7
Config rework (most tests fail)
mkbehr May 4, 2025
190d022
fix test_activations_store_multilayer
mkbehr May 4, 2025
c114946
test_crosscoder_sae passes
mkbehr May 4, 2025
cf82b58
training/test*crosscoder* passes
mkbehr May 4, 2025
9a29cba
fix evals
mkbehr May 4, 2025
7e2a40f
fix evals again; all tests pass
mkbehr May 4, 2025
e035b8c
"global" acausal crosscoder script for gpt2-small
mkbehr May 4, 2025
540e23f
remove some TODOs
mkbehr May 4, 2025
109eba8
remove more TODOs
mkbehr May 5, 2025
9989065
enable test_activations_store_normalization_multiple_layers
mkbehr May 6, 2025
c1ad393
Update to new disk loader
mkbehr May 6, 2025
07449ca
test saving multilayer activation norm
mkbehr May 6, 2025
9d40b8c
misc. cleanup
mkbehr May 6, 2025
c631c06
fix format
mkbehr May 7, 2025
bcfe7a5
fix some type errors
mkbehr May 7, 2025
0f7b349
revert changing wandb import line
mkbehr May 7, 2025
750ee92
train crosscoders without override_sae
mkbehr May 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class LanguageModelSAERunnerConfig:
model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
hook_names (list[str], optional): The names of multiple hooks to use, in order of evaluation. If this is nonempty, a CrosscoderSAE will be used. hook_name should be a descriptive name, and hook_layer should be the index of the last layer to hook.
hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here.
Expand Down Expand Up @@ -147,6 +148,7 @@ class LanguageModelSAERunnerConfig:
model_name: str = "gelu-2l"
model_class_name: str = "HookedTransformer"
hook_name: str = "blocks.0.hook_mlp_out"
hook_names: list[str] = field(default_factory=list)
hook_eval: str = "NOT_IN_USE"
hook_layer: int = 0
hook_head_index: int | None = None
Expand Down Expand Up @@ -444,6 +446,7 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]:
"device": self.device,
"model_name": self.model_name,
"hook_name": self.hook_name,
"hook_names": self.hook_names,
"hook_layer": self.hook_layer,
"hook_head_index": self.hook_head_index,
"activation_fn_str": self.activation_fn,
Expand Down Expand Up @@ -521,6 +524,7 @@ class CacheActivationsRunnerConfig:
model_name (str): The name of the model to use.
model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
hook_name (str): The name of the hook to use.
hook_names (list[str], optional): The names of multiple hooks to use, in order of evaluation. If this is nonempty, a CrosscoderSAE will be used.
hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.
d_in (int): Dimension of the model.
total_training_tokens (int): Total number of tokens to process.
Expand Down Expand Up @@ -555,6 +559,7 @@ class CacheActivationsRunnerConfig:
d_in: int
training_tokens: int

hook_names: list[str] = field(default_factory=list)
context_size: int = -1 # Required if dataset is not tokenized
model_class_name: str = "HookedTransformer"
# defaults to "activations/{dataset}/{model}/{hook_name}
Expand Down Expand Up @@ -608,8 +613,9 @@ def __post_init__(self):
)

if self.new_cached_activations_path is None:
hook_name_str = self.hook_name
self.new_cached_activations_path = _default_cached_activations_path( # type: ignore
self.dataset_path, self.model_name, self.hook_name, None
self.dataset_path, self.model_name, hook_name_str, None
)

@property
Expand Down
137 changes: 137 additions & 0 deletions sae_lens/crosscoder_sae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from dataclasses import dataclass, field
from typing import Any

import einops
import torch
from jaxtyping import Float

from sae_lens import SAE, SAEConfig
from sae_lens.toolkit.pretrained_sae_loaders import (
PretrainedSaeDiskLoader,
handle_config_defaulting,
sae_lens_disk_loader,
)


@dataclass
class CrosscoderSAEConfig(SAEConfig):
hook_names: list[str] = field(default_factory=list)

def to_dict(self) -> dict[str, Any]:
return super().to_dict() | {
"hook_names": self.hook_names,
}


class CrosscoderSAE(SAE):
"""
Sparse autoencoder that acts on multiple layers of activations.
"""

def __init__(
self,
cfg: CrosscoderSAEConfig,
use_error_term: bool = False,
):
if cfg.architecture != "standard":
raise NotImplementedError("TODO(mkbehr): support other architectures")

super().__init__(cfg=cfg, use_error_term=use_error_term)
self.cfg = cfg

if self.hook_z_reshaping_mode:
raise NotImplementedError("TODO(mkbehr): support hook_z")

@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "CrosscoderSAE":
return cls(CrosscoderSAEConfig.from_dict(config_dict)) # type: ignore

def input_shape(self):
return [len(self.cfg.hook_names), self.cfg.d_in]

def encode_standard(
self, x: Float[torch.Tensor, "... n_layers d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calculate SAE features from inputs
"""
sae_in = self.process_sae_in(x)

hidden_pre = self.hook_sae_acts_pre(
einops.einsum(
sae_in,
self.W_enc,
"... n_layers d_in, n_layers d_in d_sae -> ... d_sae",
)
+ self.b_enc
)
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))

def decode(
self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... n_layers d_in"]:
"""Decodes SAE feature activation tensor into a reconstructed
input activation tensor."""
sae_out = self.hook_sae_recons(
einops.einsum(
self.apply_finetuning_scaling_factor(feature_acts),
self.W_dec,
"... d_sae, d_sae n_layers d_in -> ... n_layers d_in",
)
+ self.b_dec
)

# handle run time activation normalization if needed
# will fail if you call this twice without calling encode in between.
sae_out = self.run_time_activation_norm_fn_out(sae_out)

# handle hook z reshaping if needed.
return self.reshape_fn_out(sae_out, self.d_head) # type: ignore

@torch.no_grad()
def fold_W_dec_norm(self):
W_dec_norms = self.W_dec.norm(dim=[-2, -1], keepdim=True)
self.W_dec.data = self.W_dec.data / W_dec_norms
self.W_enc.data = self.W_enc.data * einops.rearrange(
W_dec_norms, "d_sae 1 1 -> 1 1 d_sae"
)
if self.cfg.architecture == "gated":
self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
elif self.cfg.architecture == "jumprelu":
self.threshold.data = self.threshold.data * W_dec_norms.squeeze()
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
else:
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()

@torch.no_grad()
def fold_activation_norm_scaling_factor(
self, activation_norm_scaling_factor: Float[torch.Tensor, "n_layers"]
):
self.W_enc.data = self.W_enc.data * activation_norm_scaling_factor.reshape(
(-1, 1, 1)
)
# previously weren't doing this.
self.W_dec.data = self.W_dec.data / activation_norm_scaling_factor.unsqueeze(-1)
self.b_dec.data = self.b_dec.data / activation_norm_scaling_factor.unsqueeze(-1)

# once we normalize, we shouldn't need to scale activations.
self.cfg.normalize_activations = "none"

@classmethod
def load_from_disk(
cls,
path: str,
device: str = "cpu",
dtype: str | None = None,
converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
) -> "CrosscoderSAE":
overrides = {"dtype": dtype} if dtype is not None else None
cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
cfg_dict = handle_config_defaulting(cfg_dict)
sae_cfg = CrosscoderSAEConfig.from_dict(cfg_dict)
sae = cls(sae_cfg) # type: ignore
sae.process_state_dict_for_loading(state_dict)
sae.load_state_dict(state_dict)
return sae
30 changes: 22 additions & 8 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,9 @@ def get_sparsity_and_variance_metrics(
ignore_tokens: set[int | None] = set(),
verbose: bool = False,
) -> tuple[dict[str, Any], dict[str, Any]]:
hook_name = sae.cfg.hook_name
hook_names = (
sae.cfg.hook_names if hasattr(sae.cfg, "hook_names") else [sae.cfg.hook_name]
)
hook_head_index = sae.cfg.hook_head_index

metric_dict = {}
Expand Down Expand Up @@ -434,7 +436,7 @@ def get_sparsity_and_variance_metrics(
_, cache = model.run_with_cache(
batch_tokens,
prepend_bos=False,
names_filter=[hook_name],
names_filter=hook_names,
stop_at_layer=sae.cfg.hook_layer + 1,
**model_kwargs,
)
Expand All @@ -443,11 +445,20 @@ def get_sparsity_and_variance_metrics(
# which will do their own reshaping for hook z.
has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v", "hook_z"]
if hook_head_index is not None:
original_act = cache[hook_name][:, :, hook_head_index]
elif any(substring in hook_name for substring in has_head_dim_key_substrings):
original_act = cache[hook_name].flatten(-2, -1)
# TODO(mkbehr) support head dimension for mutilayer evals
assert len(hook_names) == 1
original_act = cache[hook_names[0]][:, :, hook_head_index]
elif any(
substring in hook_names[0] for substring in has_head_dim_key_substrings
):
# TODO(mkbehr) support head dimension for mutilayer evals
original_act = cache[hook_names[0]].flatten(-2, -1)
elif hasattr(sae.cfg, "hook_names"):
# TODO(mkbehr): support head dimension for mutilayer evals
layerwise_activations = [cache[hook_name] for hook_name in hook_names]
original_act = torch.stack(layerwise_activations, dim=2)
else:
original_act = cache[hook_name]
original_act = cache[hook_names[0]]

# normalise if necessary (necessary in training only, otherwise we should fold the scaling in)
if activation_store.normalize_activations == "expected_average_only_in":
Expand All @@ -461,14 +472,17 @@ def get_sparsity_and_variance_metrics(
if activation_store.normalize_activations == "expected_average_only_in":
sae_out = activation_store.unscale(sae_out)

flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d")
flattened_sae_input = einops.rearrange(
original_act, "b ctx d ... -> (b ctx) (d ...)"
)
flattened_sae_feature_acts = einops.rearrange(
sae_feature_activations, "b ctx d -> (b ctx) d"
)
flattened_sae_out = einops.rearrange(sae_out, "b ctx d -> (b ctx) d")
flattened_sae_out = einops.rearrange(sae_out, "b ctx d ... -> (b ctx) (d ...)")

# TODO: Clean this up.
# apply mask
# TODO(mkbehr): test mask support w/ multilayer
masked_sae_feature_activations = sae_feature_activations * mask.unsqueeze(-1)
flattened_sae_input = flattened_sae_input[
flattened_mask.to(flattened_sae_input.device)
Expand Down
29 changes: 22 additions & 7 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ def run_time_activation_ln_out(

self.setup() # Required for `HookedRootModule`s

def input_shape(self):
return [self.cfg.d_in]

def initialize_weights_basic(self):
# no config changes encoder bias init for now.
self.b_enc = nn.Parameter(
Expand All @@ -254,22 +257,28 @@ def initialize_weights_basic(self):
self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
self.cfg.d_sae,
*self.input_shape(),
dtype=self.dtype,
device=self.device,
)
)
)

self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
*self.input_shape(),
self.cfg.d_sae,
dtype=self.dtype,
device=self.device,
)
)
)

# methdods which change b_dec as a function of the dataset are implemented after init.
self.b_dec = nn.Parameter(
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
torch.zeros(*self.input_shape(), dtype=self.dtype, device=self.device)
)

# scaling factor for fine-tuning (not to be used in initial training)
Expand All @@ -284,7 +293,10 @@ def initialize_weights_gated(self):
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
*self.input_shape(),
self.cfg.d_sae,
dtype=self.dtype,
device=self.device,
)
)
)
Expand All @@ -304,13 +316,16 @@ def initialize_weights_gated(self):
self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
self.cfg.d_sae,
*self.input_shape(),
dtype=self.dtype,
device=self.device,
)
)
)

self.b_dec = nn.Parameter(
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
torch.zeros(*self.input_shape(), dtype=self.dtype, device=self.device)
)

def initialize_weights_jumprelu(self):
Expand Down Expand Up @@ -640,7 +655,7 @@ def from_pretrained(
)
cfg_dict = handle_config_defaulting(cfg_dict)

sae = cls(SAEConfig.from_dict(cfg_dict))
sae = cls.from_dict(cfg_dict)
sae.process_state_dict_for_loading(state_dict)
sae.load_state_dict(state_dict)

Expand Down
Loading