diff --git a/buffer.py b/buffer.py index 6b81f75..b889004 100644 --- a/buffer.py +++ b/buffer.py @@ -1,5 +1,8 @@ import torch as t import zstandard as zstd +import glob +from datetime import datetime +import os import json import io from nnsight import LanguageModel @@ -13,6 +16,8 @@ def __init__(self, data, # generator which yields text data model, # LanguageModel from which to extract activations submodules, # submodule of the model from which to extract activations + activation_save_dirs=None, # paths to save cached activations, one per submodule; if an individual path is None, do not cache for that submodule + activation_cache_dirs=None, # directories with cached activations to load in_feats=None, out_feats=None, io='out', # can be 'in', 'out', or 'in_to_out' @@ -22,17 +27,20 @@ def __init__(self, out_batch_size=8192, # size of batches in which to return activations device='cpu' # device on which to store the activations ): - + if activation_save_dirs is not None and activation_cache_dirs is not None: + raise ValueError("Cannot specify both activation_save_dirs and activation_cache_dirs because we cannot cache while using cached values. Choose one.") # dictionary of activations - self.activations = {} - for submodule in submodules: + self.activations = [None for _ in submodules] + if activation_cache_dirs is not None: + self.file_iters = [iter(glob.glob(os.path.join(dir_path, '*.pt'))) for dir_path in (activation_cache_dirs)] + for i, submodule in enumerate(submodules): if io == 'in': if in_feats is None: try: in_feats = submodule.in_features except: raise ValueError("in_feats cannot be inferred and must be specified directly") - self.activations[submodule] = t.empty(0, in_feats, device=device) + self.activations[i] = t.empty(0, in_feats, device=device) elif io == 'out': if out_feats is None: @@ -40,7 +48,7 @@ def __init__(self, out_feats = submodule.out_features except: raise ValueError("out_feats cannot be inferred and must be specified directly") - self.activations[submodule] = t.empty(0, out_feats, device=device) + self.activations[i] = t.empty(0, out_feats, device=device) elif io == 'in_to_out': raise ValueError("Support for in_to_out is depricated") self.read = t.zeros(0, dtype=t.bool, device=device) @@ -49,6 +57,8 @@ def __init__(self, self.data = data self.model = model # assumes nnsight model is already on the device self.submodules = submodules + self.activation_save_dirs = activation_save_dirs + self.activation_cache_dirs = activation_cache_dirs self.io = io self.n_ctxs = n_ctxs self.ctx_len = ctx_len @@ -63,6 +73,18 @@ def __next__(self): """ Return a batch of activations """ + if self.activation_cache_dirs is not None: + batch_activations = [] + for file_iter in self.file_iters: + try: + # Load next activation file from the current iterator + file_path = next(file_iter) + activations = t.load(file_path) + batch_activations.append(activations.to(self.device)) + except StopIteration: + # No more files to load, end of iteration + raise StopIteration + return batch_activations # if buffer is less than half full, refresh if (~self.read).sum() < self.n_ctxs * self.ctx_len // 2: self.refresh() @@ -71,9 +93,14 @@ def __next__(self): unreads = (~self.read).nonzero().squeeze() idxs = unreads[t.randperm(len(unreads), device=unreads.device)[:self.out_batch_size]] self.read[idxs] = True - return { - submodule : activations[idxs] for submodule, activations in self.activations.items() - } + batch_activations = [self.activations[i][idxs] for i in range(len(self.activations))] + if self.activation_save_dirs is not None: + for i, (activations_batch, path) in enumerate(zip(batch_activations, self.activation_save_dirs)): + if path is not None: + filename = f"activations_{i}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.pt" + filepath = os.path.join(path, filename) + t.save(activations_batch.cpu(), filepath) + return batch_activations def text_batch(self, batch_size=None): """ @@ -102,34 +129,34 @@ def tokenized_batch(self, batch_size=None): ) def refresh(self): - for submodule, activations in self.activations.items(): - self.activations[submodule] = activations[~self.read].contiguous() + for i, activations in enumerate(self.activations): + self.activations[i] = activations[~self.read].contiguous() self._n_activations = (~self.read).sum().item() while self._n_activations < self.n_ctxs * self.ctx_len: with self.model.invoke(self.text_batch(), truncation=True, max_length=self.ctx_len) as invoker: - hidden_states = {} - for submodule in self.submodules: + hidden_states = [None for _ in self.submodules] + for i, submodule in enumerate(self.submodules): if self.io == 'in': x = submodule.input else: x = submodule.output if (type(x.shape) == tuple): x = x[0] - hidden_states[submodule] = x.save() + hidden_states[i] = x.save() attn_mask = invoker.input['attention_mask'] self._n_activations += (attn_mask != 0).sum().item() - for submodule, activations in self.activations.items(): - self.activations[submodule] = t.cat(( + for i, activations in enumerate(self.activations): + self.activations[i] = t.cat(( activations, - hidden_states[submodule].value[attn_mask != 0].to(activations.device)), + hidden_states[i].value[attn_mask != 0].to(activations.device)), dim=0 ) - assert len(self.activations[submodule]) == self._n_activations + assert len(self.activations[i]) == self._n_activations self.read = t.zeros(self._n_activations, dtype=t.bool, device=self.device) diff --git a/training.py b/training.py index 3b81b5b..74fc68a 100644 --- a/training.py +++ b/training.py @@ -148,17 +148,17 @@ def resample_neurons(deads, activations, ae, optimizer): def trainSAE( buffer, # an ActivationBuffer - activation_dims, # dictionary of activation dimensions for each submodule (or a single int) - dictionary_sizes, # dictionary of dictionary sizes for each submodule (or a single int) - lr, + activation_dims, # list of activation dimensions for each submodule (or a single int) + dictionary_sizes, # list of dictionary sizes for each submodule (or a single int) + lrs, # list of learning rates for each submodule (or a single float) sparsity_penalty, entropy=False, steps=None, # if None, train until activations are exhausted warmup_steps=1000, # linearly increase the learning rate for this many steps resample_steps=None, # how often to resample dead neurons - ghost_threshold=None, # how many steps a neuron has to be dead for it to turn into a ghost + ghost_thresholds=None, # list of how many steps a neuron has to be dead for it to turn into a ghost (or a single int) save_steps=None, # how often to save checkpoints - save_dirs=None, # dictionary of directories to save checkpoints to + save_dirs=None, # list of directories to save checkpoints to checkpoint_offset=0, # if resuming training, the step number of the last checkpoint load_dirs=None, # if initializing from a pretrained dictionary, directories to load from log_steps=None, # how often to print statistics @@ -167,23 +167,25 @@ def trainSAE( Train and return sparse autoencoders for each submodule in the buffer. """ if isinstance(activation_dims, int): - activation_dims = {submodule: activation_dims for submodule in buffer.submodules} + activation_dims = [activation_dims for submodule in buffer.submodules] if isinstance(dictionary_sizes, int): - dictionary_sizes = {submodule: dictionary_sizes for submodule in buffer.submodules} - - aes = {} - num_samples_since_activateds = {} - for submodule in buffer.submodules: - ae = AutoEncoder(activation_dims[submodule], dictionary_sizes[submodule]).to(device) + dictionary_sizes = [dictionary_sizes for submodule in buffer.submodules] + if isinstance(lrs, float): + lrs = [lrs for submodule in buffer.submodules] + if isinstance(ghost_thresholds, int): + ghost_thresholds = [ghost_thresholds for submodule in buffer.submodules] + + aes = [None for submodule in buffer.submodules] + num_samples_since_activateds = [None for submodule in buffer.submodules] + for i, submodule in enumerate(buffer.submodules): + ae = AutoEncoder(activation_dims[i], dictionary_sizes[i]).to(device) if load_dirs is not None: - ae.load_state_dict(t.load(os.path.join(load_dirs[submodule]))) - aes[submodule] = ae - num_samples_since_activateds[submodule] = t.zeros(dictionary_sizes[submodule], dtype=int, device=device) + ae.load_state_dict(t.load(os.path.join(load_dirs[i]))) + aes[i] = ae + num_samples_since_activateds[i] = t.zeros(dictionary_sizes[i], dtype=int, device=device) # set up optimizer and scheduler - optimizers = { - submodule: ConstrainedAdam(ae.parameters(), ae.decoder.parameters(), lr=lr) for submodule, ae in aes.items() - } + optimizers = [ConstrainedAdam(ae.parameters(), ae.decoder.parameters(), lr=lrs[i]) for i, ae in enumerate(aes)] if resample_steps is None: def warmup_fn(step): return min(step / warmup_steps, 1.) @@ -191,21 +193,19 @@ def warmup_fn(step): def warmup_fn(step): return min((step % resample_steps) / warmup_steps, 1.) - schedulers = { - submodule: t.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn) for submodule, optimizer in optimizers.items() - } + schedulers = [t.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn) for optimizer in optimizers] for step, acts in enumerate(tqdm(buffer, total=steps)): real_step = step + checkpoint_offset if steps is not None and real_step >= steps: break - for submodule, act in acts.items(): + for i, act in enumerate(acts): act = act.to(device) ae, num_samples_since_activated, optimizer, scheduler \ - = aes[submodule], num_samples_since_activateds[submodule], optimizers[submodule], schedulers[submodule] + = aes[i], num_samples_since_activateds[i], optimizers[i], schedulers[i] optimizer.zero_grad() - loss = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_threshold) + loss = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_thresholds[i]) loss.backward() optimizer.step() scheduler.step() @@ -218,8 +218,8 @@ def warmup_fn(step): # logging if log_steps is not None and step % log_steps == 0: with t.no_grad(): - losses = sae_loss(act, ae, sparsity_penalty, entropy, separate=True, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_threshold) - if ghost_threshold is None: + losses = sae_loss(act, ae, sparsity_penalty, use_entropy=entropy, num_samples_since_activated=num_samples_since_activated, ghost_threshold=ghost_thresholds[i], separate=True) + if ghost_thresholds is None: mse_loss, sparsity_loss = losses print(f"step {step} MSE loss: {mse_loss}, sparsity loss: {sparsity_loss}") else: @@ -234,11 +234,11 @@ def warmup_fn(step): # saving if save_steps is not None and save_dirs is not None and real_step % save_steps == 0: - if not os.path.exists(os.path.join(save_dirs[submodule], "checkpoints")): - os.mkdir(os.path.join(save_dirs[submodule], "checkpoints")) + if not os.path.exists(os.path.join(save_dirs[i], "checkpoints")): + os.mkdir(os.path.join(save_dirs[i], "checkpoints")) t.save( ae.state_dict(), - os.path.join(save_dirs[submodule], "checkpoints", f"ae_{real_step}.pt") + os.path.join(save_dirs[i], "checkpoints", f"ae_{real_step}.pt") ) - return aes \ No newline at end of file + return aes