Skip to content
53 changes: 53 additions & 0 deletions examples/early-testing/sae_prune_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import wandb
import torch
import sys
sys.path.append('/root/taker/src')

from taker import Model
from taker.activations import get_midlayer_data
from taker.prune import prune_and_evaluate, evaluate_all
from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem

hook_config = """
post_mlp: sae_encode, mask, collect, sae_decode
mlp_pre_out: collect
"""
c = PruningConfig("doesnt matter",
attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False,
ff_frac=0.0, attn_frac=0.0, sae_frac=0.2,
token_limit=100, focus="civil", cripple="toxic", wandb_entity="seperability", recalculate_activations=False,
wandb_project="bens-tests", wandb_run_name="gemma-2b mlp sae prune test default scoring", n_steps=10)
m = Model("google/gemma-2-2b", hook_config=hook_config)

for layer in range(m.cfg.n_layers):
sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_post_mlp"]
sae_hook.load_sae_from_pretrained("gemma-scope-2b-pt-mlp-canonical", f"layer_{layer}/width_16k/canonical")

#grabbing mlp activations seems needed so things dont break (to create raw activations dict?)
m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True)
m.hooks.enable_collect_hooks(["post_mlp"], run_assert=True)

#TODO: seems we have to have collect_ff=True to get the activations for sae
focus_data = get_midlayer_data( m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False)
cripple_data = get_midlayer_data( m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False)

history = RunDataHistory(c.datasets)
wandb.init(
project=c.wandb_project,
entity=c.wandb_entity,
name=c.wandb_run_name,
)
wandb.config.update(c.to_dict(), allow_val_change=True)

with torch.no_grad():
#evaluate without pruning first
data = RunDataItem()
eval_out = evaluate_all(m, c.eval_sample_size, c.datasets,
dataset_tokens_to_skip=c.collection_sample_size)
data.update(eval_out)
history.add(data)

for i in range(c.n_steps):
print(f"Step {i}")
data = prune_and_evaluate(m, c, focus_data, cripple_data, i)
history.add(data)
55 changes: 55 additions & 0 deletions examples/early-testing/sae_prune_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import wandb
import torch
import sys
sys.path.append('/root/taker/src')

from taker import Model
from taker.activations import get_midlayer_data
from taker.prune import prune_and_evaluate, evaluate_all
from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem

hook_config = """
pre_decoder: sae_encode, mask, collect, sae_decode
mlp_pre_out: collect
attn_pre_out: collect
""" #last line is needed to recalc activations
c = PruningConfig("doesnt matter",
attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False,
ff_frac=0.0, attn_frac=0.0, sae_frac=0.2,
token_limit=100, focus="civil", cripple="toxic", wandb_entity="seperability", recalculate_activations=True,
wandb_project="bens-tests", wandb_run_name="delete me", n_steps=10)
m = Model("gpt2", hook_config=hook_config)

for layer in range(m.cfg.n_layers):
sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_pre_decoder"]
sae_hook.load_sae_from_pretrained("gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre")

#grabbing mlp activations seems needed so things dont break (to create raw activations dict?)
m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True)
m.hooks.enable_collect_hooks(["pre_decoder"], run_assert=True)
m.hooks.enable_collect_hooks(["attn_pre_out"], run_assert=True) #Needed to recalc activations

#TODO: seems we have to have collect_ff=True to get the activations for sae
focus_data = get_midlayer_data( m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False)
cripple_data = get_midlayer_data( m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False)

history = RunDataHistory(c.datasets)
wandb.init(
project=c.wandb_project,
entity=c.wandb_entity,
name=c.wandb_run_name,
)
wandb.config.update(c.to_dict(), allow_val_change=True)

with torch.no_grad():
#evaluate without pruning first
data = RunDataItem()
eval_out = evaluate_all(m, c.eval_sample_size, c.datasets,
dataset_tokens_to_skip=c.collection_sample_size)
data.update(eval_out)
history.add(data)

for i in range(c.n_steps):
print(f"Step {i}")
data = prune_and_evaluate(m, c, focus_data, cripple_data, i)
history.add(data)
95 changes: 95 additions & 0 deletions examples/early-testing/sae_prune_gpt2_save_activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import wandb
import torch
import sys
import os
from pathlib import Path
sys.path.append('/root/taker/src')

from taker import Model
from taker.activations import get_midlayer_data
from taker.prune import prune_and_evaluate, evaluate_all
from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem

# Create directory for saving activations
SAVE_DIR = Path("examples/early-testing/sae_activations")
SAVE_DIR.mkdir(exist_ok=True)

hook_config = """
pre_decoder: sae_encode, mask, collect, sae_decode
mlp_pre_out: collect
attn_pre_out: collect
""" #last line is needed to recalc activations
c = PruningConfig("doesnt matter",
attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False,
ff_frac=0.0, attn_frac=0.0, sae_frac=0.2,
token_limit=100, focus="civil", cripple="toxic", wandb_entity="seperability", recalculate_activations=False,
wandb_project="bens-tests", wandb_run_name="delete me", n_steps=10)
m = Model("gpt2", hook_config=hook_config)

for layer in range(m.cfg.n_layers):
sae_hook = m.hooks.neuron_sae_encode[f"layer_{layer}_pre_decoder"]
sae_hook.load_sae_from_pretrained("gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre")

#grabbing mlp activations seems needed so things dont break (to create raw activations dict?)
m.hooks.enable_collect_hooks(["mlp_pre_out"], run_assert=True)
m.hooks.enable_collect_hooks(["pre_decoder"], run_assert=True)
m.hooks.enable_collect_hooks(["attn_pre_out"], run_assert=True) #Needed to recalc activations

# Save activations for each layer
def save_activations(data, dataset_type, step):
for layer in range(m.cfg.n_layers):
layer_dir = SAVE_DIR / f"layer_{layer}"
layer_dir.mkdir(exist_ok=True)

# Get activations for this layer from the raw data
activations = data.raw["sae"]["pre_decoder"][:, layer, :] # Shape: [batch, d_sae]

# Save as torch tensor with step number in filename
save_path = layer_dir / f"step_{step:03d}_{dataset_type}_activations.pt"
torch.save(activations.detach().cpu(), save_path)

# Save metadata
metadata = {
"shape": activations.shape,
"mean": float(activations.mean().item()),
"std": float(activations.std().item()),
"min": float(activations.min().item()),
"max": float(activations.max().item()),
"step": step,
"layer": layer,
"dataset_type": dataset_type
}
torch.save(metadata, layer_dir / f"step_{step:03d}_{dataset_type}_metadata.pt")

history = RunDataHistory(c.datasets)
wandb.init(
project=c.wandb_project,
entity=c.wandb_entity,
name=c.wandb_run_name,
)
wandb.config.update(c.to_dict(), allow_val_change=True)

with torch.no_grad():
#evaluate without pruning first
data = RunDataItem()
eval_out = evaluate_all(m, c.eval_sample_size, c.datasets,
dataset_tokens_to_skip=c.collection_sample_size)
data.update(eval_out)
history.add(data)

# Save initial activations (step -1)
focus_data = get_midlayer_data(m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False)
cripple_data = get_midlayer_data(m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False)
save_activations(focus_data, "focus", -1)
save_activations(cripple_data, "cripple", -1)

for i in range(c.n_steps):
print(f"Step {i}")
data = prune_and_evaluate(m, c, focus_data, cripple_data, i)
history.add(data)

# Get and save activations after each pruning step
focus_data = get_midlayer_data(m, "civil", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False)
cripple_data = get_midlayer_data(m, "toxic", 10, collect_sae=True, calculate_sae=True, collect_attn=False, collect_ff=True, calculate_attn=False, calculate_ff=False)
save_activations(focus_data, "focus", i)
save_activations(cripple_data, "cripple", i)
2 changes: 1 addition & 1 deletion src/taker/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def get_midlayer_data(opt: Model,
#assuming all layers have the hook
sae_shape = (opt.cfg.n_layers, opt.hooks.neuron_sae_encode[f"layer_0_{sae_hook}"].sae_config["d_sae"])
sae_data[sae_hook] = ActivationCollector(sae_shape, opt.output_device, collect_sae)
opt.hooks.enable_collect_hooks([sae_hook_points])
opt.hooks.enable_collect_hooks([sae_hook])

if do_collect:
criteria_raw = []
Expand Down
4 changes: 4 additions & 0 deletions src/taker/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,8 @@ class PruningConfig:
token_limit: int = None
ff_frac: float = 0.1
ff_eps: float = 0.001
sae_frac: float = 0.0
sae_eps: float = 0.001
attn_frac: float = 0.0
attn_eps: float = 1e-4
dtype: str = "fp16"
Expand All @@ -623,6 +625,8 @@ class PruningConfig:
attn_offset_mode: str = "zero"
ff_offset_mode: str = "zero"

sae_scoring: str = "abs"

attn_scoring: str = "abs"
attn_mode: str = "pre-out"
svd_attn: bool = False
Expand Down
27 changes: 23 additions & 4 deletions src/taker/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,22 @@ def prune_and_evaluate(
# Find out what we are doing
do_ff = pruning_config.ff_frac > 0
do_attn = pruning_config.attn_frac > 0
if not do_ff and not do_attn:
raise ValueError("Must prune at least one of FF or Attention")
do_sae = pruning_config.sae_frac > 0
if not do_ff and not do_attn and not do_sae:
raise ValueError("Must prune at least one of FF or Attention or SAE")
if do_attn and pruning_config.attn_mode not in ["pre-out", "value"]:
raise NotImplementedError("attn_mode must be 'pre-out' or 'value'")

# Get midlayer activations of FF and ATTN
if pruning_config.recalculate_activations:
sae_enabled = False
if pruning_config.sae_frac > 0:
sae_enabled = True

focus_out = get_midlayer_data( opt, pruning_config.focus,
pruning_config.collection_sample_size, pruning_config.attn_mode )
pruning_config.collection_sample_size, pruning_config.attn_mode, calculate_sae=sae_enabled, collect_sae=sae_enabled )
cripple_out = get_midlayer_data( opt, pruning_config.cripple,
pruning_config.collection_sample_size, pruning_config.attn_mode )
pruning_config.collection_sample_size, pruning_config.attn_mode, calculate_sae=sae_enabled, collect_sae=sae_enabled )

# Otherwise, import activation data, and adjust the "pruning fraction"
else:
Expand Down Expand Up @@ -81,9 +86,11 @@ def score_and_prune( opt: Model,
):
# Get the top fraction FF activations and prune
ff_frac, ff_eps = pruning_config.ff_frac, pruning_config.ff_eps
sae_frac, sae_eps = pruning_config.sae_frac, pruning_config.sae_eps
attn_frac, attn_eps = pruning_config.attn_frac, pruning_config.attn_eps
do_ff = ff_frac > 0
do_attn = attn_frac > 0
do_sae = sae_frac > 0

act_subset = pruning_config.scoring_normalization
if do_ff > 0:
Expand All @@ -94,6 +101,18 @@ def score_and_prune( opt: Model,
ff_scores = ff_scoring_fn(opt, ff_focus_data, ff_cripple_data, ff_eps)
ff_criteria, ff_threshold = get_top_frac(ff_scores, ff_frac)
opt.hooks.delete_mlp_neurons(ff_criteria)
if do_sae > 0:
sae_hook_points = [point for point, layers in opt.hooks.hook_config.hook_points.items()
if 'all' in layers and any('sae' in hook for hook in layers['all'])]
for sae_hook in sae_hook_points:
sae_focus_data = focus_activations_data.sae[sae_hook]
sae_cripple_data = cripple_activations_data.sae[sae_hook]
sae_scoring_fn = score_indices_by(pruning_config.sae_scoring)

sae_scores = sae_scoring_fn(opt, sae_focus_data.orig, sae_cripple_data.orig, sae_eps)
sae_criteria, sae_threshold = get_top_frac(sae_scores, sae_frac)

opt.hooks[sae_hook].delete_neurons(sae_criteria)

# Get the top fraction of Attention activations and prune
if do_attn > 0:
Expand Down
Loading