Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0a53fbf
support seqpos slicing
callummcdougall Sep 18, 2024
3ba222b
add basic tests, ensure it's in the SAE config
jbloomAus Sep 20, 2024
b54d188
format
jbloomAus Sep 20, 2024
264a570
fix tests
jbloomAus Sep 20, 2024
48b92c5
fix tests 2
jbloomAus Sep 20, 2024
54d1105
fix: Changing the activations store to handle context sizes smaller t…
zhenningdavidliu Sep 30, 2024
eb04a01
fix: Found bug which allowed for negative context lengths. Removed th…
zhenningdavidliu Sep 30, 2024
cc43814
Update pytest to test new logic for context size of tokenized dataset
decandido Sep 30, 2024
0284000
Reformat code to pass CI tests
decandido Sep 30, 2024
c12550f
Add warning for when context_size is smaller than the dataset context…
decandido Oct 1, 2024
59439bf
feat: adding support for start and end position offsets for token seq…
zhenningdavidliu Oct 1, 2024
ac7ed3b
Add start_pos_offset and end_pos_offset to the SAERunnerConfig
decandido Oct 2, 2024
560ae8a
Add tests for start_pos_offset and end_pos_offset in the LanguageMode…
decandido Oct 2, 2024
93ebea6
feat: start and end position offset support for SAELens.
zhenningdavidliu Oct 2, 2024
340500f
Add test for CacheActivationsRunnerConfig with start and end pos offset
decandido Oct 2, 2024
c436a4f
Test cache activation runner wtih valid start and end pos offset
decandido Oct 2, 2024
bdbb585
feat: Enabling loading of start and end pos offset from saes. Adding
zhenningdavidliu Oct 2, 2024
7f3b76a
fix: Renaming variables and a test
zhenningdavidliu Oct 3, 2024
755ba75
adds test for position offests for saes
zhenningdavidliu Oct 3, 2024
d680041
reformats files with black
decandido Oct 3, 2024
776fdd7
Add start and end pos offset to the base sae dict
decandido Oct 3, 2024
0625447
fix test for sae training runner config with position offsets
decandido Oct 3, 2024
f7d6a38
add a benchmark test to train an SAE on OthelloGPT
decandido Oct 3, 2024
9f16ff2
Remove double import from typing
decandido Oct 3, 2024
99ace75
change dead_feature_window to int
decandido Oct 3, 2024
c0dc5bf
remove print statements from test file
decandido Oct 4, 2024
9130ff9
Rebase on seqpos tuple implementation and remove start/end pos offset
decandido Oct 4, 2024
125b275
Reword docstring for seqpos to be clearer.
decandido Oct 9, 2024
552eea6
Added script to train an SAE on othelloGPT
decandido Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class LanguageModelSAERunnerConfig:
store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.
train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop.
normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).
seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
device (str): The device to use. Usually cuda.
act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
seed (int): The seed to use.
Expand Down Expand Up @@ -153,6 +154,7 @@ class LanguageModelSAERunnerConfig:
normalize_activations: str = (
"none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
)
seqpos_slice: tuple[int | None, ...] = (None,)

# Misc
device: str = "cpu"
Expand Down Expand Up @@ -355,6 +357,13 @@ def __post_init__(self):
if self.use_ghost_grads:
print("Using Ghost Grads.")

if self.context_size < 0:
raise ValueError(
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
)

_validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size)

@property
def total_training_tokens(self) -> int:
return self.training_tokens + self.finetuning_tokens
Expand Down Expand Up @@ -386,6 +395,7 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]:
"normalize_activations": self.normalize_activations,
"activation_fn_kwargs": self.activation_fn_kwargs,
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
"seqpos_slice": self.seqpos_slice,
}

def get_training_sae_cfg_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -427,6 +437,15 @@ def to_json(self, path: str) -> None:
def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
with open(path + "cfg.json", "r") as f:
cfg = json.load(f)

# ensure that seqpos slices is a tuple
# Ensure seqpos_slice is a tuple
if "seqpos_slice" in cfg:
if isinstance(cfg["seqpos_slice"], list):
cfg["seqpos_slice"] = tuple(cfg["seqpos_slice"])
elif not isinstance(cfg["seqpos_slice"], tuple):
cfg["seqpos_slice"] = (cfg["seqpos_slice"],)

return cls(**cfg)


Expand Down Expand Up @@ -461,6 +480,7 @@ class CacheActivationsRunnerConfig:
store_batch_size_prompts: int = 32
train_batch_size_tokens: int = 4096
normalize_activations: str = "none" # should always be none for activation caching
seqpos_slice: tuple[int | None, ...] = (None,)

# Misc
device: str = "cpu"
Expand Down Expand Up @@ -491,6 +511,13 @@ def __post_init__(self):
if self.act_store_device == "with_model":
self.act_store_device = self.device

if self.context_size < 0:
raise ValueError(
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
)

_validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size)


@dataclass
class ToyModelSAERunnerConfig:
Expand Down Expand Up @@ -576,6 +603,17 @@ def _default_cached_activations_path(
return path


def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None:
# Ensure that the step-size is larger or equal to 1
if len(seqpos) == 3:
step_size = seqpos[2] or 1
assert (
step_size > 1
), f"Ensure the step_size {seqpos[2]=} for sequence slicing is positive."
# Ensure that the choice of seqpos doesn't end up with an empty list
assert len(list(range(context_size))[slice(*seqpos)]) > 0


@dataclass
class PretokenizeRunnerConfig:
tokenizer_name: str = "gpt2"
Expand Down
6 changes: 6 additions & 0 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class SAEConfig:
activation_fn_kwargs: dict[str, Any] = field(default_factory=dict)
neuronpedia_id: Optional[str] = None
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
seqpos_slice: tuple[int | None, ...] = (None,)

@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
Expand All @@ -81,6 +82,10 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
for k, v in config_dict.items()
if k in cls.__dataclass_fields__ # pylint: disable=no-member
}

if "seqpos_slice" in config_dict:
config_dict["seqpos_slice"] = tuple(config_dict["seqpos_slice"])

return cls(**config_dict)

# def __post_init__(self):
Expand Down Expand Up @@ -108,6 +113,7 @@ def to_dict(self) -> dict[str, Any]:
"normalize_activations": self.normalize_activations,
"neuronpedia_id": self.neuronpedia_id,
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
"seqpos_slice": self.seqpos_slice,
}


Expand Down
41 changes: 22 additions & 19 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def from_config(
model_kwargs=cfg.model_kwargs,
autocast_lm=cfg.autocast_lm,
dataset_trust_remote_code=cfg.dataset_trust_remote_code,
seqpos_slice=cfg.seqpos_slice,
)

@classmethod
Expand Down Expand Up @@ -123,6 +124,7 @@ def from_sae(
dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code,
dtype=sae.cfg.dtype,
device=torch.device(device),
seqpos_slice=sae.cfg.seqpos_slice,
)

def __init__(
Expand All @@ -147,6 +149,7 @@ def __init__(
model_kwargs: dict[str, Any] | None = None,
autocast_lm: bool = False,
dataset_trust_remote_code: bool | None = None,
seqpos_slice: tuple[int | None, ...] = (None,),
):
self.model = model
if model_kwargs is None:
Expand Down Expand Up @@ -188,6 +191,7 @@ def __init__(
self.dtype = DTYPE_MAP[dtype]
self.cached_activations_path = cached_activations_path
self.autocast_lm = autocast_lm
self.seqpos_slice = seqpos_slice

self.n_dataset_processed = 0

Expand Down Expand Up @@ -220,10 +224,6 @@ def __init__(
f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}.
The context_size {ds_context_size} is expected to be larger than or equal to the provided context size {self.context_size}."""
)
if self.context_size < 0:
raise ValueError(
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size"
)
if self.context_size != ds_context_size:
warnings.warn(
f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. Some data will be discarded in this case.""",
Expand Down Expand Up @@ -441,37 +441,39 @@ def get_activations(self, batch_tokens: torch.Tensor):
autocast_if_enabled = contextlib.nullcontext()

with autocast_if_enabled:
layerwise_activations = self.model.run_with_cache(
layerwise_activations_cache = self.model.run_with_cache(
batch_tokens,
names_filter=[self.hook_name],
stop_at_layer=self.hook_layer + 1,
prepend_bos=False,
**self.model_kwargs,
)[1]

n_batches, n_context = batch_tokens.shape
layerwise_activations = layerwise_activations_cache[self.hook_name][
:, slice(*self.seqpos_slice)
]

n_batches, n_context = layerwise_activations.shape[:2]

stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))

if self.hook_head_index is not None:
stacked_activations[:, :, 0] = layerwise_activations[self.hook_name][
stacked_activations[:, :, 0] = layerwise_activations[
:, :, self.hook_head_index
]
elif (
layerwise_activations[self.hook_name].ndim > 3
): # if we have a head dimension
elif layerwise_activations.ndim > 3: # if we have a head dimension
try:
stacked_activations[:, :, 0] = layerwise_activations[
self.hook_name
].view(n_batches, n_context, -1)
stacked_activations[:, :, 0] = layerwise_activations.view(
n_batches, n_context, -1
)
except RuntimeError as e:
print(f"Error during view operation: {e}")
print("Attempting to use reshape instead...")
stacked_activations[:, :, 0] = layerwise_activations[
self.hook_name
].reshape(n_batches, n_context, -1)
stacked_activations[:, :, 0] = layerwise_activations.reshape(
n_batches, n_context, -1
)
else:
stacked_activations[:, :, 0] = layerwise_activations[self.hook_name]
stacked_activations[:, :, 0] = layerwise_activations

return stacked_activations

Expand All @@ -487,14 +489,15 @@ def get_buffer(
If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
"""
context_size = self.context_size
training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
batch_size = self.store_batch_size_prompts
d_in = self.d_in
total_size = batch_size * n_batches_in_buffer
num_layers = 1

if self.cached_activations_path is not None:
# Load the activations from disk
buffer_size = total_size * context_size
buffer_size = total_size * training_context_size
# Initialize an empty tensor with an additional dimension for layers
new_buffer = torch.zeros(
(buffer_size, num_layers, d_in),
Expand Down Expand Up @@ -548,7 +551,7 @@ def get_buffer(
refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
new_buffer = torch.zeros(
(total_size, context_size, num_layers, d_in),
(total_size, training_context_size, num_layers, d_in),
dtype=self.dtype, # type: ignore
device=self.device,
)
Expand Down
13 changes: 13 additions & 0 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def from_sae_runner_config(
context_size=cfg.context_size,
dataset_path=cfg.dataset_path,
prepend_bos=cfg.prepend_bos,
seqpos_slice=cfg.seqpos_slice,
# Training cfg
l1_coefficient=cfg.l1_coefficient,
lp_norm=cfg.lp_norm,
Expand All @@ -99,6 +100,18 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig":
valid_config_dict = {
key: val for key, val in config_dict.items() if key in valid_field_names
}

# ensure seqpos slice is tuple
# ensure that seqpos slices is a tuple
# Ensure seqpos_slice is a tuple
if "seqpos_slice" in valid_config_dict:
if isinstance(valid_config_dict["seqpos_slice"], list):
valid_config_dict["seqpos_slice"] = tuple(
valid_config_dict["seqpos_slice"]
)
elif not isinstance(valid_config_dict["seqpos_slice"], tuple):
valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],)

return TrainingSAEConfig(**valid_config_dict)

def to_dict(self) -> dict[str, Any]:
Expand Down
107 changes: 107 additions & 0 deletions scripts/training_a_sparse_autoencoder_othelloGPT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os

import torch

from sae_lens import (
SAE,
HookedSAETransformer,
LanguageModelSAERunnerConfig,
SAETrainingRunner,
upload_saes_to_huggingface,
)

if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"


model_name = "othello-gpt"
model = HookedSAETransformer.from_pretrained(model_name)

dataset_path = "taufeeque/othellogpt"
context_size = 59

layer = 5
training_tokens = int(1e3)
train_batch_size_tokens = 2048
n_steps = int(training_tokens / train_batch_size_tokens)

print(LanguageModelSAERunnerConfig())
runner_cfg = LanguageModelSAERunnerConfig(
#
# Data generation
model_name=model_name,
hook_name=f"blocks.{layer}.mlp.hook_post",
hook_layer=layer,
d_in=model.cfg.d_mlp,
dataset_path=dataset_path,
is_dataset_tokenized=True,
prepend_bos=False,
streaming=True,
train_batch_size_tokens=train_batch_size_tokens,
context_size=context_size,
seqpos_slice=(5, -5),
#
# SAE achitecture
architecture="gated",
expansion_factor=8,
b_dec_init_method="zeros",
apply_b_dec_to_input=True,
normalize_sae_decoder=False,
scale_sparsity_penalty_by_decoder_norm=True,
decoder_heuristic_init=True,
init_encoder_as_decoder_transpose=True,
#
# Activations store
n_batches_in_buffer=32,
store_batch_size_prompts=16,
training_tokens=training_tokens,
#
# Training hyperparameters (standard)
lr=2e-4,
adam_beta1=0.9,
adam_beta2=0.999,
lr_scheduler_name="constant",
lr_warm_up_steps=int(0.2 * n_steps),
lr_decay_steps=int(0.2 * n_steps),
#
# Training hyperparameters (SAE-specific)
l1_coefficient=5,
l1_warm_up_steps=int(0.2 * n_steps),
use_ghost_grads=False,
feature_sampling_window=1000,
dead_feature_window=500,
dead_feature_threshold=1e-5,
#
# Logging / evals
log_to_wandb=True,
wandb_project=f"othello_gpt_sae_{layer=}",
wandb_log_frequency=30,
eval_every_n_wandb_logs=10,
checkpoint_path="checkpoints",
#
# Misc.
device=str(device),
seed=42,
n_checkpoints=5,
dtype="float32",
)

# t.set_grad_enabled(True)
runner = SAETrainingRunner(runner_cfg)
sae = runner.run()

hf_repo_id = "callummcdougall/arena-demos-othellogpt"
sae_id = "blocks.5.mlp.hook_post-v1"

upload_saes_to_huggingface({sae_id: sae}, hf_repo_id=hf_repo_id)

othellogpt_sae = SAE.from_pretrained(
release=hf_repo_id, sae_id=sae_id, device=str(device)
)[0]
Loading