Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions examples/early-testing/test_residual_pruning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#!/usr/bin/env python3
"""Simple test script for residual stream pruning functionality."""

from taker import Model
from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem
from taker.prune import prune_and_evaluate
from taker.activations import get_midlayer_data
from taker.eval import evaluate_all
import torch
import wandb

hook_config = """
post_decoder: mask, collect
"""

c = PruningConfig("gpt2",
dtype="fp32", # MPS compatibility
wandb_entity = "seperability",
wandb_project = "bens-tests",
wandb_run_name = "gpt2 residual stream prune test",
token_limit = 100,
# Residual stream pruning only
ff_frac = 0.0,
attn_frac = 0.0,
residual_frac = 0.1, # Prune 10% of residual dimensions
residual_scoring = "abs",
focus = "civil",
cripple = "toxic",
recalculate_activations = False,
collection_sample_size = 100,
eval_sample_size = 100,
n_steps = 3,
)

m = Model("gpt2", hook_config=hook_config)
m.hooks.enable_collect_hooks(["post_decoder"], run_assert=True)

# Get initial activations
focus_data = get_midlayer_data(m, "civil", 100, collect_residual=True, calculate_residual=True,
collect_ff=False, calculate_ff=False, collect_attn=False, calculate_attn=False)
cripple_data = get_midlayer_data(m, "toxic", 100, collect_residual=True, calculate_residual=True,
collect_ff=False, calculate_ff=False, collect_attn=False, calculate_attn=False)

history = RunDataHistory(list(c.datasets))
wandb.init(
project=c.wandb_project,
entity=c.wandb_entity,
name=c.wandb_run_name,
)
wandb.config.update(c.to_dict(), allow_val_change=True)

torch.set_grad_enabled(False)

# Run the pruning
with torch.no_grad():
# Evaluate without pruning first
data = RunDataItem()
eval_out = evaluate_all(m, c.eval_sample_size, c.datasets,
dataset_tokens_to_skip=c.collection_sample_size)
data.update(eval_out)
history.add(data)

for i in range(c.n_steps):
print(f"Step {i}")
data = prune_and_evaluate(m, c, focus_data, cripple_data, i)
history.add(data)
print(f"Residual dimensions pruned: {data.deletions['residual_del']}")
print(f"Residual threshold: {data.deletions['residual_threshold']}")

# Get new activations after pruning for next iteration
focus_data = get_midlayer_data(m, "civil", 100, collect_residual=True, calculate_residual=True,
collect_ff=False, calculate_ff=False, collect_attn=False, calculate_attn=False)
cripple_data = get_midlayer_data(m, "toxic", 100, collect_residual=True, calculate_residual=True,
collect_ff=False, calculate_ff=False, collect_attn=False, calculate_attn=False)
24 changes: 23 additions & 1 deletion src/taker/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ def get_midlayer_data(opt: Model,
calculate_ff: bool = True,
calculate_attn: bool = True,
calculate_sae: bool = False,
calculate_residual: bool = False,
collect_ff: bool = False,
collect_attn: bool = False,
collect_ids: bool = False,
collect_sae: bool = False,
collect_residual: bool = False,
dataset_texts_to_skip: int = None,
random_subset_frac: float = None,
eval_config: EvalConfig = None,
Expand All @@ -101,8 +103,9 @@ def get_midlayer_data(opt: Model,

do_ff = calculate_ff or collect_ff
do_attn = calculate_attn or collect_attn
do_collect = collect_ff or collect_attn or collect_ids or collect_sae
do_collect = collect_ff or collect_attn or collect_ids or collect_sae or collect_residual
do_sae = calculate_sae or collect_sae
do_residual = calculate_residual or collect_residual

# Get things ready for collection
opt.hooks.disable_all_collect_hooks()
Expand All @@ -122,6 +125,12 @@ def get_midlayer_data(opt: Model,
if attn_peak is not None:
attn_data_peak_centered = ActivationCollector(attn_shape, opt.output_device)

# residual stream activation collector
if do_residual:
residual_shape = (opt.cfg.n_layers, opt.cfg.d_model)
residual_data = ActivationCollector(residual_shape, opt.output_device, collect_residual)
opt.hooks.enable_collect_hooks(["post_decoder"])

if do_sae:
sae_hook_points = [point for point, layers in opt.hooks.hook_config.hook_points.items()
if 'all' in layers and any('sae' in hook for hook in layers['all'])]
Expand Down Expand Up @@ -160,6 +169,11 @@ def get_midlayer_data(opt: Model,
if do_attn:
attn_acts = opt.collect_recent_attn_pre_out()
attn_acts = einops.rearrange(attn_acts, "b l t nh dh -> (b t) l nh dh")
if do_residual:
residual_acts = opt.hooks.get_all_layer_data("post_decoder", "collect")
# Stack across layers: (n_layers, batch, seq_len, d_model)
residual_acts = torch.stack([act for act in residual_acts if act is not None])
residual_acts = einops.rearrange(residual_acts, "l b t d -> (b t) l d")
if do_sae:
sae_acts = {}
for sae_hook in sae_hook_points:
Expand Down Expand Up @@ -201,6 +215,8 @@ def get_midlayer_data(opt: Model,
attn_data.add_all(attn_acts[criteria_indices])
if attn_peak is not None:
attn_data_peak_centered.add_all((attn_acts - attn_peak)[criteria_indices])
if do_residual:
residual_data.add_all(residual_acts[criteria_indices])
if do_sae:
for sae_hook in sae_hook_points:
sae_data[sae_hook].add_all(sae_acts[sae_hook])
Expand Down Expand Up @@ -230,6 +246,10 @@ def get_midlayer_data(opt: Model,
orig=attn_data.summary(dtype=opt.dtype),
peak_centered = attn_data_peak_centered.summary(dtype=opt.dtype, allow_nan=True) if attn_peak is not None else None,
)
if calculate_residual:
output["residual"] = ActivationSummaryHolder(
orig=residual_data.summary(dtype=opt.dtype),
)
if calculate_sae:
output["sae"] = {}
for sae_hook in sae_hook_points:
Expand All @@ -242,6 +262,8 @@ def get_midlayer_data(opt: Model,
output["raw"]["mlp"] = ff_data.get_raw()
if collect_attn:
output["raw"]["attn"] = attn_data.get_raw()
if collect_residual:
output["raw"]["residual"] = residual_data.get_raw()
if collect_sae:
output["raw"]["sae"] = {sae_hook: sae_data[sae_hook].get_raw() for sae_hook in sae_hook_points}
if collect_ids:
Expand Down
4 changes: 4 additions & 0 deletions src/taker/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ class ActivationOverview:
mlp: Optional[ActivationSummaryHolder] = None
sae: Optional[Dict[str, ActivationSummaryHolder]] = None
attn: Optional[ActivationSummaryHolder] = None
residual: Optional[ActivationSummaryHolder] = None
raw: Optional[dict] = None
misc_data: Optional[dict] = None

Expand Down Expand Up @@ -610,6 +611,9 @@ class PruningConfig:
sae_eps: float = 0.001
attn_frac: float = 0.0
attn_eps: float = 1e-4
residual_frac: float = 0.0
residual_eps: float = 0.001
residual_scoring: str = "abs"
dtype: str = "fp16"
use_accelerator: bool = True
model_device: Optional[str] = None
Expand Down
3 changes: 3 additions & 0 deletions src/taker/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def delete_mlp_neurons(self, remove_indices, layer: int = None):
def delete_attn_neurons(self, remove_indices, layer: int = None):
return self["attn_pre_out"].delete_neurons(remove_indices, layer)

def delete_residual_dimensions(self, remove_indices, layer: int = None):
return self["post_decoder"].delete_neurons(remove_indices, layer)

def reset_neuron_replace(self):
[h.reset() for h in self.neuron_replace.values()]

Expand Down
56 changes: 49 additions & 7 deletions src/taker/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ def prune_and_evaluate(
"""
c = copy.deepcopy(pruning_config)

# Find out what we are doing
# Find out what we are doing
do_ff = pruning_config.ff_frac > 0
do_attn = pruning_config.attn_frac > 0
do_sae = pruning_config.sae_frac > 0
if not do_ff and not do_attn and not do_sae:
raise ValueError("Must prune at least one of FF or Attention or SAE")
do_residual = pruning_config.residual_frac > 0
if not do_ff and not do_attn and not do_sae and not do_residual:
raise ValueError("Must prune at least one of FF, Attention, SAE, or Residual")
if do_attn and pruning_config.attn_mode not in ["pre-out", "value"]:
raise NotImplementedError("attn_mode must be 'pre-out' or 'value'")

Expand All @@ -51,11 +52,16 @@ def prune_and_evaluate(
sae_enabled = False
if pruning_config.sae_frac > 0:
sae_enabled = True
residual_enabled = pruning_config.residual_frac > 0

focus_out = get_midlayer_data( opt, pruning_config.focus,
pruning_config.collection_sample_size, pruning_config.attn_mode, calculate_sae=sae_enabled, collect_sae=sae_enabled )
pruning_config.collection_sample_size, pruning_config.attn_mode,
calculate_sae=sae_enabled, collect_sae=sae_enabled,
calculate_residual=residual_enabled, collect_residual=residual_enabled )
cripple_out = get_midlayer_data( opt, pruning_config.cripple,
pruning_config.collection_sample_size, pruning_config.attn_mode, calculate_sae=sae_enabled, collect_sae=sae_enabled )
pruning_config.collection_sample_size, pruning_config.attn_mode,
calculate_sae=sae_enabled, collect_sae=sae_enabled,
calculate_residual=residual_enabled, collect_residual=residual_enabled )

# Otherwise, import activation data, and adjust the "pruning fraction"
else:
Expand Down Expand Up @@ -88,11 +94,19 @@ def score_and_prune( opt: Model,
ff_frac, ff_eps = pruning_config.ff_frac, pruning_config.ff_eps
sae_frac, sae_eps = pruning_config.sae_frac, pruning_config.sae_eps
attn_frac, attn_eps = pruning_config.attn_frac, pruning_config.attn_eps
residual_frac, residual_eps = pruning_config.residual_frac, pruning_config.residual_eps
do_ff = ff_frac > 0
do_attn = attn_frac > 0
do_sae = sae_frac > 0
do_residual = residual_frac > 0

act_subset = pruning_config.scoring_normalization

# Initialize variables to avoid undefined variable issues
ff_scores, ff_criteria, ff_threshold = None, None, 0
attn_scores, attn_criteria, attn_threshold = None, None, 0
residual_scores, residual_criteria, residual_threshold = None, None, 0

if do_ff > 0:
ff_focus_data = focus_activations_data.mlp[act_subset]
ff_cripple_data = cripple_activations_data.mlp[act_subset]
Expand All @@ -114,6 +128,22 @@ def score_and_prune( opt: Model,

opt.hooks[sae_hook].delete_neurons(sae_criteria)

# Residual stream pruning logic
if do_residual > 0:
# Get activation data for residual stream
residual_focus_data = focus_activations_data.residual.orig
residual_cripple_data = cripple_activations_data.residual.orig

# Score residual stream dimensions
residual_scoring_fn = score_indices_by(pruning_config.residual_scoring)
residual_scores = residual_scoring_fn(opt, residual_focus_data, residual_cripple_data, residual_eps)

# Determine which dimensions to prune
residual_criteria, residual_threshold = get_top_frac(residual_scores, residual_frac)

# Perform the actual pruning
opt.hooks.delete_residual_dimensions(residual_criteria)

# Get the top fraction of Attention activations and prune
if do_attn > 0:
attn_focus_data = focus_activations_data.attn[act_subset]
Expand Down Expand Up @@ -153,8 +183,10 @@ def score_and_prune( opt: Model,
"ff_scores": ff_scores if do_ff else None,
# FIXME: doesn't return attn_std_mean
"attn_scores": attn_scores if do_attn else None,
"residual_scores": residual_scores if do_residual else None,
"ff_criteria": ff_criteria if do_ff else None,
"attn_criteria": attn_criteria if do_attn else None,
"residual_criteria": residual_criteria if do_residual else None,
}

if save:
Expand All @@ -169,13 +201,16 @@ def score_and_prune( opt: Model,
data.update({'deletions': {
"ff_threshold": ff_threshold if do_ff else 0,
"attn_threshold": attn_threshold if do_attn else 0,
"residual_threshold": residual_threshold if do_residual else 0,
"ff_del": float( torch.sum(ff_criteria) ) if do_ff else 0,
"attn_del": float( torch.sum(attn_criteria) ) if do_attn else 0,
"residual_del": float( torch.sum(residual_criteria) ) if do_residual else 0,
}})

data.update({'deletions_per_layer': {
'ff': ff_criteria.sum(dim=-1).tolist() if do_ff else [],
'attn': attn_criteria.sum(dim=-1).tolist() if do_attn else [],
'residual': residual_criteria.sum(dim=-1).tolist() if do_residual else [],
}})

# Save removals and scores to history
Expand Down Expand Up @@ -338,10 +373,17 @@ def run_pruning(c: PruningConfig):

# Non-iteratively get activations, then iteratively prune and evaluate
else:
sae_enabled = c.sae_frac > 0
residual_enabled = c.residual_frac > 0

focus_out = get_midlayer_data(opt, c.focus,
c.collection_sample_size, c.attn_mode)
c.collection_sample_size, c.attn_mode,
calculate_sae=sae_enabled, collect_sae=sae_enabled,
calculate_residual=residual_enabled, collect_residual=residual_enabled)
cripple_out = get_midlayer_data(opt, c.cripple,
c.collection_sample_size, c.attn_mode)
c.collection_sample_size, c.attn_mode,
calculate_sae=sae_enabled, collect_sae=sae_enabled,
calculate_residual=residual_enabled, collect_residual=residual_enabled)
for i in range(c.n_steps):
data = prune_and_evaluate(opt, c, focus_out, cripple_out, i)
history.add(data)
Expand Down
Loading