-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathentropy.py
More file actions
106 lines (79 loc) · 3.41 KB
/
entropy.py
File metadata and controls
106 lines (79 loc) · 3.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# %%
# %%
from torch.nn.functional import log_softmax
import torch
import pandas as pd
from sae_lens import HookedSAETransformer
import torch.nn.functional as F
import plotly.express as px
# %%
def logits_to_entropy(logits):
log_probs = log_softmax(logits, dim = -1)
probs = log_probs.exp()
entropy = -(log_probs*probs).sum(dim = -1)
return entropy
def logits_to_varentropy(logits):
log_probs = log_softmax(logits, dim = -1)
probs = log_probs.exp()
entropy = -(log_probs*probs).sum(dim = -1)
elem = (probs*(-log_probs)**2).sum(dim = -1)
return elem - entropy
def logits_to_prob(logits,pos,tok_id1,tok_id2):
log_probs = log_softmax(logits, dim = -1)
probs = log_probs.exp()
tup = [(probs[0,p,tok_id1].item(),probs[0,p,tok_id2].item()) for p in pos]
return tup
def plot_get_entropy(generation_dict,rep_tok):
all_entropy_hyphen = []
tokens = []
for key in list(generation_dict.keys()):
val = generation_dict[key]
for toks in val:
tokens.append(toks)
with torch.no_grad():
hyphen_pos = torch.where(toks[0] == 235290)[0]
break_pos = torch.where(toks[0] == 108)[0]
positions = (hyphen_pos[1:]-1).tolist() + [break_pos[-1].item()-2]
logits = model(toks)
entropy = logits_to_entropy(logits)
all_entropy_hyphen.append(entropy[:,positions])
max_size = max(tensor.size(1) for tensor in all_entropy_hyphen)
padded_tensors = []
for tensor in all_entropy_hyphen:
pad_amount = max_size - tensor.size(1)
padded_tensor = F.pad(tensor, (pad_amount, 0), "constant", 0) # Left padding
padded_tensors.append(padded_tensor)
stacked_entropy_hyphen = torch.cat(padded_tensors,dim = 0)
torch.cuda.empty_cache()
fig = px.imshow(stacked_entropy_hyphen.cpu().numpy(), aspect='auto')
fig.write_image(f"entropy_plot_clean_{rep_tok}.png")
# Plot the entropy for the corruped prompts
all_entropy_hyphen = []
tokens = []
for key in list(generation_dict.keys()):
val = generation_dict[key]
for toks in val:
toks[0,8] = rep_tok
tokens.append(toks)
with torch.no_grad():
hyphen_pos = torch.where(toks[0] == 235290)[0]
break_pos = torch.where(toks[0] == 108)[0]
positions = (hyphen_pos[1:]-1).tolist() + [break_pos[-1].item()-2]
logits = model(toks)
entropy = logits_to_entropy(logits)
all_entropy_hyphen.append(entropy[:,positions])
max_size = max(tensor.size(1) for tensor in all_entropy_hyphen)
padded_tensors = []
for tensor in all_entropy_hyphen:
pad_amount = max_size - tensor.size(1)
padded_tensor = F.pad(tensor, (pad_amount, 0), "constant", 0) # Left padding
padded_tensors.append(padded_tensor)
stacked_entropy_hyphen = torch.cat(padded_tensors,dim = 0)
px.imshow(stacked_entropy_hyphen.cpu().numpy(), aspect = 'auto')
fig.write_image(f"entropy_plot_corrupt_{rep_tok}.png")
if __name__ == "__main__":
model = HookedSAETransformer.from_pretrained("google/gemma-2-2b-it")
generation_dict = torch.load("generation_dicts/gemma2_generation_dict.pt")
plot_get_entropy(generation_dict, 1497)
generation_dict = torch.load("generation_dicts/gemma2_generation_long_dict.pt")
plot_get_entropy(generation_dict, 3309)