Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 101 additions & 7 deletions sae_lens/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.util import str_to_dtype

# Directory names for temporary operations during caching
_TMP_SHARDS_DIR = ".tmp_shards"
_SHUFFLED_DIR = ".shuffled"
_BACKUP_SUFFIX = ".backup"


def _mk_activations_store(
model: HookedRootModule,
Expand All @@ -28,6 +33,7 @@ def _mk_activations_store(
Internal method used in CacheActivationsRunner. Used to create a cached dataset
from a ActivationsStore.
"""
device = torch.device("cpu") # since we're saving to disk
return ActivationsStore(
model=model,
dataset=override_dataset or cfg.dataset_path,
Expand All @@ -42,13 +48,15 @@ def _mk_activations_store(
train_batch_size_tokens=-1,
prepend_bos=cfg.prepend_bos,
normalize_activations="none",
device=torch.device("cpu"), # since we're saving to disk
device=device,
dtype=cfg.dtype,
cached_activations_path=None,
model_kwargs=cfg.model_kwargs,
autocast_lm=cfg.autocast_lm,
dataset_trust_remote_code=cfg.dataset_trust_remote_code,
seqpos_slice=cfg.seqpos_slice,
disable_concat_sequences=cfg.disable_concat_sequences,
sequence_separator_token=cfg.sequence_separator_token,
)


Expand Down Expand Up @@ -176,10 +184,10 @@ def _consolidate_shards(
f"output_dir is not an existing directory: {output_dir}"
)

other_items = [p for p in output_dir.iterdir() if p.name != ".tmp_shards"]
other_items = [p for p in output_dir.iterdir() if p.name != _TMP_SHARDS_DIR]
if other_items:
raise FileExistsError(
f"output_dir must be empty (besides .tmp_shards). Found: {other_items}"
f"output_dir must be empty (besides {_TMP_SHARDS_DIR}). Found: {other_items}"
)

if not (source_dir / first_shard_dir_name).exists():
Expand Down Expand Up @@ -221,7 +229,7 @@ def _consolidate_shards(
"_split": None,
}

# fingerprint is generated from dataset.__getstate__ (not includeing _fingerprint)
# fingerprint is generated from dataset.__getstate__ (not including _fingerprint)
with open(output_dir / "state.json", "w") as f:
json.dump(new_state, f, indent=2)

Expand Down Expand Up @@ -254,7 +262,7 @@ def run(self) -> Dataset:
f"Activations directory ({final_cached_activation_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files."
)

tmp_cached_activation_path = final_cached_activation_path / ".tmp_shards/"
tmp_cached_activation_path = final_cached_activation_path / _TMP_SHARDS_DIR
tmp_cached_activation_path.mkdir(exist_ok=False, parents=False)

### Create temporary sharded datasets
Expand Down Expand Up @@ -284,8 +292,93 @@ def run(self) -> Dataset:
)

if self.cfg.shuffle:
logger.info("Shuffling...")
dataset = dataset.shuffle(seed=self.cfg.seed)
# shuffle_across_sequences: shuffle individual activations globally,
# treating the entire dataset as a flat array of (total_tokens, d_in).
# This breaks up sequential patterns within sequences while keeping
# token_ids paired with their corresponding activations.
if self.cfg.shuffle_across_sequences:
logger.info("Shuffling across sequences...")
dataset.set_format("torch")
hook_name = self.cfg.hook_name

# Load all data and flatten
# With torch format, [:] returns tensors directly
all_data = dataset[:]
acts = all_data[hook_name] # (n_seq, context_size, d_in)
token_ids = all_data["token_ids"] # (n_seq, context_size)
n_seq = acts.shape[0]

acts_flat = einops.rearrange(
acts, "n_seq context_size d_in -> (n_seq context_size) d_in"
)
token_ids_flat = einops.rearrange(
token_ids, "n_seq context_size -> (n_seq context_size)"
)

# Shuffle globally with the same permutation for both
generator = torch.Generator().manual_seed(self.cfg.seed)
perm = torch.randperm(acts_flat.shape[0], generator=generator)
acts_flat = acts_flat[perm]
token_ids_flat = token_ids_flat[perm]

# Reshape back to sequences
acts_shuffled = einops.rearrange(
acts_flat,
"(n_seq context_size) d_in -> n_seq context_size d_in",
n_seq=n_seq,
context_size=self.context_size,
)
token_ids_shuffled = einops.rearrange(
token_ids_flat,
"(n_seq context_size) -> n_seq context_size",
n_seq=n_seq,
context_size=self.context_size,
)

# Create new dataset from shuffled data
dataset = Dataset.from_dict(
{
hook_name: acts_shuffled,
"token_ids": token_ids_shuffled.to(torch.int32),
},
features=self.features,
)
else:
# Sequence-level shuffle only: shuffle the order of sequences (rows)
# Skip if shuffle_across_sequences was used since global shuffle is stronger
logger.info("Shuffling sequences...")
dataset = dataset.shuffle(seed=self.cfg.seed)

# Save the shuffled dataset back to disk using atomic rename with backup
# to prevent data loss if the process crashes mid-operation.
# Note: shuffled_path must be a sibling (not child) of final_cached_activation_path
# so that renaming the parent doesn't invalidate the shuffled path.
shuffled_path = final_cached_activation_path.parent / (
final_cached_activation_path.name + _SHUFFLED_DIR
)
backup_path = final_cached_activation_path.parent / (
final_cached_activation_path.name + _BACKUP_SUFFIX
)

dataset.save_to_disk(str(shuffled_path))

# Atomic swap: rename original to backup, then shuffled to original
try:
final_cached_activation_path.rename(backup_path)
shuffled_path.rename(final_cached_activation_path)
# Success - remove backup
shutil.rmtree(backup_path)
except Exception:
# Rollback: restore from backup if it exists
if backup_path.exists() and not final_cached_activation_path.exists():
backup_path.rename(final_cached_activation_path)
# Clean up shuffled path if it still exists
if shuffled_path.exists():
shutil.rmtree(shuffled_path)
raise

# Reload the dataset from the new location
dataset = Dataset.load_from_disk(str(final_cached_activation_path))

if self.cfg.hf_repo_id:
logger.info("Pushing to Huggingface Hub...")
Expand Down Expand Up @@ -323,6 +416,7 @@ def _create_shard(
) -> Dataset:
hook_names = [self.cfg.hook_name]
acts, token_ids = buffer

acts = einops.rearrange(
acts,
"(bs context_size) d_in -> bs context_size d_in",
Expand Down
16 changes: 15 additions & 1 deletion sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ class CacheActivationsRunnerConfig:
context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
new_cached_activations_path (str, optional): The path to save the activations.
shuffle (bool): Whether to shuffle the dataset.
shuffle (bool): Whether to shuffle the dataset at the sequence level.
shuffle_across_sequences (bool): Whether to shuffle individual activations across all sequence positions within each buffer. This treats the buffer as a flat 2D array and shuffles activation positions while keeping token_ids paired with their activations.
seed (int): The seed to use for shuffling.
dtype (str): Datatype of activations to be stored.
device (str): The device for the model.
Expand All @@ -496,6 +497,8 @@ class CacheActivationsRunnerConfig:
streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
autocast_lm (bool): Whether to use autocast during activation fetching.
dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences.
sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `<bos>` token.
"""

dataset_path: str
Expand All @@ -510,6 +513,7 @@ class CacheActivationsRunnerConfig:
# defaults to "activations/{dataset}/{model}/{hook_name}
new_cached_activations_path: str | None = None
shuffle: bool = True
shuffle_across_sequences: bool = False
seed: int = 42
dtype: str = "float32"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -533,6 +537,10 @@ class CacheActivationsRunnerConfig:
streaming: bool = True
autocast_lm: bool = False
dataset_trust_remote_code: bool | None = None
disable_concat_sequences: bool = False
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
special_token_field(default="bos")
)

def __post_init__(self):
# Automatically determine context_size if dataset is tokenized
Expand Down Expand Up @@ -562,6 +570,12 @@ def __post_init__(self):
self.dataset_path, self.model_name, self.hook_name, None
)

if self.shuffle_across_sequences and not self.shuffle:
raise ValueError(
"shuffle_across_sequences=True requires shuffle=True. "
"Set shuffle=True to enable shuffling across sequences."
)

@property
def sliced_context_size(self) -> int:
if self.seqpos_slice is not None:
Expand Down
Loading
Loading