Skip to content

Commit 72e23be

Browse files
committed
add way to cache activations from layer
This allows one to re-train a sparse autoencoder on the same layer without re-generating all of the activations to train on.
1 parent a23173f commit 72e23be

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

buffer.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import torch as t
22
import zstandard as zstd
3+
import glob
4+
from datetime import datetime
5+
import os
36
import json
47
import io
58
from nnsight import LanguageModel
@@ -13,6 +16,8 @@ def __init__(self,
1316
data, # generator which yields text data
1417
model, # LanguageModel from which to extract activations
1518
submodules, # submodule of the model from which to extract activations
19+
activation_save_dirs=None, # paths to save cached activations, one per submodule; if an individual path is None, do not cache for that submodule
20+
activation_cache_dirs=None, # directories with cached activations to load
1621
in_feats=None,
1722
out_feats=None,
1823
io='out', # can be 'in', 'out', or 'in_to_out'
@@ -22,9 +27,12 @@ def __init__(self,
2227
out_batch_size=8192, # size of batches in which to return activations
2328
device='cpu' # device on which to store the activations
2429
):
25-
30+
if activation_save_dirs is not None and activation_cache_dirs is not None:
31+
raise ValueError("Cannot specify both activation_save_dirs and activation_cache_dirs because we cannot cache while using cached values. Choose one.")
2632
# dictionary of activations
2733
self.activations = [None for _ in submodules]
34+
if activation_cache_dirs is not None:
35+
self.file_iters = [iter(glob.glob(os.path.join(dir_path, '*.pt'))) for dir_path in (activation_cache_dirs)]
2836
for i, submodule in enumerate(submodules):
2937
if io == 'in':
3038
if in_feats is None:
@@ -49,6 +57,8 @@ def __init__(self,
4957
self.data = data
5058
self.model = model # assumes nnsight model is already on the device
5159
self.submodules = submodules
60+
self.activation_save_dirs = activation_save_dirs
61+
self.activation_cache_dirs = activation_cache_dirs
5262
self.io = io
5363
self.n_ctxs = n_ctxs
5464
self.ctx_len = ctx_len
@@ -63,6 +73,18 @@ def __next__(self):
6373
"""
6474
Return a batch of activations
6575
"""
76+
if self.activation_cache_dirs is not None:
77+
batch_activations = []
78+
for file_iter in self.file_iters:
79+
try:
80+
# Load next activation file from the current iterator
81+
file_path = next(file_iter)
82+
activations = t.load(file_path)
83+
batch_activations.append(activations.to(self.device))
84+
except StopIteration:
85+
# No more files to load, end of iteration
86+
raise StopIteration
87+
return batch_activations
6688
# if buffer is less than half full, refresh
6789
if (~self.read).sum() < self.n_ctxs * self.ctx_len // 2:
6890
self.refresh()
@@ -71,7 +93,14 @@ def __next__(self):
7193
unreads = (~self.read).nonzero().squeeze()
7294
idxs = unreads[t.randperm(len(unreads), device=unreads.device)[:self.out_batch_size]]
7395
self.read[idxs] = True
74-
return [self.activations[i][idxs] for i in range(len(self.activations))]
96+
batch_activations = [self.activations[i][idxs] for i in range(len(self.activations))]
97+
if self.activation_save_dirs is not None:
98+
for i, (activations_batch, path) in enumerate(zip(batch_activations, self.activation_save_dirs)):
99+
if path is not None:
100+
filename = f"activations_{i}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.pt"
101+
filepath = os.path.join(path, filename)
102+
t.save(activations_batch.cpu(), filepath)
103+
return batch_activations
75104

76105
def text_batch(self, batch_size=None):
77106
"""

0 commit comments

Comments
 (0)