-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmodel_similarity.py
More file actions
91 lines (79 loc) · 2.77 KB
/
model_similarity.py
File metadata and controls
91 lines (79 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Model similarity results using ITDA.
"""
import argparse
from itertools import product
import numpy as np
import matplotlib.pyplot as plt
import torch
from layer_similarity import (
get_atoms_from_wandb_run,
get_layered_runs_for_models,
get_similarity_measure,
)
from transformer_lens import HookedTransformer
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
parser = argparse.ArgumentParser(
description="Get the ITDA IoU similarity between a set of models."
)
parser.add_argument(
"--models",
nargs="+",
type=str,
required=True,
help="List of model names to compare.",
)
parser.add_argument(
"--wandb_entity",
type=str,
default="patrickaaleask",
help="Entity (user or team) under which the W&B project lives.",
)
parser.add_argument(
"--wandb_project",
type=str,
default="itda",
help="Name of the W&B project for logging or fetching runs.",
)
args = parser.parse_args()
atom_unions = {}
for model_name in args.models:
model = HookedTransformer.from_pretrained(model_name, device=device)
n_layers = model.cfg.n_layers
existing_itdas = get_layered_runs_for_models(
[model_name],
list(range(1, n_layers)),
entity=args.wandb_entity,
project=args.wandb_project,
)[model_name]
all_atom_indices = []
for run in existing_itdas.values():
atom_indices = get_atoms_from_wandb_run(run, args.wandb_project)[1]
all_atom_indices.append(atom_indices)
all_atom_indices = np.concatenate(all_atom_indices, axis=0)
atom_unions[model_name] = all_atom_indices
similarities = np.array(
[
get_similarity_measure(atom_unions[model1], atom_unions[model2])
for model1, model2 in product(args.models, args.models)
]
)
similarities = similarities.reshape(len(args.models), len(args.models))
fig, ax = plt.subplots()
cax = ax.matshow(similarities, cmap="viridis")
fig.colorbar(cax)
ax.set_xticks(np.arange(len(args.models)))
ax.set_yticks(np.arange(len(args.models)))
ax.set_xticklabels(args.models, rotation=90)
ax.set_yticklabels(args.models)
# TODO: The models names are long so this is not lovely
for i in range(len(args.models)):
for j in range(len(args.models)):
ax.text(j, i, f"{similarities[i, j]:.2f}", ha="center", va="center", color="white")
plt.tight_layout()
# TODO: this needs to be configurable
plt.savefig(f"artifacts/similarities/model_sim.png")