diff --git a/examples/early-testing/sae_prune_gemma.py b/examples/early-testing/sae_prune_gemma.py new file mode 100644 index 0000000..b64192a --- /dev/null +++ b/examples/early-testing/sae_prune_gemma.py @@ -0,0 +1,53 @@ +import wandb +import torch +import sys +sys.path.append('/root/taker/src') + +from taker import Model +from taker.activations import get_midlayer_data +from taker.prune import prune_and_evaluate, evaluate_all +from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem + +hook_config = """ +post_mlp: sae_encode, mask, collect, sae_decode +mlp_pre_out: collect +""" +c = PruningConfig("doesnt matter", + attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False, + ff_frac=0.0, attn_frac=0.0, sae_frac=0.2, + token_limit=100, focus="civil", cripple="toxic", wandb_entity="seperability", recalculate_activations=False, + wandb_project="bens-tests", wandb_run_name="gemma-2b mlp sae prune test default scoring", n_steps=10) +m = Model("google/gemma-2-2b", hook_config=hook_config) + +for layer in range(m.cfg.n_layers): + sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_post_mlp"] + sae_hook.load_sae_from_pretrained("gemma-scope-2b-pt-mlp-canonical", f"layer_{layer}/width_16k/canonical") + +#grabbing mlp activations seems needed so things dont break (to create raw activations dict?) +m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True) +m.hooks.enable_collect_hooks(["post_mlp"], run_assert=True) + +#TODO: seems we have to have collect_ff=True to get the activations for sae +focus_data = get_midlayer_data( m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) +cripple_data = get_midlayer_data( m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + +history = RunDataHistory(c.datasets) +wandb.init( + project=c.wandb_project, + entity=c.wandb_entity, + name=c.wandb_run_name, + ) +wandb.config.update(c.to_dict(), allow_val_change=True) + +with torch.no_grad(): + #evaluate without pruning first + data = RunDataItem() + eval_out = evaluate_all(m, c.eval_sample_size, c.datasets, + dataset_tokens_to_skip=c.collection_sample_size) + data.update(eval_out) + history.add(data) + + for i in range(c.n_steps): + print(f"Step {i}") + data = prune_and_evaluate(m, c, focus_data, cripple_data, i) + history.add(data) \ No newline at end of file diff --git a/examples/early-testing/sae_prune_gpt2.py b/examples/early-testing/sae_prune_gpt2.py new file mode 100644 index 0000000..ffa74af --- /dev/null +++ b/examples/early-testing/sae_prune_gpt2.py @@ -0,0 +1,55 @@ +import wandb +import torch +import sys +sys.path.append('/root/taker/src') + +from taker import Model +from taker.activations import get_midlayer_data +from taker.prune import prune_and_evaluate, evaluate_all +from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem + +hook_config = """ +pre_decoder: sae_encode, mask, collect, sae_decode +mlp_pre_out: collect +attn_pre_out: collect +""" #last line is needed to recalc activations +c = PruningConfig("doesnt matter", + attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False, + ff_frac=0.0, attn_frac=0.0, sae_frac=0.2, + token_limit=100, focus="civil", cripple="toxic", wandb_entity="seperability", recalculate_activations=True, + wandb_project="bens-tests", wandb_run_name="delete me", n_steps=10) +m = Model("gpt2", hook_config=hook_config) + +for layer in range(m.cfg.n_layers): + sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_pre_decoder"] + sae_hook.load_sae_from_pretrained("gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre") + +#grabbing mlp activations seems needed so things dont break (to create raw activations dict?) +m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True) +m.hooks.enable_collect_hooks(["pre_decoder"], run_assert=True) +m.hooks.enable_collect_hooks(["attn_pre_out"], run_assert=True) #Needed to recalc activations + +#TODO: seems we have to have collect_ff=True to get the activations for sae +focus_data = get_midlayer_data( m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) +cripple_data = get_midlayer_data( m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + +history = RunDataHistory(c.datasets) +wandb.init( + project=c.wandb_project, + entity=c.wandb_entity, + name=c.wandb_run_name, + ) +wandb.config.update(c.to_dict(), allow_val_change=True) + +with torch.no_grad(): + #evaluate without pruning first + data = RunDataItem() + eval_out = evaluate_all(m, c.eval_sample_size, c.datasets, + dataset_tokens_to_skip=c.collection_sample_size) + data.update(eval_out) + history.add(data) + + for i in range(c.n_steps): + print(f"Step {i}") + data = prune_and_evaluate(m, c, focus_data, cripple_data, i) + history.add(data) \ No newline at end of file diff --git a/examples/early-testing/sae_prune_gpt2_save_activations.py b/examples/early-testing/sae_prune_gpt2_save_activations.py new file mode 100644 index 0000000..c8cbae8 --- /dev/null +++ b/examples/early-testing/sae_prune_gpt2_save_activations.py @@ -0,0 +1,95 @@ +import wandb +import torch +import sys +import os +from pathlib import Path +sys.path.append('/root/taker/src') + +from taker import Model +from taker.activations import get_midlayer_data +from taker.prune import prune_and_evaluate, evaluate_all +from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem + +# Create directory for saving activations +SAVE_DIR = Path("examples/early-testing/sae_activations") +SAVE_DIR.mkdir(exist_ok=True) + +hook_config = """ +pre_decoder: sae_encode, mask, collect, sae_decode +mlp_pre_out: collect +attn_pre_out: collect +""" #last line is needed to recalc activations +c = PruningConfig("doesnt matter", + attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False, + ff_frac=0.0, attn_frac=0.0, sae_frac=0.2, + token_limit=100, focus="civil", cripple="toxic", wandb_entity="seperability", recalculate_activations=False, + wandb_project="bens-tests", wandb_run_name="delete me", n_steps=10) +m = Model("gpt2", hook_config=hook_config) + +for layer in range(m.cfg.n_layers): + sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_pre_decoder"] + sae_hook.load_sae_from_pretrained("gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre") + +#grabbing mlp activations seems needed so things dont break (to create raw activations dict?) +m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True) +m.hooks.enable_collect_hooks(["pre_decoder"], run_assert=True) +m.hooks.enable_collect_hooks(["attn_pre_out"], run_assert=True) #Needed to recalc activations + +# Save activations for each layer +def save_activations(data, dataset_type, step): + for layer in range(m.cfg.n_layers): + layer_dir = SAVE_DIR / f"layer_{layer}" + layer_dir.mkdir(exist_ok=True) + + # Get activations for this layer from the raw data + activations = data.raw["sae"]["pre_decoder"][:, layer, :] # Shape: [batch, d_sae] + + # Save as torch tensor with step number in filename + save_path = layer_dir / f"step_{step:03d}_{dataset_type}_activations.pt" + torch.save(activations.detach().cpu(), save_path) + + # Save metadata + metadata = { + "shape": activations.shape, + "mean": float(activations.mean().item()), + "std": float(activations.std().item()), + "min": float(activations.min().item()), + "max": float(activations.max().item()), + "step": step, + "layer": layer, + "dataset_type": dataset_type + } + torch.save(metadata, layer_dir / f"step_{step:03d}_{dataset_type}_metadata.pt") + +history = RunDataHistory(c.datasets) +wandb.init( + project=c.wandb_project, + entity=c.wandb_entity, + name=c.wandb_run_name, + ) +wandb.config.update(c.to_dict(), allow_val_change=True) + +with torch.no_grad(): + #evaluate without pruning first + data = RunDataItem() + eval_out = evaluate_all(m, c.eval_sample_size, c.datasets, + dataset_tokens_to_skip=c.collection_sample_size) + data.update(eval_out) + history.add(data) + + # Save initial activations (step -1) + focus_data = get_midlayer_data(m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + cripple_data = get_midlayer_data(m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + save_activations(focus_data, "focus", -1) + save_activations(cripple_data, "cripple", -1) + + for i in range(c.n_steps): + print(f"Step {i}") + data = prune_and_evaluate(m, c, focus_data, cripple_data, i) + history.add(data) + + # Get and save activations after each pruning step + focus_data = get_midlayer_data(m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + cripple_data = get_midlayer_data(m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False) + save_activations(focus_data, "focus", i) + save_activations(cripple_data, "cripple", i) \ No newline at end of file diff --git a/src/taker/activations.py b/src/taker/activations.py index 847f09e..9bd8bd2 100644 --- a/src/taker/activations.py +++ b/src/taker/activations.py @@ -130,7 +130,7 @@ def get_midlayer_data(opt: Model, #assuming all layers have the hook sae_shape = (opt.cfg.n_layers, opt.hooks.neuron_sae_encode[f"layer_0_{sae_hook}"].sae_config["d_sae"]) sae_data[sae_hook] = ActivationCollector(sae_shape, opt.output_device, collect_sae) - opt.hooks.enable_collect_hooks([sae_hook_points]) + opt.hooks.enable_collect_hooks([sae_hook]) if do_collect: criteria_raw = [] diff --git a/src/taker/data_classes.py b/src/taker/data_classes.py index 800e56c..b95c914 100644 --- a/src/taker/data_classes.py +++ b/src/taker/data_classes.py @@ -606,6 +606,8 @@ class PruningConfig: token_limit: int = None ff_frac: float = 0.1 ff_eps: float = 0.001 + sae_frac: float = 0.0 + sae_eps: float = 0.001 attn_frac: float = 0.0 attn_eps: float = 1e-4 dtype: str = "fp16" @@ -623,6 +625,8 @@ class PruningConfig: attn_offset_mode: str = "zero" ff_offset_mode: str = "zero" + sae_scoring: str = "abs" + attn_scoring: str = "abs" attn_mode: str = "pre-out" svd_attn: bool = False diff --git a/src/taker/prune.py b/src/taker/prune.py index 8b5588d..a03bdc3 100644 --- a/src/taker/prune.py +++ b/src/taker/prune.py @@ -40,17 +40,22 @@ def prune_and_evaluate( # Find out what we are doing do_ff = pruning_config.ff_frac > 0 do_attn = pruning_config.attn_frac > 0 - if not do_ff and not do_attn: - raise ValueError("Must prune at least one of FF or Attention") + do_sae = pruning_config.sae_frac > 0 + if not do_ff and not do_attn and not do_sae: + raise ValueError("Must prune at least one of FF or Attention or SAE") if do_attn and pruning_config.attn_mode not in ["pre-out", "value"]: raise NotImplementedError("attn_mode must be 'pre-out' or 'value'") # Get midlayer activations of FF and ATTN if pruning_config.recalculate_activations: + sae_enabled = False + if pruning_config.sae_frac > 0: + sae_enabled = True + focus_out = get_midlayer_data( opt, pruning_config.focus, - pruning_config.collection_sample_size, pruning_config.attn_mode ) + pruning_config.collection_sample_size, pruning_config.attn_mode, calculate_sae=sae_enabled, collect_sae=sae_enabled ) cripple_out = get_midlayer_data( opt, pruning_config.cripple, - pruning_config.collection_sample_size, pruning_config.attn_mode ) + pruning_config.collection_sample_size, pruning_config.attn_mode, calculate_sae=sae_enabled, collect_sae=sae_enabled ) # Otherwise, import activation data, and adjust the "pruning fraction" else: @@ -81,9 +86,11 @@ def score_and_prune( opt: Model, ): # Get the top fraction FF activations and prune ff_frac, ff_eps = pruning_config.ff_frac, pruning_config.ff_eps + sae_frac, sae_eps = pruning_config.sae_frac, pruning_config.sae_eps attn_frac, attn_eps = pruning_config.attn_frac, pruning_config.attn_eps do_ff = ff_frac > 0 do_attn = attn_frac > 0 + do_sae = sae_frac > 0 act_subset = pruning_config.scoring_normalization if do_ff > 0: @@ -94,6 +101,18 @@ def score_and_prune( opt: Model, ff_scores = ff_scoring_fn(opt, ff_focus_data, ff_cripple_data, ff_eps) ff_criteria, ff_threshold = get_top_frac(ff_scores, ff_frac) opt.hooks.delete_mlp_neurons(ff_criteria) + if do_sae > 0: + sae_hook_points = [point for point, layers in opt.hooks.hook_config.hook_points.items() + if 'all' in layers and any('sae' in hook for hook in layers['all'])] + for sae_hook in sae_hook_points: + sae_focus_data = focus_activations_data.sae[sae_hook] + sae_cripple_data = cripple_activations_data.sae[sae_hook] + sae_scoring_fn = score_indices_by(pruning_config.sae_scoring) + + sae_scores = sae_scoring_fn(opt, sae_focus_data.orig, sae_cripple_data.orig, sae_eps) + sae_criteria, sae_threshold = get_top_frac(sae_scores, sae_frac) + + opt.hooks[sae_hook].delete_neurons(sae_criteria) # Get the top fraction of Attention activations and prune if do_attn > 0: