-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgemma_utils.py
More file actions
169 lines (149 loc) · 5.51 KB
/
gemma_utils.py
File metadata and controls
169 lines (149 loc) · 5.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import re
from typing import Optional, Dict, Any, Tuple
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from safetensors import safe_open
def get_gemma_2_config(
repo_id: str,
folder_name: str,
d_sae_override: Optional[int] = None,
layer_override: Optional[int] = None,
) -> Dict[str, Any]:
# Detect width from folder_name
width_map = {
"width_16k": 16384,
"width_32k": 32768,
"width_65k": 65536,
"width_131k": 131072,
"width_262k": 262144,
"width_524k": 524288,
"width_1m": 1048576,
}
d_sae = next(
(width for key, width in width_map.items() if key in folder_name), None
)
if d_sae is None:
if not d_sae_override:
raise ValueError("Width not found in folder_name and no override provided.")
d_sae = d_sae_override
# Detect layer from folder_name
match = re.search(r"layer_(\d+)", folder_name)
layer = int(match.group(1)) if match else layer_override
if layer is None:
raise ValueError("Layer not found in folder_name and no override provided.")
# Model specific parameters
model_params = {
"2b": {"name": "gemma-2-2b", "d_in": 2304},
"9b": {"name": "gemma-2-9b", "d_in": 3584},
"27b": {"name": "gemma-2-27b", "d_in": 4608},
}
model_info = next(
(info for key, info in model_params.items() if key in repo_id), None
)
if not model_info:
raise ValueError("Model name not found in repo_id.")
model_name, d_in = model_info["name"], model_info["d_in"]
# Hook specific parameters
if "res" in repo_id:
hook_name = f"blocks.{layer}.hook_resid_post"
elif "mlp" in repo_id:
hook_name = f"blocks.{layer}.hook_mlp_out"
elif "att" in repo_id:
hook_name = f"blocks.{layer}.attn.hook_z"
d_in = {"2b": 2048, "9b": 4096, "27b": 4608}.get(
next(key for key in model_params if key in repo_id), d_in
)
else:
raise ValueError("Hook name not found in folder_name.")
return {
"architecture": "jumprelu",
"d_in": d_in,
"d_sae": d_sae,
"dtype": "float32",
"model_name": model_name,
"hook_name": hook_name,
"hook_layer": layer,
"hook_head_index": None,
"activation_fn_str": "relu",
"finetuning_scaling_factor": False,
"sae_lens_training_version": None,
"prepend_bos": True,
"dataset_path": "monology/pile-uncopyrighted",
"context_size": 1024,
"dataset_trust_remote_code": True,
"apply_b_dec_to_input": False,
"normalize_activations": None,
}
def gemma_2_sae_loader(
repo_id: str,
folder_name: str,
device: str = "cpu",
force_download: bool = False,
cfg_overrides: Optional[Dict[str, Any]] = None,
d_sae_override: Optional[int] = None,
layer_override: Optional[int] = None,
) -> Tuple[Dict[str, Any], Dict[str, torch.Tensor], Optional[torch.Tensor]]:
"""
Custom loader for Gemma 2 SAEs.
"""
cfg_dict = get_gemma_2_config(repo_id, folder_name, d_sae_override, layer_override)
cfg_dict["device"] = device
# Apply overrides if provided
if cfg_overrides is not None:
cfg_dict.update(cfg_overrides)
# Download the SAE weights
sae_path = hf_hub_download(
repo_id=repo_id,
filename="params.npz",
subfolder=folder_name,
force_download=force_download,
)
# Load and convert the weights
state_dict = {}
with np.load(sae_path) as data:
for key in data.keys():
state_dict_key = "W_" + key[2:] if key.startswith("w_") else key
state_dict[state_dict_key] = (
torch.tensor(data[key]).to(dtype=torch.float32).to(device)
)
# Handle scaling factor
if "scaling_factor" in state_dict:
if torch.allclose(
state_dict["scaling_factor"], torch.ones_like(state_dict["scaling_factor"])
):
del state_dict["scaling_factor"]
cfg_dict["finetuning_scaling_factor"] = False
else:
assert cfg_dict[
"finetuning_scaling_factor"
], "Scaling factor is present but finetuning_scaling_factor is False."
state_dict["finetuning_scaling_factor"] = state_dict.pop("scaling_factor")
else:
cfg_dict["finetuning_scaling_factor"] = False
# No sparsity tensor for Gemma 2 SAEs
log_sparsity = None
return cfg_dict, state_dict, log_sparsity
# Helper function to get dtype from string
DTYPE_MAP = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
import pandas as pd
def get_all_string_min_l0_resid_gemma():
df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T
resid_dict = df[df['release'] == "gemma-scope-2b-pt-res"]['saes_map'][0]
splitted_list = [[e.split("_")[-1] for e in elem.split("/")] for elem in list(resid_dict.keys())]
full_dict = {}
for elem in splitted_list:
if elem[1]=="16k":
if elem[0] not in full_dict.keys():
full_dict[elem[0]] = {}
full_dict[elem[0]][elem[1]] = elem[2]
else:
if full_dict[elem[0]][elem[1]]>elem[2]:
full_dict[elem[0]][elem[1]] = elem[2]
full_strings = [f"layer_{key}/width_16k/average_l0_{val['16k']}" for key,val in full_dict.items()]
return full_strings