From f8c1e3113343ff94740f006ffa2e64ac65603f86 Mon Sep 17 00:00:00 2001 From: nnebp Date: Thu, 16 Jan 2025 13:39:45 -0500 Subject: [PATCH 1/9] added sae to get_midlayer_data --- src/taker/activations.py | 45 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/src/taker/activations.py b/src/taker/activations.py index cfc1721..3c0ee22 100644 --- a/src/taker/activations.py +++ b/src/taker/activations.py @@ -72,9 +72,11 @@ def get_midlayer_data(opt: Model, skip_input_or_output: str = "output", calculate_ff: bool = True, calculate_attn: bool = True, + calculate_sae: bool = False, collect_ff: bool = False, collect_attn: bool = False, collect_ids: bool = False, + collect_sae: bool = False, dataset_texts_to_skip: int = None, random_subset_frac: float = None, eval_config: EvalConfig = None, @@ -100,6 +102,7 @@ def get_midlayer_data(opt: Model, do_ff = calculate_ff or collect_ff do_attn = calculate_attn or collect_attn do_collect = collect_ff or collect_attn or collect_ids + do_sae = calculate_sae or collect_sae # Get things ready for collection opt.hooks.disable_all_collect_hooks() @@ -118,6 +121,16 @@ def get_midlayer_data(opt: Model, opt.hooks.enable_collect_hooks(["attn_pre_out"]) if attn_peak is not None: attn_data_peak_centered = ActivationCollector(attn_shape, opt.output_device) + + if do_sae: + sae_hook_points = [point for point, layers in opt.hooks.hook_config.hook_points.items() + if 'all' in layers and any('sae' in hook for hook in layers['all'])] + sae_data = dict.fromkeys(sae_hook_points) + for sae_hook in sae_hook_points: + #assuming all layers have the hook + sae_shape = (opt.cfg.n_layers, opt.hooks.neuron_sae_encode[f"layer_0_{sae_hook}"].sae_config["d_sae"]) + sae_data[sae_hook] = ActivationCollector(sae_shape, opt.output_device, collect_sae) + opt.hooks.enable_collect_hooks([sae_hook_points]) if do_collect: criteria_raw = [] @@ -143,10 +156,25 @@ def get_midlayer_data(opt: Model, if do_ff: ff_acts = opt.collect_recent_mlp_pre_out() + #FIXME: delete + print("ff_acts shape before ",ff_acts.shape) ff_acts = einops.rearrange(ff_acts, "b l t d -> (b t) l d") + #FIXME: delete + print("ff_acts shape after ",ff_acts.shape) if do_attn: attn_acts = opt.collect_recent_attn_pre_out() attn_acts = einops.rearrange(attn_acts, "b l t nh dh -> (b t) l nh dh") + if do_sae: + sae_acts = {} + for sae_hook in sae_hook_points: + acts = [opt.hooks.collects[f"layer_{i}_{sae_hook}"].activation.to(opt.device) for i in range(opt.cfg.n_layers)] + #FIXME: delete + #TODO: rearrange + print("sae act shape ",acts[0].shape) + sae_acts[sae_hook] = torch.stack(acts) + print("sae acts before shape ",sae_acts[sae_hook].shape) + sae_acts[sae_hook] = einops.rearrange(sae_acts[sae_hook], "l b t d -> (b t) l d") + print("sae acts after shape ",sae_acts[sae_hook].shape) # set up criteria for filtering which activations we actually want ids = einops.rearrange(input_ids, "b t -> (b t)") @@ -175,6 +203,11 @@ def get_midlayer_data(opt: Model, criteria_indices = criteria.nonzero().flatten() if do_ff: ff_data.add_all(ff_acts[criteria_indices]) + #FIXME: delete + copy_acts = ff_acts[criteria_indices].clone() + print("copy_acts shape ",copy_acts.shape) + print("ff_data raw shape", ff_data.shape) + print("ff_acts criteria shape ",(ff_acts[criteria_indices]).shape) if ff_peak is not None: ff_data_peak_centered.add_all((ff_acts - ff_peak)[criteria_indices]) @@ -182,6 +215,9 @@ def get_midlayer_data(opt: Model, attn_data.add_all(attn_acts[criteria_indices]) if attn_peak is not None: attn_data_peak_centered.add_all((attn_acts - attn_peak)[criteria_indices]) + if do_sae: + for sae_hook in sae_hook_points: + sae_data[sae_hook].add_all(sae_acts[sae_hook]) if do_collect: for criterion in criteria: criteria_raw.append(criterion.cpu()) @@ -208,13 +244,20 @@ def get_midlayer_data(opt: Model, orig=attn_data.summary(dtype=opt.dtype), peak_centered = attn_data_peak_centered.summary(dtype=opt.dtype, allow_nan=True) if attn_peak is not None else None, ) - + if calculate_sae: + output["sae"] = {} + for sae_hook in sae_hook_points: + output["sae"][sae_hook] = ActivationSummaryHolder( + orig=sae_data[sae_hook].summary(dtype=opt.dtype), + ) if do_collect: output["raw"] = {"criteria": torch.stack(criteria_raw)} if collect_ff: output["raw"]["mlp"] = ff_data.get_raw() if collect_attn: output["raw"]["attn"] = attn_data.get_raw() + if collect_sae: + output["raw"]["sae"] = {sae_hook: sae_data[sae_hook].get_raw() for sae_hook in sae_hook_points} if collect_ids: output["raw"]["input_ids"] = torch.stack(input_id_data) if len(output_id_data): From 69846226019a503e0b1765f50e6092390cef66d9 Mon Sep 17 00:00:00 2001 From: nnebp Date: Thu, 16 Jan 2025 22:41:10 +0000 Subject: [PATCH 2/9] remove debug prints and comments --- src/taker/activations.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/taker/activations.py b/src/taker/activations.py index 3c0ee22..847f09e 100644 --- a/src/taker/activations.py +++ b/src/taker/activations.py @@ -156,11 +156,7 @@ def get_midlayer_data(opt: Model, if do_ff: ff_acts = opt.collect_recent_mlp_pre_out() - #FIXME: delete - print("ff_acts shape before ",ff_acts.shape) ff_acts = einops.rearrange(ff_acts, "b l t d -> (b t) l d") - #FIXME: delete - print("ff_acts shape after ",ff_acts.shape) if do_attn: attn_acts = opt.collect_recent_attn_pre_out() attn_acts = einops.rearrange(attn_acts, "b l t nh dh -> (b t) l nh dh") @@ -168,13 +164,8 @@ def get_midlayer_data(opt: Model, sae_acts = {} for sae_hook in sae_hook_points: acts = [opt.hooks.collects[f"layer_{i}_{sae_hook}"].activation.to(opt.device) for i in range(opt.cfg.n_layers)] - #FIXME: delete - #TODO: rearrange - print("sae act shape ",acts[0].shape) sae_acts[sae_hook] = torch.stack(acts) - print("sae acts before shape ",sae_acts[sae_hook].shape) sae_acts[sae_hook] = einops.rearrange(sae_acts[sae_hook], "l b t d -> (b t) l d") - print("sae acts after shape ",sae_acts[sae_hook].shape) # set up criteria for filtering which activations we actually want ids = einops.rearrange(input_ids, "b t -> (b t)") @@ -203,11 +194,6 @@ def get_midlayer_data(opt: Model, criteria_indices = criteria.nonzero().flatten() if do_ff: ff_data.add_all(ff_acts[criteria_indices]) - #FIXME: delete - copy_acts = ff_acts[criteria_indices].clone() - print("copy_acts shape ",copy_acts.shape) - print("ff_data raw shape", ff_data.shape) - print("ff_acts criteria shape ",(ff_acts[criteria_indices]).shape) if ff_peak is not None: ff_data_peak_centered.add_all((ff_acts - ff_peak)[criteria_indices]) From 927ce79366b201458b1e26cb9f4ab6f221f5aa5f Mon Sep 17 00:00:00 2001 From: nnebp Date: Thu, 16 Jan 2025 22:41:58 +0000 Subject: [PATCH 3/9] make sae dict in ActivationOverview more clear --- src/taker/data_classes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/taker/data_classes.py b/src/taker/data_classes.py index 16c5ea3..377df54 100644 --- a/src/taker/data_classes.py +++ b/src/taker/data_classes.py @@ -500,6 +500,7 @@ class ActivationOverview: """Output from activation collection on multiple possible parts""" texts_viewed: int mlp: Optional[ActivationSummaryHolder] = None + sae: Optional[Dict[str, ActivationSummaryHolder]] = None attn: Optional[ActivationSummaryHolder] = None raw: Optional[dict] = None misc_data: Optional[dict] = None From df0b5ed0dd5b5b481fadcc1d005cd2d834455d16 Mon Sep 17 00:00:00 2001 From: nnebp Date: Fri, 21 Feb 2025 12:14:12 -0500 Subject: [PATCH 4/9] fixed bug enabling sae hooks --- src/taker/activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/taker/activations.py b/src/taker/activations.py index 847f09e..9bd8bd2 100644 --- a/src/taker/activations.py +++ b/src/taker/activations.py @@ -130,7 +130,7 @@ def get_midlayer_data(opt: Model, #assuming all layers have the hook sae_shape = (opt.cfg.n_layers, opt.hooks.neuron_sae_encode[f"layer_0_{sae_hook}"].sae_config["d_sae"]) sae_data[sae_hook] = ActivationCollector(sae_shape, opt.output_device, collect_sae) - opt.hooks.enable_collect_hooks([sae_hook_points]) + opt.hooks.enable_collect_hooks([sae_hook]) if do_collect: criteria_raw = [] From 4332bf3c298dc0e945d8b79844b4bc852c12aa08 Mon Sep 17 00:00:00 2001 From: nnebp Date: Fri, 7 Mar 2025 11:26:34 -0500 Subject: [PATCH 5/9] SAE pruning code in progress --- examples/early-testing/sae_prune.py | 36 +++++++ examples/early-testing/sae_test_delete_me.py | 102 +++++++++++++++++++ src/taker/data_classes.py | 4 + src/taker/prune.py | 22 +++- 4 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 examples/early-testing/sae_prune.py create mode 100644 examples/early-testing/sae_test_delete_me.py diff --git a/examples/early-testing/sae_prune.py b/examples/early-testing/sae_prune.py new file mode 100644 index 0000000..e7adecd --- /dev/null +++ b/examples/early-testing/sae_prune.py @@ -0,0 +1,36 @@ + +from taker import Model +import torch +from taker.activations import get_midlayer_data +from taker.prune import prune_and_evaluate +from taker.data_classes import PruningConfig + +#TODO: +# 1. copy pruning code in here +# 2. add hooks and other things needed for all the sae stuff +# 3. grab the sae activations + +hook_config = """ +pre_decoder: mask, sae_encode, collect, sae_decode +mlp_pre_out: collect +""" +c = PruningConfig("nickypro/tinyllama-15m", + attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False, + ff_frac=0.0, attn_frac=0.0, sae_frac=0.1, + token_limit=1000, focus="pile", cripple="code", wandb_entity="seperability", recalculate_activations=False, + wandb_project="bens-tests", wandb_run_name="test notebook2", n_steps=10, scoring_normalization="peak_centered") +m = Model("gpt2", hook_config=hook_config) + +for layer in range(m.cfg.n_layers): + sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_pre_decoder"] + sae_hook.load_sae_from_pretrained("gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre") + +m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True) +m.hooks.enable_collect_hooks(["pre_decoder"], run_assert=True) + +focus_data = get_midlayer_data( m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) +cripple_data = get_midlayer_data( m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + +with torch.no_grad(): + for i in range(c.n_steps): + data = prune_and_evaluate(m, c, focus_data, cripple_data, i) \ No newline at end of file diff --git a/examples/early-testing/sae_test_delete_me.py b/examples/early-testing/sae_test_delete_me.py new file mode 100644 index 0000000..77945ee --- /dev/null +++ b/examples/early-testing/sae_test_delete_me.py @@ -0,0 +1,102 @@ + +# Pre-residual hooks +from taker import Model +import torch +from taker.activations import get_midlayer_data + +#pre_decoder: sae_encode, collect, sae_decode +hook_config = """ +pre_decoder: sae_encode, collect, sae_decode +mlp_pre_out: collect +""" +m = Model("gpt2", hook_config=hook_config) +#m = Model("gpt2") +#print(len(m.hooks["mlp_pre_out"]["collect"])) +print("model info:") +print(m.cfg.n_layers) +print(m.cfg.d_model) +print(m.cfg.n_heads) +print(m.cfg.d_mlp) +#--------------------- +sae_hook_points = [point for point, layers in m.hooks.hook_config.hook_points.items() + if 'all' in layers and any('sae' in hook for hook in layers['all'])] +sae_dict = dict.fromkeys(sae_hook_points) +#--------------------- +print(m.hooks.hook_config.hook_points) +#print(m.hooks.hook_config.hook_points.items(), " hook points") +print(sae_hook_points, " hook points") +for layer in range(m.cfg.n_layers): + sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_pre_decoder"] + sae_hook.load_sae_from_pretrained("gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre") + #print(sae_hook.sae_config["d_sae"]) + +#testing enabling hhooks +#m.hooks.enable_collect_hooks(["pre_decoder", "mlp_pre_out"], run_assert=True) +m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True) +m.hooks.enable_collect_hooks(["pre_decoder"], run_assert=True) + +#m.get_outputs_embeds("Hello, world!") +#m.get_outputs_embeds("kill all humans!") + +print("sae data: ") +# Working with layers +#layer_0 = m.layers[0] + +#stuff = layer_0 +#print(stuff) +cripple_data = get_midlayer_data( m, "toxic", 100, collect_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) +sae_acts = (m.hooks.collects["layer_0_pre_decoder"].activation) + +#test deletes +print("deleting neurons") +#print("hooks ", m.hooks) +#m.hooks.collects["layer_0_pre_decoder"].delete_neurons([]) +#m.hooks["mlp_pre_out"].delete_neurons([]) +m.hooks["pre_decoder"].delete_neurons([]) +print(m.hooks["mlp_pre_out"][2]) +#m.hooks["pre_decoder"][2].delete_neurons([]) + + +print("here we go") +print(sae_acts.shape) +print(sae_acts) +print(torch.count_nonzero(sae_acts), " non zero") +print(sae_acts[sae_acts != 0]) +#exit() +#cripple_data = get_midlayer_data( m, "toxic", 100, collect_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) +exit() +#print("hooks") +#print(m.hooks) +for layer in range(m.cfg.n_layers): + #act = m.hooks.collects[f"layer_{layer}_pre_decoder"].activation + #print(act.shape) + #print(m.hooks.neuron_sae_encode) + for sae_hook in sae_hook_points: + print(f"layer: {layer}") + print(m.hooks.neuron_sae_encode[f"layer_{layer}_{sae_hook}"].sae_config["d_sae"], " ", layer) + print(m.hooks.neuron_sae_encode[f"layer_{layer}_{sae_hook}"].sae_config, " ", layer) + #print(m.hooks.collects[f"layer_{layer}_{sae_hook}"].activation) + #print(act, " act") + #act = m.hooks.collects[f"layer_{layer}_pre_decoder"].activation + #print(m.hooks.collects[f"layer_{layer}_{sae_hook}"]) + #print(m.hooks.collects) #DOESNT WORK. its by layer + + #print(m.hooks["mlp_pre_out"]["collect"]) + #print(m.hooks["pre_decoder"]["collect"]) + + #print(m.hooks["pre_decoder"]) + +#print(len(m.hooks["pre_decoder"]["collect"])) +#print(m.hooks["pre_decoder"]["collect"]) + + +print("--------------------") +print(m.hooks.collects["layer_0_mlp_pre_out"].activation.shape) +print(m.hooks["mlp_pre_out"]["collect"][0].shape) #if you do this before the .collects call it will be empty. other way around is fine + +print(m.hooks.collects["layer_0_pre_decoder"].activation.shape) +print((m.hooks.collects["layer_0_pre_decoder"].activation != 0).sum()) +#print sae config +print(m.hooks.neuron_sae_encode["layer_0_pre_decoder"].sae_config["d_in"]) +#print model width +print(m.cfg.d_model) \ No newline at end of file diff --git a/src/taker/data_classes.py b/src/taker/data_classes.py index 800e56c..b95c914 100644 --- a/src/taker/data_classes.py +++ b/src/taker/data_classes.py @@ -606,6 +606,8 @@ class PruningConfig: token_limit: int = None ff_frac: float = 0.1 ff_eps: float = 0.001 + sae_frac: float = 0.0 + sae_eps: float = 0.001 attn_frac: float = 0.0 attn_eps: float = 1e-4 dtype: str = "fp16" @@ -623,6 +625,8 @@ class PruningConfig: attn_offset_mode: str = "zero" ff_offset_mode: str = "zero" + sae_scoring: str = "abs" + attn_scoring: str = "abs" attn_mode: str = "pre-out" svd_attn: bool = False diff --git a/src/taker/prune.py b/src/taker/prune.py index 8b5588d..644490d 100644 --- a/src/taker/prune.py +++ b/src/taker/prune.py @@ -40,8 +40,9 @@ def prune_and_evaluate( # Find out what we are doing do_ff = pruning_config.ff_frac > 0 do_attn = pruning_config.attn_frac > 0 - if not do_ff and not do_attn: - raise ValueError("Must prune at least one of FF or Attention") + do_sae = pruning_config.sae_frac > 0 + if not do_ff and not do_attn and not do_sae: + raise ValueError("Must prune at least one of FF or Attention or SAE") if do_attn and pruning_config.attn_mode not in ["pre-out", "value"]: raise NotImplementedError("attn_mode must be 'pre-out' or 'value'") @@ -81,9 +82,11 @@ def score_and_prune( opt: Model, ): # Get the top fraction FF activations and prune ff_frac, ff_eps = pruning_config.ff_frac, pruning_config.ff_eps + sae_frac, sae_eps = pruning_config.sae_frac, pruning_config.sae_eps attn_frac, attn_eps = pruning_config.attn_frac, pruning_config.attn_eps do_ff = ff_frac > 0 do_attn = attn_frac > 0 + do_sae = sae_frac > 0 act_subset = pruning_config.scoring_normalization if do_ff > 0: @@ -94,6 +97,21 @@ def score_and_prune( opt: Model, ff_scores = ff_scoring_fn(opt, ff_focus_data, ff_cripple_data, ff_eps) ff_criteria, ff_threshold = get_top_frac(ff_scores, ff_frac) opt.hooks.delete_mlp_neurons(ff_criteria) + if do_sae > 0: + sae_hook_points = [point for point, layers in opt.hooks.hook_config.hook_points.items() + if 'all' in layers and any('sae' in hook for hook in layers['all'])] + for sae_hook in sae_hook_points: + #FIXME: delete. work in progress + #print("sae activationOverview") + #print(focus_activations_data) + sae_focus_data = focus_activations_data.sae[sae_hook] + sae_cripple_data = cripple_activations_data.sae[sae_hook] + sae_scoring_fn = score_indices_by(pruning_config.sae_scoring) + + sae_scores = sae_scoring_fn(opt, sae_focus_data.orig, sae_cripple_data.orig, sae_eps) + sae_criteria, sae_threshold = get_top_frac(sae_scores, sae_frac) + + opt.hooks[sae_hook].delete_neurons(sae_criteria) # Get the top fraction of Attention activations and prune if do_attn > 0: From b47e110980bfe25c47ce32daf7196046cc383102 Mon Sep 17 00:00:00 2001 From: nnebp Date: Mon, 10 Mar 2025 18:34:15 -0400 Subject: [PATCH 6/9] fix datasets for sae testing --- examples/early-testing/sae_prune.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/early-testing/sae_prune.py b/examples/early-testing/sae_prune.py index e7adecd..5750017 100644 --- a/examples/early-testing/sae_prune.py +++ b/examples/early-testing/sae_prune.py @@ -11,13 +11,13 @@ # 3. grab the sae activations hook_config = """ -pre_decoder: mask, sae_encode, collect, sae_decode +pre_decoder: sae_encode, mask, collect, sae_decode mlp_pre_out: collect """ c = PruningConfig("nickypro/tinyllama-15m", attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False, ff_frac=0.0, attn_frac=0.0, sae_frac=0.1, - token_limit=1000, focus="pile", cripple="code", wandb_entity="seperability", recalculate_activations=False, + token_limit=100, focus="civil", cripple="toxic", wandb_entity="seperability", recalculate_activations=False, wandb_project="bens-tests", wandb_run_name="test notebook2", n_steps=10, scoring_normalization="peak_centered") m = Model("gpt2", hook_config=hook_config) @@ -33,4 +33,5 @@ with torch.no_grad(): for i in range(c.n_steps): + print(f"Step {i}") data = prune_and_evaluate(m, c, focus_data, cripple_data, i) \ No newline at end of file From 9b3e679d1fda69d244307db72fe2fa2a4d320e17 Mon Sep 17 00:00:00 2001 From: nnebp Date: Wed, 7 May 2025 23:37:34 +0000 Subject: [PATCH 7/9] Non working attempt at sae pruning. commiting to test outside of runpot --- examples/early-testing/sae_prune.py | 50 +++++++++++++++------ examples/early-testing/sae_prune_gemma.py | 53 ++++++++++++++++++++++ examples/early-testing/sae_prune_gpt2.py | 55 +++++++++++++++++++++++ src/taker/prune.py | 8 +++- 4 files changed, 150 insertions(+), 16 deletions(-) create mode 100644 examples/early-testing/sae_prune_gemma.py create mode 100644 examples/early-testing/sae_prune_gpt2.py diff --git a/examples/early-testing/sae_prune.py b/examples/early-testing/sae_prune.py index 5750017..4ef920f 100644 --- a/examples/early-testing/sae_prune.py +++ b/examples/early-testing/sae_prune.py @@ -1,37 +1,59 @@ +import wandb +import torch +import sys +sys.path.append('/root/taker/src') from taker import Model -import torch from taker.activations import get_midlayer_data -from taker.prune import prune_and_evaluate -from taker.data_classes import PruningConfig +from taker.prune import prune_and_evaluate, evaluate_all +from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem #TODO: # 1. copy pruning code in here # 2. add hooks and other things needed for all the sae stuff # 3. grab the sae activations +#hook_config = """ +#pre_decoder: sae_encode, mask, collect, sae_decode +#""" hook_config = """ -pre_decoder: sae_encode, mask, collect, sae_decode -mlp_pre_out: collect +attn_pre_out: sae_encode, collect, sae_decode """ -c = PruningConfig("nickypro/tinyllama-15m", +c = PruningConfig("does it matter?", attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False, ff_frac=0.0, attn_frac=0.0, sae_frac=0.1, - token_limit=100, focus="civil", cripple="toxic", wandb_entity="seperability", recalculate_activations=False, - wandb_project="bens-tests", wandb_run_name="test notebook2", n_steps=10, scoring_normalization="peak_centered") -m = Model("gpt2", hook_config=hook_config) + token_limit=512, focus="pile", cripple="code", wandb_entity="seperability", recalculate_activations=False, + wandb_project="bens-tests", wandb_run_name="sae prune test", n_steps=10) +m = Model("google/gemma-2-2b", hook_config=hook_config) for layer in range(m.cfg.n_layers): - sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_pre_decoder"] - sae_hook.load_sae_from_pretrained("gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre") + sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_attn_pre_out"] + sae_hook.load_sae_from_pretrained("gemma-scope-2b-pt-att-canonical", f"layer_{layer}/width_65k/canonical") + #sae_hook.load_sae_from_pretrained("gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre") m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True) m.hooks.enable_collect_hooks(["pre_decoder"], run_assert=True) -focus_data = get_midlayer_data( m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) -cripple_data = get_midlayer_data( m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) +focus_data = get_midlayer_data( m, "pile", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=False, calculate_attn=False, calculate_ff=False) +cripple_data = get_midlayer_data( m, "code", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=False, calculate_attn=False, calculate_ff=False) + +history = RunDataHistory(c.datasets) +wandb.init( + project=c.wandb_project, + entity=c.wandb_entity, + name=c.wandb_run_name, + ) +wandb.config.update(c.to_dict(), allow_val_change=True) with torch.no_grad(): + #evaluate without pruning first + data = RunDataItem() + eval_out = evaluate_all(m, c.eval_sample_size, c.datasets, + dataset_tokens_to_skip=c.collection_sample_size) + data.update(eval_out) + history.add(data) + for i in range(c.n_steps): print(f"Step {i}") - data = prune_and_evaluate(m, c, focus_data, cripple_data, i) \ No newline at end of file + data = prune_and_evaluate(m, c, focus_data, cripple_data, i) + history.add(data) \ No newline at end of file diff --git a/examples/early-testing/sae_prune_gemma.py b/examples/early-testing/sae_prune_gemma.py new file mode 100644 index 0000000..b64192a --- /dev/null +++ b/examples/early-testing/sae_prune_gemma.py @@ -0,0 +1,53 @@ +import wandb +import torch +import sys +sys.path.append('/root/taker/src') + +from taker import Model +from taker.activations import get_midlayer_data +from taker.prune import prune_and_evaluate, evaluate_all +from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem + +hook_config = """ +post_mlp: sae_encode, mask, collect, sae_decode +mlp_pre_out: collect +""" +c = PruningConfig("doesnt matter", + attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False, + ff_frac=0.0, attn_frac=0.0, sae_frac=0.2, + token_limit=100, focus="civil", cripple="toxic", wandb_entity="seperability", recalculate_activations=False, + wandb_project="bens-tests", wandb_run_name="gemma-2b mlp sae prune test default scoring", n_steps=10) +m = Model("google/gemma-2-2b", hook_config=hook_config) + +for layer in range(m.cfg.n_layers): + sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_post_mlp"] + sae_hook.load_sae_from_pretrained("gemma-scope-2b-pt-mlp-canonical", f"layer_{layer}/width_16k/canonical") + +#grabbing mlp activations seems needed so things dont break (to create raw activations dict?) +m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True) +m.hooks.enable_collect_hooks(["post_mlp"], run_assert=True) + +#TODO: seems we have to have collect_ff=True to get the activations for sae +focus_data = get_midlayer_data( m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) +cripple_data = get_midlayer_data( m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + +history = RunDataHistory(c.datasets) +wandb.init( + project=c.wandb_project, + entity=c.wandb_entity, + name=c.wandb_run_name, + ) +wandb.config.update(c.to_dict(), allow_val_change=True) + +with torch.no_grad(): + #evaluate without pruning first + data = RunDataItem() + eval_out = evaluate_all(m, c.eval_sample_size, c.datasets, + dataset_tokens_to_skip=c.collection_sample_size) + data.update(eval_out) + history.add(data) + + for i in range(c.n_steps): + print(f"Step {i}") + data = prune_and_evaluate(m, c, focus_data, cripple_data, i) + history.add(data) \ No newline at end of file diff --git a/examples/early-testing/sae_prune_gpt2.py b/examples/early-testing/sae_prune_gpt2.py new file mode 100644 index 0000000..ffa74af --- /dev/null +++ b/examples/early-testing/sae_prune_gpt2.py @@ -0,0 +1,55 @@ +import wandb +import torch +import sys +sys.path.append('/root/taker/src') + +from taker import Model +from taker.activations import get_midlayer_data +from taker.prune import prune_and_evaluate, evaluate_all +from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem + +hook_config = """ +pre_decoder: sae_encode, mask, collect, sae_decode +mlp_pre_out: collect +attn_pre_out: collect +""" #last line is needed to recalc activations +c = PruningConfig("doesnt matter", + attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False, + ff_frac=0.0, attn_frac=0.0, sae_frac=0.2, + token_limit=100, focus="civil", cripple="toxic", wandb_entity="seperability", recalculate_activations=True, + wandb_project="bens-tests", wandb_run_name="delete me", n_steps=10) +m = Model("gpt2", hook_config=hook_config) + +for layer in range(m.cfg.n_layers): + sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_pre_decoder"] + sae_hook.load_sae_from_pretrained("gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre") + +#grabbing mlp activations seems needed so things dont break (to create raw activations dict?) +m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True) +m.hooks.enable_collect_hooks(["pre_decoder"], run_assert=True) +m.hooks.enable_collect_hooks(["attn_pre_out"], run_assert=True) #Needed to recalc activations + +#TODO: seems we have to have collect_ff=True to get the activations for sae +focus_data = get_midlayer_data( m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) +cripple_data = get_midlayer_data( m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + +history = RunDataHistory(c.datasets) +wandb.init( + project=c.wandb_project, + entity=c.wandb_entity, + name=c.wandb_run_name, + ) +wandb.config.update(c.to_dict(), allow_val_change=True) + +with torch.no_grad(): + #evaluate without pruning first + data = RunDataItem() + eval_out = evaluate_all(m, c.eval_sample_size, c.datasets, + dataset_tokens_to_skip=c.collection_sample_size) + data.update(eval_out) + history.add(data) + + for i in range(c.n_steps): + print(f"Step {i}") + data = prune_and_evaluate(m, c, focus_data, cripple_data, i) + history.add(data) \ No newline at end of file diff --git a/src/taker/prune.py b/src/taker/prune.py index 644490d..13a3e6c 100644 --- a/src/taker/prune.py +++ b/src/taker/prune.py @@ -48,10 +48,14 @@ def prune_and_evaluate( # Get midlayer activations of FF and ATTN if pruning_config.recalculate_activations: + sae_enabled = False + if pruning_config.sae_frac > 0: + sae_enabled = True + focus_out = get_midlayer_data( opt, pruning_config.focus, - pruning_config.collection_sample_size, pruning_config.attn_mode ) + pruning_config.collection_sample_size, pruning_config.attn_mode, calculate_sae=sae_enabled, collect_sae=sae_enabled ) cripple_out = get_midlayer_data( opt, pruning_config.cripple, - pruning_config.collection_sample_size, pruning_config.attn_mode ) + pruning_config.collection_sample_size, pruning_config.attn_mode, calculate_sae=sae_enabled, collect_sae=sae_enabled ) # Otherwise, import activation data, and adjust the "pruning fraction" else: From d84e56f43490c8321bd1b32c063c0de3d2767ab8 Mon Sep 17 00:00:00 2001 From: nnebp Date: Mon, 19 May 2025 17:47:45 +0000 Subject: [PATCH 8/9] some example SAE pruning scripts --- .../sae_prune_gpt2_save_activations.py | 95 +++++++++++++++++++ src/taker/prune.py | 3 - 2 files changed, 95 insertions(+), 3 deletions(-) create mode 100644 examples/early-testing/sae_prune_gpt2_save_activations.py diff --git a/examples/early-testing/sae_prune_gpt2_save_activations.py b/examples/early-testing/sae_prune_gpt2_save_activations.py new file mode 100644 index 0000000..c8cbae8 --- /dev/null +++ b/examples/early-testing/sae_prune_gpt2_save_activations.py @@ -0,0 +1,95 @@ +import wandb +import torch +import sys +import os +from pathlib import Path +sys.path.append('/root/taker/src') + +from taker import Model +from taker.activations import get_midlayer_data +from taker.prune import prune_and_evaluate, evaluate_all +from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem + +# Create directory for saving activations +SAVE_DIR = Path("examples/early-testing/sae_activations") +SAVE_DIR.mkdir(exist_ok=True) + +hook_config = """ +pre_decoder: sae_encode, mask, collect, sae_decode +mlp_pre_out: collect +attn_pre_out: collect +""" #last line is needed to recalc activations +c = PruningConfig("doesnt matter", + attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False, + ff_frac=0.0, attn_frac=0.0, sae_frac=0.2, + token_limit=100, focus="civil", cripple="toxic", wandb_entity="seperability", recalculate_activations=False, + wandb_project="bens-tests", wandb_run_name="delete me", n_steps=10) +m = Model("gpt2", hook_config=hook_config) + +for layer in range(m.cfg.n_layers): + sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_pre_decoder"] + sae_hook.load_sae_from_pretrained("gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre") + +#grabbing mlp activations seems needed so things dont break (to create raw activations dict?) +m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True) +m.hooks.enable_collect_hooks(["pre_decoder"], run_assert=True) +m.hooks.enable_collect_hooks(["attn_pre_out"], run_assert=True) #Needed to recalc activations + +# Save activations for each layer +def save_activations(data, dataset_type, step): + for layer in range(m.cfg.n_layers): + layer_dir = SAVE_DIR / f"layer_{layer}" + layer_dir.mkdir(exist_ok=True) + + # Get activations for this layer from the raw data + activations = data.raw["sae"]["pre_decoder"][:, layer, :] # Shape: [batch, d_sae] + + # Save as torch tensor with step number in filename + save_path = layer_dir / f"step_{step:03d}_{dataset_type}_activations.pt" + torch.save(activations.detach().cpu(), save_path) + + # Save metadata + metadata = { + "shape": activations.shape, + "mean": float(activations.mean().item()), + "std": float(activations.std().item()), + "min": float(activations.min().item()), + "max": float(activations.max().item()), + "step": step, + "layer": layer, + "dataset_type": dataset_type + } + torch.save(metadata, layer_dir / f"step_{step:03d}_{dataset_type}_metadata.pt") + +history = RunDataHistory(c.datasets) +wandb.init( + project=c.wandb_project, + entity=c.wandb_entity, + name=c.wandb_run_name, + ) +wandb.config.update(c.to_dict(), allow_val_change=True) + +with torch.no_grad(): + #evaluate without pruning first + data = RunDataItem() + eval_out = evaluate_all(m, c.eval_sample_size, c.datasets, + dataset_tokens_to_skip=c.collection_sample_size) + data.update(eval_out) + history.add(data) + + # Save initial activations (step -1) + focus_data = get_midlayer_data(m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + cripple_data = get_midlayer_data(m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + save_activations(focus_data, "focus", -1) + save_activations(cripple_data, "cripple", -1) + + for i in range(c.n_steps): + print(f"Step {i}") + data = prune_and_evaluate(m, c, focus_data, cripple_data, i) + history.add(data) + + # Get and save activations after each pruning step + focus_data = get_midlayer_data(m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + cripple_data = get_midlayer_data(m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + save_activations(focus_data, "focus", i) + save_activations(cripple_data, "cripple", i) \ No newline at end of file diff --git a/src/taker/prune.py b/src/taker/prune.py index 13a3e6c..a03bdc3 100644 --- a/src/taker/prune.py +++ b/src/taker/prune.py @@ -105,9 +105,6 @@ def score_and_prune( opt: Model, sae_hook_points = [point for point, layers in opt.hooks.hook_config.hook_points.items() if 'all' in layers and any('sae' in hook for hook in layers['all'])] for sae_hook in sae_hook_points: - #FIXME: delete. work in progress - #print("sae activationOverview") - #print(focus_activations_data) sae_focus_data = focus_activations_data.sae[sae_hook] sae_cripple_data = cripple_activations_data.sae[sae_hook] sae_scoring_fn = score_indices_by(pruning_config.sae_scoring) From 4a7836799d4822718851660ef284a80f96a24b17 Mon Sep 17 00:00:00 2001 From: nnebp Date: Mon, 19 May 2025 18:00:43 +0000 Subject: [PATCH 9/9] cleaning up old files --- examples/early-testing/sae_prune.py | 59 ----------- examples/early-testing/sae_test_delete_me.py | 102 ------------------- 2 files changed, 161 deletions(-) delete mode 100644 examples/early-testing/sae_prune.py delete mode 100644 examples/early-testing/sae_test_delete_me.py diff --git a/examples/early-testing/sae_prune.py b/examples/early-testing/sae_prune.py deleted file mode 100644 index 4ef920f..0000000 --- a/examples/early-testing/sae_prune.py +++ /dev/null @@ -1,59 +0,0 @@ -import wandb -import torch -import sys -sys.path.append('/root/taker/src') - -from taker import Model -from taker.activations import get_midlayer_data -from taker.prune import prune_and_evaluate, evaluate_all -from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem - -#TODO: -# 1. copy pruning code in here -# 2. add hooks and other things needed for all the sae stuff -# 3. grab the sae activations - -#hook_config = """ -#pre_decoder: sae_encode, mask, collect, sae_decode -#""" -hook_config = """ -attn_pre_out: sae_encode, collect, sae_decode -""" -c = PruningConfig("does it matter?", - attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False, - ff_frac=0.0, attn_frac=0.0, sae_frac=0.1, - token_limit=512, focus="pile", cripple="code", wandb_entity="seperability", recalculate_activations=False, - wandb_project="bens-tests", wandb_run_name="sae prune test", n_steps=10) -m = Model("google/gemma-2-2b", hook_config=hook_config) - -for layer in range(m.cfg.n_layers): - sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_attn_pre_out"] - sae_hook.load_sae_from_pretrained("gemma-scope-2b-pt-att-canonical", f"layer_{layer}/width_65k/canonical") - #sae_hook.load_sae_from_pretrained("gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre") - -m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True) -m.hooks.enable_collect_hooks(["pre_decoder"], run_assert=True) - -focus_data = get_midlayer_data( m, "pile", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=False, calculate_attn=False, calculate_ff=False) -cripple_data = get_midlayer_data( m, "code", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=False, calculate_attn=False, calculate_ff=False) - -history = RunDataHistory(c.datasets) -wandb.init( - project=c.wandb_project, - entity=c.wandb_entity, - name=c.wandb_run_name, - ) -wandb.config.update(c.to_dict(), allow_val_change=True) - -with torch.no_grad(): - #evaluate without pruning first - data = RunDataItem() - eval_out = evaluate_all(m, c.eval_sample_size, c.datasets, - dataset_tokens_to_skip=c.collection_sample_size) - data.update(eval_out) - history.add(data) - - for i in range(c.n_steps): - print(f"Step {i}") - data = prune_and_evaluate(m, c, focus_data, cripple_data, i) - history.add(data) \ No newline at end of file diff --git a/examples/early-testing/sae_test_delete_me.py b/examples/early-testing/sae_test_delete_me.py deleted file mode 100644 index 77945ee..0000000 --- a/examples/early-testing/sae_test_delete_me.py +++ /dev/null @@ -1,102 +0,0 @@ - -# Pre-residual hooks -from taker import Model -import torch -from taker.activations import get_midlayer_data - -#pre_decoder: sae_encode, collect, sae_decode -hook_config = """ -pre_decoder: sae_encode, collect, sae_decode -mlp_pre_out: collect -""" -m = Model("gpt2", hook_config=hook_config) -#m = Model("gpt2") -#print(len(m.hooks["mlp_pre_out"]["collect"])) -print("model info:") -print(m.cfg.n_layers) -print(m.cfg.d_model) -print(m.cfg.n_heads) -print(m.cfg.d_mlp) -#--------------------- -sae_hook_points = [point for point, layers in m.hooks.hook_config.hook_points.items() - if 'all' in layers and any('sae' in hook for hook in layers['all'])] -sae_dict = dict.fromkeys(sae_hook_points) -#--------------------- -print(m.hooks.hook_config.hook_points) -#print(m.hooks.hook_config.hook_points.items(), " hook points") -print(sae_hook_points, " hook points") -for layer in range(m.cfg.n_layers): - sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_pre_decoder"] - sae_hook.load_sae_from_pretrained("gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre") - #print(sae_hook.sae_config["d_sae"]) - -#testing enabling hhooks -#m.hooks.enable_collect_hooks(["pre_decoder", "mlp_pre_out"], run_assert=True) -m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True) -m.hooks.enable_collect_hooks(["pre_decoder"], run_assert=True) - -#m.get_outputs_embeds("Hello, world!") -#m.get_outputs_embeds("kill all humans!") - -print("sae data: ") -# Working with layers -#layer_0 = m.layers[0] - -#stuff = layer_0 -#print(stuff) -cripple_data = get_midlayer_data( m, "toxic", 100, collect_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) -sae_acts = (m.hooks.collects["layer_0_pre_decoder"].activation) - -#test deletes -print("deleting neurons") -#print("hooks ", m.hooks) -#m.hooks.collects["layer_0_pre_decoder"].delete_neurons([]) -#m.hooks["mlp_pre_out"].delete_neurons([]) -m.hooks["pre_decoder"].delete_neurons([]) -print(m.hooks["mlp_pre_out"][2]) -#m.hooks["pre_decoder"][2].delete_neurons([]) - - -print("here we go") -print(sae_acts.shape) -print(sae_acts) -print(torch.count_nonzero(sae_acts), " non zero") -print(sae_acts[sae_acts != 0]) -#exit() -#cripple_data = get_midlayer_data( m, "toxic", 100, collect_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) -exit() -#print("hooks") -#print(m.hooks) -for layer in range(m.cfg.n_layers): - #act = m.hooks.collects[f"layer_{layer}_pre_decoder"].activation - #print(act.shape) - #print(m.hooks.neuron_sae_encode) - for sae_hook in sae_hook_points: - print(f"layer: {layer}") - print(m.hooks.neuron_sae_encode[f"layer_{layer}_{sae_hook}"].sae_config["d_sae"], " ", layer) - print(m.hooks.neuron_sae_encode[f"layer_{layer}_{sae_hook}"].sae_config, " ", layer) - #print(m.hooks.collects[f"layer_{layer}_{sae_hook}"].activation) - #print(act, " act") - #act = m.hooks.collects[f"layer_{layer}_pre_decoder"].activation - #print(m.hooks.collects[f"layer_{layer}_{sae_hook}"]) - #print(m.hooks.collects) #DOESNT WORK. its by layer - - #print(m.hooks["mlp_pre_out"]["collect"]) - #print(m.hooks["pre_decoder"]["collect"]) - - #print(m.hooks["pre_decoder"]) - -#print(len(m.hooks["pre_decoder"]["collect"])) -#print(m.hooks["pre_decoder"]["collect"]) - - -print("--------------------") -print(m.hooks.collects["layer_0_mlp_pre_out"].activation.shape) -print(m.hooks["mlp_pre_out"]["collect"][0].shape) #if you do this before the .collects call it will be empty. other way around is fine - -print(m.hooks.collects["layer_0_pre_decoder"].activation.shape) -print((m.hooks.collects["layer_0_pre_decoder"].activation != 0).sum()) -#print sae config -print(m.hooks.neuron_sae_encode["layer_0_pre_decoder"].sae_config["d_in"]) -#print model width -print(m.cfg.d_model) \ No newline at end of file