Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions examples/neuron-mapping/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch

def compare_pruned_ff_criteria(cripple_repos: list[str], model_size: str):
# cripple_repos = ["physics", "bio", "code"]
directory = "/home/ubuntu/taker-rashid/examples/neuron-mapping/saved_tensors/"+model_size+"/"
focus_repo = "pile"
suffix = "-"+model_size+"-recent.pt"
ratios = {}
ratios["model_size"] = model_size

for repo1 in cripple_repos:
#load ff_criteria from repo1
repo1_tensors = torch.load(directory+repo1+"-"+focus_repo+suffix)
repo1_ff_criteria = repo1_tensors["ff_criteria"]
ratios[repo1] = {}
for repo2 in cripple_repos:
if repo1 == repo2:
continue
#load ff_criteria from repo2
repo2_tensors = torch.load(directory+repo2+"-"+focus_repo+suffix)
repo2_ff_criteria = repo2_tensors["ff_criteria"]


matches = torch.logical_and(repo1_ff_criteria, repo2_ff_criteria)
ratio = torch.sum(matches)/torch.sum(repo1_ff_criteria)
ratios[repo1][repo2] = ratio

return ratios

print(compare_pruned_ff_criteria(["physics", "bio", "code"], "nickypro/tinyllama-15M"))
73 changes: 73 additions & 0 deletions examples/neuron-mapping/prune_repos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

from taker.data_classes import PruningConfig
from taker.parser import cli_parser
from taker.prune import run_pruning
import torch

def compare_pruned_ff_criteria(cripple_repos: list[str], model_size: str):
# cripple_repos = ["physics", "bio", "code"]
print("model_size: ",model_size)
directory = "/home/ubuntu/taker-rashid/examples/neuron-mapping/saved_tensors/"+model_size+"/"
focus_repo = "pile"
suffix = "-"+model_size+"-recent.pt"
ratios = {}
ratios["model_size"] = model_size

for repo1 in cripple_repos:
#load ff_criteria from repo1
repo1_tensors = torch.load(directory+repo1+"-"+focus_repo+suffix)
repo1_ff_criteria = repo1_tensors["ff_criteria"]
ratios[repo1] = {}
for repo2 in cripple_repos:
if repo1 == repo2:
continue
#load ff_criteria from repo2
repo2_tensors = torch.load(directory+repo2+"-"+focus_repo+suffix)
repo2_ff_criteria = repo2_tensors["ff_criteria"]

matches = torch.logical_and(repo1_ff_criteria, repo2_ff_criteria)
ratio = torch.sum(matches)/torch.sum(repo1_ff_criteria)
ratios[repo1][repo2] = ratio

return ratios


# Configure initial model and tests
c = PruningConfig(
wandb_project = "testing", # repo to push results to
model_repo = "nickypro/tinyllama-15M",
# "metallama/llama-2-7b"
token_limit = 1000, # trim the input to this max length
run_pre_test = True, # evaluate the unpruned model
eval_sample_size = 1e3,
collection_sample_size = 1e3,
# Removals parameters
ff_frac = 0.2, # % of feed forward neurons to prune
attn_frac = 0.00, # % of attention neurons to prune
focus = "pile", # the “reference” dataset
cripple = "physics", # the “unlearned” dataset
additional_datasets=tuple(), # any extra datasets to evaluate on
recalculate_activations = False, # iterative vs non-iterative
n_steps = 1,
)

# Parse CLI for arguments
# c, args = cli_parser(c)

#list of repos to cripple
cripple_repos = ["physics", "biology","chemistry", "math", "code", "poems", "civil", "stories"]
ff_frac_to_prune = [0.01]
model_size = c.model_repo.split('-')[-1]

# Run the iterated pruning for each cripple repo, for a range of ff_frac pruned
shared_pruning_data = {}
for ff_frac in ff_frac_to_prune:
c.ff_frac = ff_frac
for repo in cripple_repos:
c.cripple = repo
print("running iteration for ", c.cripple, " vs ", c.focus, "with ff_frac: ", ff_frac)
with torch.no_grad():
model, history = run_pruning(c)
ratios = compare_pruned_ff_criteria(cripple_repos, model_size)
shared_pruning_data[ff_frac] = ratios
print(shared_pruning_data)
13 changes: 11 additions & 2 deletions src/taker/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,17 @@ def save_timestamped_tensor_dict( opt: Model,
data: Dict[str, Tensor],
name: str ):
now = datetime.datetime.now().strftime( "%Y-%m-%d_%H:%M:%S" )
os.makedirs( f'tmp/{opt.model_size}', exist_ok=True )
filename = f'tmp/{opt.model_size}/{opt.model_size}-{name}-{now}.pt'
os.makedirs( f'saved_tensors/{opt.model_size}', exist_ok=True )
filename = f'saved_tensors/{opt.model_size}/{name}-{opt.model_size}-{now}.pt'
torch.save( data, filename )
print( f'Saved {filename} to {opt.model_size}' )
return filename

def save_tensor_dict( opt: Model,
data: Dict[str, Tensor],
name: str ):
os.makedirs( f'saved_tensors/{opt.model_size}', exist_ok=True )
filename = f'saved_tensors/{opt.model_size}/{name}-{opt.model_size}-recent.pt'
torch.save( data, filename )
print( f'Saved {filename} to {opt.model_size}' )
return filename
Expand Down
15 changes: 10 additions & 5 deletions src/taker/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .eval import evaluate_all
from .scoring import score_indices_by, score_indices
from .activations import get_midlayer_activations, get_top_frac, \
choose_attn_heads_by, save_timestamped_tensor_dict
choose_attn_heads_by, save_timestamped_tensor_dict, save_tensor_dict
from .texts import prepare

def prune_and_evaluate(
Expand Down Expand Up @@ -60,6 +60,7 @@ def prune_and_evaluate(

# Prune the model using the activation data
data = score_and_prune(opt, focus_out, cripple_out, c)
# Should return a dict with data["deletions"]["ff_pruned"]

# Evaluate the model
with torch.no_grad():
Expand All @@ -73,7 +74,7 @@ def score_and_prune( opt: Model,
focus_activations_data: ActivationOverview,
cripple_activations_data: ActivationOverview,
pruning_config: PruningConfig,
save=False,
save=True,
):
# Get the top fraction FF activations and prune
ff_frac, ff_eps = pruning_config.ff_frac, pruning_config.ff_eps
Expand Down Expand Up @@ -133,7 +134,9 @@ def score_and_prune( opt: Model,
"attn_criteria": attn_criteria if do_attn else None,
}
if save:
save_timestamped_tensor_dict( opt, tensor_data, "activation_metrics" )
#original save function with timestamp, but also version with most recent run saved without timestamp for easy loading, will overwrite old version.
save_timestamped_tensor_dict( opt, tensor_data, pruning_config.cripple + "-" + pruning_config.focus )
save_tensor_dict( opt, tensor_data, pruning_config.cripple + "-" + pruning_config.focus)

# Initialize the output dictionary
data = RunDataItem()
Expand All @@ -143,6 +146,8 @@ def score_and_prune( opt: Model,
"attn_threshold": attn_threshold if do_attn else 0,
"ff_del": float( torch.sum(ff_criteria) ) if do_ff else 0,
"attn_del": float( torch.sum(attn_criteria) ) if do_attn else 0,
"ff_scores": ff_scores.cpu().numpy(),
"ff_criteria": ff_criteria.cpu().numpy(),
}})

data.update({'deletions_per_layer': {
Expand Down Expand Up @@ -273,7 +278,7 @@ def run_pruning(c: PruningConfig):
entity=c.wandb_entity,
name=c.wandb_run_name,
)
wandb.config.update(c.to_dict())
wandb.config.update(c.to_dict(), allow_val_change=True)

# Evaluate model before removal of any neurons
if c.run_pre_test:
Expand Down Expand Up @@ -310,7 +315,7 @@ def run_pruning(c: PruningConfig):
print(history.history[-1])
print(history.df.T)
print(history.df.T.to_csv())

# print("masks: ", opt.masks["mlp_pre_out"])
return opt, history

######################################################################################
Expand Down
30 changes: 25 additions & 5 deletions src/taker/texts.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,17 +291,37 @@ def infer_dataset_config(dataset_name:str, dataset_subset:str=None):
dataset_image_label_key = "coarse_label",
dataset_filter=DatasetFilters.filter_veh2,
),
EvalConfig("bio",
dataset_repo = "camel-ai/biology",
dataset_text_key = "message_2",
dataset_has_test_split = False,
),
EvalConfig("emotion",
dataset_repo = "dair-ai/emotion",
dataset_type = "text-classification",
dataset_text_key = "text",
dataset_text_label_key = "label",
dataset_has_test_split = True,
),
EvalConfig("biology",
dataset_repo = "camel-ai/biology",
dataset_text_key = "message_2",
dataset_has_test_split = False,
),
EvalConfig("physics",
dataset_repo = "camel-ai/physics",
dataset_text_key = "message_2",
dataset_has_test_split = False,
),
EvalConfig("chemistry",
dataset_repo = "camel-ai/chemistry",
dataset_text_key = "message_2",
dataset_has_test_split = False,
),
EvalConfig("math",
dataset_repo = "camel-ai/math",
dataset_text_key = "message_2",
dataset_has_test_split = False,
),
EvalConfig("poems",
dataset_repo = "sadFaceEmoji/english-poems",
dataset_text_key = "poem",
dataset_has_test_split = False,
)
]

Expand Down