-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreproduce_error.py
More file actions
115 lines (83 loc) · 3.86 KB
/
reproduce_error.py
File metadata and controls
115 lines (83 loc) · 3.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from sae_lens import HookedSAETransformer, SAE, SAEConfig
from gemma_utils import get_gemma_2_config, gemma_2_sae_loader
import torch
torch.set_grad_enabled(False)
from sae_lens import SAE
device = "cpu"
model = HookedSAETransformer.from_pretrained("google/gemma-2-2b-it", device=device)
full_strings = {
10:"layer_10/width_16k/average_l0_77",
}
attn_repo_id = "google/gemma-scope-2b-pt-att"
layers = [10]
with torch.no_grad():
repo_id = "google/gemma-scope-2b-pt-res"
folder_name = "layer_10/width_16k/average_l0_77"
config = get_gemma_2_config(repo_id, folder_name)
cfg, state_dict, log_spar = gemma_2_sae_loader(repo_id, folder_name)
sae_cfg = SAEConfig.from_dict(cfg)
sae = SAE(sae_cfg)
sae.load_state_dict(state_dict)
sae.d_head = 256
sae.use_error_term = True
string = "The quick brown fox jumps over the lazy dog."
tokens = model.to_tokens(string)
sae_filter = lambda x: "hook_sae_output" in x
# ============= Original Logits =========
original_logits,original_cache = model.run_with_cache(tokens, names_filter = sae_filter)
model.reset_hooks(including_permanent=True)
# ============= Add SAEs with error term ===========
sae.use_error_term = True
model.add_sae(sae)
logits_with_sae_we, cache_with_sae_we = model.run_with_cache(tokens,names_filter = sae_filter)
# ============ Add SAEs w/o error term =========
model.reset_saes() # Reset the model SAEs
from copy import deepcopy
sae_ne = deepcopy(sae)
sae_ne.use_error_term = False
model.add_sae(sae_ne)
logits_with_sae_ne, cache_with_sae_ne = model.run_with_cache(tokens,names_filter = sae_filter)
# =========== Add SAEs with error term and ablate some feature =========
model.reset_hooks()# Correct order
model.reset_saes()
sae.use_error_term = True
model.add_sae(sae)
def ablate_hooked_sae(acts,hook):
acts[:,:,:] = 20 # This is absurd
return acts
with model.hooks(fwd_hooks = [("blocks.10.hook_resid_post.hook_sae_acts_post",ablate_hooked_sae)]):
logits_with_ablated_sae,cache_with_ablated_sae = model.run_with_cache(tokens, names_filter = sae_filter)
# ===== Comparison of the logits ==========
print("Original Logits & Logits with SAEs with error term") # Should be true
print(torch.allclose(logits_with_sae_we, original_logits, atol=1))
print("Original Logits & Logits with SAEs with error term") # Should be false
print(torch.allclose(logits_with_sae_ne, original_logits, atol=1))
print("Original Logits & Logits with SAEs with error term") # Should be false
print(torch.allclose(logits_with_ablated_sae, original_logits, atol=1))
# ===== Comparison of the SAE output ==========
cache_with_sae_we = cache_with_sae_we["blocks.10.hook_resid_post.hook_sae_output"]
cache_with_sae_ne = cache_with_sae_ne["blocks.10.hook_resid_post.hook_sae_output"]
cache_with_ablated_sae = cache_with_ablated_sae["blocks.10.hook_resid_post.hook_sae_output"]
print("Cache with SAEs with error term & Cache with SAEs without error term") # Should be False
print(torch.allclose(cache_with_sae_we, cache_with_sae_ne, atol=1))
print("Cache with SAEs with error term & Cache with SAEs with error term and ablation") # Should be false
print(torch.allclose(cache_with_ablated_sae, cache_with_sae_we, atol=1))
print("Cache with SAEs with no error term & Cache with SAEs with error term and ablation") # Should be false
print(torch.allclose(cache_with_ablated_sae, cache_with_sae_ne, atol=1))
"""
Logits:
Original Logits & Logits with SAEs with error term
True
Original Logits & Logits with SAEs with error term
False
Original Logits & Logits with SAEs with error term
True
------------------
SAE Output:
Cache with SAEs with error term & Cache with SAEs without error term (should be False)
False
Cache with SAEs with error term & Cache with SAEs with error term and ablation (should be False)
True
Cache with SAEs with no error term & Cache with SAEs with error term and ablation (should be False)
False
"""