diff --git a/examples/early-testing/test_residual_pruning.py b/examples/early-testing/test_residual_pruning.py new file mode 100644 index 0000000..b5d13d7 --- /dev/null +++ b/examples/early-testing/test_residual_pruning.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +"""Simple test script for residual stream pruning functionality.""" + +from taker import Model +from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem +from taker.prune import prune_and_evaluate +from taker.activations import get_midlayer_data +from taker.eval import evaluate_all +import torch +import wandb + +hook_config = """ +post_decoder: mask, collect +""" + +c = PruningConfig("gpt2", + dtype="fp32", # MPS compatibility + wandb_entity = "seperability", + wandb_project = "bens-tests", + wandb_run_name = "gpt2 residual stream prune test", + token_limit = 100, + # Residual stream pruning only + ff_frac = 0.0, + attn_frac = 0.0, + residual_frac = 0.1, # Prune 10% of residual dimensions + residual_scoring = "abs", + focus = "civil", + cripple = "toxic", + recalculate_activations = False, + collection_sample_size = 100, + eval_sample_size = 100, + n_steps = 3, +) + +m = Model("gpt2", hook_config=hook_config) +m.hooks.enable_collect_hooks(["post_decoder"], run_assert=True) + +# Get initial activations +focus_data = get_midlayer_data(m, "civil", 100, collect_residual=True, calculate_residual=True, + collect_ff=False, calculate_ff=False, collect_attn=False, calculate_attn=False) +cripple_data = get_midlayer_data(m, "toxic", 100, collect_residual=True, calculate_residual=True, + collect_ff=False, calculate_ff=False, collect_attn=False, calculate_attn=False) + +history = RunDataHistory(list(c.datasets)) +wandb.init( + project=c.wandb_project, + entity=c.wandb_entity, + name=c.wandb_run_name, +) +wandb.config.update(c.to_dict(), allow_val_change=True) + +torch.set_grad_enabled(False) + +# Run the pruning +with torch.no_grad(): + # Evaluate without pruning first + data = RunDataItem() + eval_out = evaluate_all(m, c.eval_sample_size, c.datasets, + dataset_tokens_to_skip=c.collection_sample_size) + data.update(eval_out) + history.add(data) + + for i in range(c.n_steps): + print(f"Step {i}") + data = prune_and_evaluate(m, c, focus_data, cripple_data, i) + history.add(data) + print(f"Residual dimensions pruned: {data.deletions['residual_del']}") + print(f"Residual threshold: {data.deletions['residual_threshold']}") + + # Get new activations after pruning for next iteration + focus_data = get_midlayer_data(m, "civil", 100, collect_residual=True, calculate_residual=True, + collect_ff=False, calculate_ff=False, collect_attn=False, calculate_attn=False) + cripple_data = get_midlayer_data(m, "toxic", 100, collect_residual=True, calculate_residual=True, + collect_ff=False, calculate_ff=False, collect_attn=False, calculate_attn=False) \ No newline at end of file diff --git a/src/taker/activations.py b/src/taker/activations.py index c1be387..2691f79 100644 --- a/src/taker/activations.py +++ b/src/taker/activations.py @@ -73,10 +73,12 @@ def get_midlayer_data(opt: Model, calculate_ff: bool = True, calculate_attn: bool = True, calculate_sae: bool = False, + calculate_residual: bool = False, collect_ff: bool = False, collect_attn: bool = False, collect_ids: bool = False, collect_sae: bool = False, + collect_residual: bool = False, dataset_texts_to_skip: int = None, random_subset_frac: float = None, eval_config: EvalConfig = None, @@ -101,8 +103,9 @@ def get_midlayer_data(opt: Model, do_ff = calculate_ff or collect_ff do_attn = calculate_attn or collect_attn - do_collect = collect_ff or collect_attn or collect_ids or collect_sae + do_collect = collect_ff or collect_attn or collect_ids or collect_sae or collect_residual do_sae = calculate_sae or collect_sae + do_residual = calculate_residual or collect_residual # Get things ready for collection opt.hooks.disable_all_collect_hooks() @@ -122,6 +125,12 @@ def get_midlayer_data(opt: Model, if attn_peak is not None: attn_data_peak_centered = ActivationCollector(attn_shape, opt.output_device) + # residual stream activation collector + if do_residual: + residual_shape = (opt.cfg.n_layers, opt.cfg.d_model) + residual_data = ActivationCollector(residual_shape, opt.output_device, collect_residual) + opt.hooks.enable_collect_hooks(["post_decoder"]) + if do_sae: sae_hook_points = [point for point, layers in opt.hooks.hook_config.hook_points.items() if 'all' in layers and any('sae' in hook for hook in layers['all'])] @@ -160,6 +169,11 @@ def get_midlayer_data(opt: Model, if do_attn: attn_acts = opt.collect_recent_attn_pre_out() attn_acts = einops.rearrange(attn_acts, "b l t nh dh -> (b t) l nh dh") + if do_residual: + residual_acts = opt.hooks.get_all_layer_data("post_decoder", "collect") + # Stack across layers: (n_layers, batch, seq_len, d_model) + residual_acts = torch.stack([act for act in residual_acts if act is not None]) + residual_acts = einops.rearrange(residual_acts, "l b t d -> (b t) l d") if do_sae: sae_acts = {} for sae_hook in sae_hook_points: @@ -201,6 +215,8 @@ def get_midlayer_data(opt: Model, attn_data.add_all(attn_acts[criteria_indices]) if attn_peak is not None: attn_data_peak_centered.add_all((attn_acts - attn_peak)[criteria_indices]) + if do_residual: + residual_data.add_all(residual_acts[criteria_indices]) if do_sae: for sae_hook in sae_hook_points: sae_data[sae_hook].add_all(sae_acts[sae_hook]) @@ -230,6 +246,10 @@ def get_midlayer_data(opt: Model, orig=attn_data.summary(dtype=opt.dtype), peak_centered = attn_data_peak_centered.summary(dtype=opt.dtype, allow_nan=True) if attn_peak is not None else None, ) + if calculate_residual: + output["residual"] = ActivationSummaryHolder( + orig=residual_data.summary(dtype=opt.dtype), + ) if calculate_sae: output["sae"] = {} for sae_hook in sae_hook_points: @@ -242,6 +262,8 @@ def get_midlayer_data(opt: Model, output["raw"]["mlp"] = ff_data.get_raw() if collect_attn: output["raw"]["attn"] = attn_data.get_raw() + if collect_residual: + output["raw"]["residual"] = residual_data.get_raw() if collect_sae: output["raw"]["sae"] = {sae_hook: sae_data[sae_hook].get_raw() for sae_hook in sae_hook_points} if collect_ids: diff --git a/src/taker/data_classes.py b/src/taker/data_classes.py index b95c914..3c37c6e 100644 --- a/src/taker/data_classes.py +++ b/src/taker/data_classes.py @@ -506,6 +506,7 @@ class ActivationOverview: mlp: Optional[ActivationSummaryHolder] = None sae: Optional[Dict[str, ActivationSummaryHolder]] = None attn: Optional[ActivationSummaryHolder] = None + residual: Optional[ActivationSummaryHolder] = None raw: Optional[dict] = None misc_data: Optional[dict] = None @@ -610,6 +611,9 @@ class PruningConfig: sae_eps: float = 0.001 attn_frac: float = 0.0 attn_eps: float = 1e-4 + residual_frac: float = 0.0 + residual_eps: float = 0.001 + residual_scoring: str = "abs" dtype: str = "fp16" use_accelerator: bool = True model_device: Optional[str] = None diff --git a/src/taker/hooks.py b/src/taker/hooks.py index 8fef503..9ee3c56 100644 --- a/src/taker/hooks.py +++ b/src/taker/hooks.py @@ -235,6 +235,9 @@ def delete_mlp_neurons(self, remove_indices, layer: int = None): def delete_attn_neurons(self, remove_indices, layer: int = None): return self["attn_pre_out"].delete_neurons(remove_indices, layer) + def delete_residual_dimensions(self, remove_indices, layer: int = None): + return self["post_decoder"].delete_neurons(remove_indices, layer) + def reset_neuron_replace(self): [h.reset() for h in self.neuron_replace.values()] diff --git a/src/taker/prune.py b/src/taker/prune.py index a03bdc3..b3ffea0 100644 --- a/src/taker/prune.py +++ b/src/taker/prune.py @@ -37,12 +37,13 @@ def prune_and_evaluate( """ c = copy.deepcopy(pruning_config) - # Find out what we are doing + # Find out what we are doing do_ff = pruning_config.ff_frac > 0 do_attn = pruning_config.attn_frac > 0 do_sae = pruning_config.sae_frac > 0 - if not do_ff and not do_attn and not do_sae: - raise ValueError("Must prune at least one of FF or Attention or SAE") + do_residual = pruning_config.residual_frac > 0 + if not do_ff and not do_attn and not do_sae and not do_residual: + raise ValueError("Must prune at least one of FF, Attention, SAE, or Residual") if do_attn and pruning_config.attn_mode not in ["pre-out", "value"]: raise NotImplementedError("attn_mode must be 'pre-out' or 'value'") @@ -51,11 +52,16 @@ def prune_and_evaluate( sae_enabled = False if pruning_config.sae_frac > 0: sae_enabled = True + residual_enabled = pruning_config.residual_frac > 0 focus_out = get_midlayer_data( opt, pruning_config.focus, - pruning_config.collection_sample_size, pruning_config.attn_mode, calculate_sae=sae_enabled, collect_sae=sae_enabled ) + pruning_config.collection_sample_size, pruning_config.attn_mode, + calculate_sae=sae_enabled, collect_sae=sae_enabled, + calculate_residual=residual_enabled, collect_residual=residual_enabled ) cripple_out = get_midlayer_data( opt, pruning_config.cripple, - pruning_config.collection_sample_size, pruning_config.attn_mode, calculate_sae=sae_enabled, collect_sae=sae_enabled ) + pruning_config.collection_sample_size, pruning_config.attn_mode, + calculate_sae=sae_enabled, collect_sae=sae_enabled, + calculate_residual=residual_enabled, collect_residual=residual_enabled ) # Otherwise, import activation data, and adjust the "pruning fraction" else: @@ -88,11 +94,19 @@ def score_and_prune( opt: Model, ff_frac, ff_eps = pruning_config.ff_frac, pruning_config.ff_eps sae_frac, sae_eps = pruning_config.sae_frac, pruning_config.sae_eps attn_frac, attn_eps = pruning_config.attn_frac, pruning_config.attn_eps + residual_frac, residual_eps = pruning_config.residual_frac, pruning_config.residual_eps do_ff = ff_frac > 0 do_attn = attn_frac > 0 do_sae = sae_frac > 0 + do_residual = residual_frac > 0 act_subset = pruning_config.scoring_normalization + + # Initialize variables to avoid undefined variable issues + ff_scores, ff_criteria, ff_threshold = None, None, 0 + attn_scores, attn_criteria, attn_threshold = None, None, 0 + residual_scores, residual_criteria, residual_threshold = None, None, 0 + if do_ff > 0: ff_focus_data = focus_activations_data.mlp[act_subset] ff_cripple_data = cripple_activations_data.mlp[act_subset] @@ -114,6 +128,22 @@ def score_and_prune( opt: Model, opt.hooks[sae_hook].delete_neurons(sae_criteria) + # Residual stream pruning logic + if do_residual > 0: + # Get activation data for residual stream + residual_focus_data = focus_activations_data.residual.orig + residual_cripple_data = cripple_activations_data.residual.orig + + # Score residual stream dimensions + residual_scoring_fn = score_indices_by(pruning_config.residual_scoring) + residual_scores = residual_scoring_fn(opt, residual_focus_data, residual_cripple_data, residual_eps) + + # Determine which dimensions to prune + residual_criteria, residual_threshold = get_top_frac(residual_scores, residual_frac) + + # Perform the actual pruning + opt.hooks.delete_residual_dimensions(residual_criteria) + # Get the top fraction of Attention activations and prune if do_attn > 0: attn_focus_data = focus_activations_data.attn[act_subset] @@ -153,8 +183,10 @@ def score_and_prune( opt: Model, "ff_scores": ff_scores if do_ff else None, # FIXME: doesn't return attn_std_mean "attn_scores": attn_scores if do_attn else None, + "residual_scores": residual_scores if do_residual else None, "ff_criteria": ff_criteria if do_ff else None, "attn_criteria": attn_criteria if do_attn else None, + "residual_criteria": residual_criteria if do_residual else None, } if save: @@ -169,13 +201,16 @@ def score_and_prune( opt: Model, data.update({'deletions': { "ff_threshold": ff_threshold if do_ff else 0, "attn_threshold": attn_threshold if do_attn else 0, + "residual_threshold": residual_threshold if do_residual else 0, "ff_del": float( torch.sum(ff_criteria) ) if do_ff else 0, "attn_del": float( torch.sum(attn_criteria) ) if do_attn else 0, + "residual_del": float( torch.sum(residual_criteria) ) if do_residual else 0, }}) data.update({'deletions_per_layer': { 'ff': ff_criteria.sum(dim=-1).tolist() if do_ff else [], 'attn': attn_criteria.sum(dim=-1).tolist() if do_attn else [], + 'residual': residual_criteria.sum(dim=-1).tolist() if do_residual else [], }}) # Save removals and scores to history @@ -338,10 +373,17 @@ def run_pruning(c: PruningConfig): # Non-iteratively get activations, then iteratively prune and evaluate else: + sae_enabled = c.sae_frac > 0 + residual_enabled = c.residual_frac > 0 + focus_out = get_midlayer_data(opt, c.focus, - c.collection_sample_size, c.attn_mode) + c.collection_sample_size, c.attn_mode, + calculate_sae=sae_enabled, collect_sae=sae_enabled, + calculate_residual=residual_enabled, collect_residual=residual_enabled) cripple_out = get_midlayer_data(opt, c.cripple, - c.collection_sample_size, c.attn_mode) + c.collection_sample_size, c.attn_mode, + calculate_sae=sae_enabled, collect_sae=sae_enabled, + calculate_residual=residual_enabled, collect_residual=residual_enabled) for i in range(c.n_steps): data = prune_and_evaluate(opt, c, focus_out, cripple_out, i) history.add(data)