11import torch as t
22import zstandard as zstd
3+ import glob
4+ from datetime import datetime
5+ import os
36import json
47import io
58from nnsight import LanguageModel
@@ -13,6 +16,8 @@ def __init__(self,
1316 data , # generator which yields text data
1417 model , # LanguageModel from which to extract activations
1518 submodules , # submodule of the model from which to extract activations
19+ activation_save_dirs = None , # paths to save cached activations, one per submodule; if an individual path is None, do not cache for that submodule
20+ activation_cache_dirs = None , # directories with cached activations to load
1621 in_feats = None ,
1722 out_feats = None ,
1823 io = 'out' , # can be 'in', 'out', or 'in_to_out'
@@ -22,9 +27,12 @@ def __init__(self,
2227 out_batch_size = 8192 , # size of batches in which to return activations
2328 device = 'cpu' # device on which to store the activations
2429 ):
25-
30+ if activation_save_dirs is not None and activation_cache_dirs is not None :
31+ raise ValueError ("Cannot specify both activation_save_dirs and activation_cache_dirs because we cannot cache while using cached values. Choose one." )
2632 # dictionary of activations
2733 self .activations = [None for _ in submodules ]
34+ if activation_cache_dirs is not None :
35+ self .file_iters = [iter (glob .glob (os .path .join (dir_path , '*.pt' ))) for dir_path in (activation_cache_dirs )]
2836 for i , submodule in enumerate (submodules ):
2937 if io == 'in' :
3038 if in_feats is None :
@@ -49,6 +57,8 @@ def __init__(self,
4957 self .data = data
5058 self .model = model # assumes nnsight model is already on the device
5159 self .submodules = submodules
60+ self .activation_save_dirs = activation_save_dirs
61+ self .activation_cache_dirs = activation_cache_dirs
5262 self .io = io
5363 self .n_ctxs = n_ctxs
5464 self .ctx_len = ctx_len
@@ -63,6 +73,18 @@ def __next__(self):
6373 """
6474 Return a batch of activations
6575 """
76+ if self .activation_cache_dirs is not None :
77+ batch_activations = []
78+ for file_iter in self .file_iters :
79+ try :
80+ # Load next activation file from the current iterator
81+ file_path = next (file_iter )
82+ activations = t .load (file_path )
83+ batch_activations .append (activations .to (self .device ))
84+ except StopIteration :
85+ # No more files to load, end of iteration
86+ raise StopIteration
87+ return batch_activations
6688 # if buffer is less than half full, refresh
6789 if (~ self .read ).sum () < self .n_ctxs * self .ctx_len // 2 :
6890 self .refresh ()
@@ -71,7 +93,14 @@ def __next__(self):
7193 unreads = (~ self .read ).nonzero ().squeeze ()
7294 idxs = unreads [t .randperm (len (unreads ), device = unreads .device )[:self .out_batch_size ]]
7395 self .read [idxs ] = True
74- return [self .activations [i ][idxs ] for i in range (len (self .activations ))]
96+ batch_activations = [self .activations [i ][idxs ] for i in range (len (self .activations ))]
97+ if self .activation_save_dirs is not None :
98+ for i , (activations_batch , path ) in enumerate (zip (batch_activations , self .activation_save_dirs )):
99+ if path is not None :
100+ filename = f"activations_{ i } _{ datetime .now ().strftime ('%Y%m%d%H%M%S%f' )} .pt"
101+ filepath = os .path .join (path , filename )
102+ t .save (activations_batch .cpu (), filepath )
103+ return batch_activations
75104
76105 def text_batch (self , batch_size = None ):
77106 """
0 commit comments