-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsae.py
More file actions
820 lines (673 loc) · 30.1 KB
/
sae.py
File metadata and controls
820 lines (673 loc) · 30.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
"""Most of this is just copied over from Arthur's code and slightly simplified:
https://github.com/ArthurConmy/sae/blob/main/sae/model.py
"""
import json
import os
import warnings
from dataclasses import dataclass, field
from typing import Any, Callable, Literal, Optional, Tuple, TypeVar, Union, overload
T = TypeVar("T", bound="SAE")
import einops
import torch
from jaxtyping import Float
from safetensors.torch import save_file
from torch import nn
from transformer_lens.hook_points import HookedRootModule, HookPoint
from sae_lens.config import DTYPE_MAP
from sae_lens.toolkit.pretrained_sae_loaders import (
NAMED_PRETRAINED_SAE_LOADERS,
handle_config_defaulting,
read_sae_from_disk,
)
from sae_lens.toolkit.pretrained_saes_directory import (
get_norm_scaling_factor,
get_pretrained_saes_directory,
)
SPARSITY_PATH = "sparsity.safetensors"
SAE_WEIGHTS_PATH = "sae_weights.safetensors"
SAE_CFG_PATH = "cfg.json"
@dataclass
class SAEConfig:
# architecture details
architecture: Literal["standard", "gated", "jumprelu"]
# forward pass details.
d_in: int
d_sae: int
activation_fn_str: str
apply_b_dec_to_input: bool
finetuning_scaling_factor: bool
# dataset it was trained on details.
context_size: int
model_name: str
hook_name: str
hook_layer: int
hook_head_index: Optional[int]
prepend_bos: bool
dataset_path: str
dataset_trust_remote_code: bool
normalize_activations: str
# misc
dtype: str
device: str
sae_lens_training_version: Optional[str]
activation_fn_kwargs: dict[str, Any] = field(default_factory=dict)
neuronpedia_id: Optional[str] = None
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
# rename dict:
rename_dict = { # old : new
"hook_point": "hook_name",
"hook_point_head_index": "hook_head_index",
"hook_point_layer": "hook_layer",
"activation_fn": "activation_fn_str",
}
config_dict = {rename_dict.get(k, k): v for k, v in config_dict.items()}
# use only config terms that are in the dataclass
config_dict = {
k: v
for k, v in config_dict.items()
if k in cls.__dataclass_fields__ # pylint: disable=no-member
}
return cls(**config_dict)
# def __post_init__(self):
def to_dict(self) -> dict[str, Any]:
return {
"architecture": self.architecture,
"d_in": self.d_in,
"d_sae": self.d_sae,
"dtype": self.dtype,
"device": self.device,
"model_name": self.model_name,
"hook_name": self.hook_name,
"hook_layer": self.hook_layer,
"hook_head_index": self.hook_head_index,
"activation_fn_str": self.activation_fn_str, # use string for serialization
"activation_fn_kwargs": self.activation_fn_kwargs or {},
"apply_b_dec_to_input": self.apply_b_dec_to_input,
"finetuning_scaling_factor": self.finetuning_scaling_factor,
"sae_lens_training_version": self.sae_lens_training_version,
"prepend_bos": self.prepend_bos,
"dataset_path": self.dataset_path,
"dataset_trust_remote_code": self.dataset_trust_remote_code,
"context_size": self.context_size,
"normalize_activations": self.normalize_activations,
"neuronpedia_id": self.neuronpedia_id,
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
}
class SAE(HookedRootModule):
"""
Core Sparse Autoencoder (SAE) class used for inference. For training, see `TrainingSAE`.
"""
cfg: SAEConfig
dtype: torch.dtype
device: torch.device
# analysis
use_error_term: bool
def __init__(
self,
cfg: SAEConfig,
use_error_term: bool = False,
):
super().__init__()
self.cfg = cfg
if cfg.model_from_pretrained_kwargs:
warnings.warn(
"\nThis SAE has non-empty model_from_pretrained_kwargs. "
"\nFor optimal performance, load the model like so:\n"
"model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
category=UserWarning,
stacklevel=1,
)
self.activation_fn = get_activation_fn(
cfg.activation_fn_str, **cfg.activation_fn_kwargs or {}
)
self.dtype = DTYPE_MAP[cfg.dtype]
self.device = torch.device(cfg.device)
self.use_error_term = use_error_term
if self.cfg.architecture == "standard":
self.initialize_weights_basic()
self.encode = self.encode_standard
elif self.cfg.architecture == "gated":
self.initialize_weights_gated()
self.encode = self.encode_gated
elif self.cfg.architecture == "jumprelu":
self.initialize_weights_jumprelu()
self.encode = self.encode_jumprelu
else:
raise (ValueError)
# handle presence / absence of scaling factor.
if self.cfg.finetuning_scaling_factor:
self.apply_finetuning_scaling_factor = (
lambda x: x * self.finetuning_scaling_factor
)
else:
self.apply_finetuning_scaling_factor = lambda x: x
# set up hooks
self.hook_sae_input = HookPoint()
self.hook_sae_acts_pre = HookPoint()
self.hook_sae_acts_post = HookPoint()
self.hook_sae_output = HookPoint()
self.hook_sae_recons = HookPoint()
self.hook_sae_error = HookPoint()
# handle hook_z reshaping if needed.
# this is very cursed and should be refactored. it exists so that we can reshape out
# the z activations for hook_z SAEs. but don't know d_head if we split up the forward pass
# into a separate encode and decode function.
# this will cause errors if we call decode before encode.
if self.cfg.hook_name.endswith("_z"):
self.turn_on_forward_pass_hook_z_reshaping()
else:
# need to default the reshape fns
self.turn_off_forward_pass_hook_z_reshaping()
# handle run time activation normalization if needed:
if self.cfg.normalize_activations == "constant_norm_rescale":
# we need to scale the norm of the input and store the scaling factor
def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
x = x * self.x_norm_coeff
return x
def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor: #
x = x / self.x_norm_coeff
del self.x_norm_coeff # prevents reusing
return x
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
elif self.cfg.normalize_activations == "layer_norm":
# we need to scale the norm of the input and store the scaling factor
def run_time_activation_ln_in(
x: torch.Tensor, eps: float = 1e-5
) -> torch.Tensor:
mu = x.mean(dim=-1, keepdim=True)
x = x - mu
std = x.std(dim=-1, keepdim=True)
x = x / (std + eps)
self.ln_mu = mu
self.ln_std = std
return x
def run_time_activation_ln_out(x: torch.Tensor, eps: float = 1e-5):
return x * self.ln_std + self.ln_mu
self.run_time_activation_norm_fn_in = run_time_activation_ln_in
self.run_time_activation_norm_fn_out = run_time_activation_ln_out
else:
self.run_time_activation_norm_fn_in = lambda x: x
self.run_time_activation_norm_fn_out = lambda x: x
self.setup() # Required for `HookedRootModule`s
def initialize_weights_basic(self):
# no config changes encoder bias init for now.
self.b_enc = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)
# Start with the default init strategy:
self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
)
)
)
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
)
)
)
# methdods which change b_dec as a function of the dataset are implemented after init.
self.b_dec = nn.Parameter(
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
)
# scaling factor for fine-tuning (not to be used in initial training)
# TODO: Make this optional and not included with all SAEs by default (but maintain backwards compatibility)
if self.cfg.finetuning_scaling_factor:
self.finetuning_scaling_factor = nn.Parameter(
torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)
def initialize_weights_gated(self):
# Initialize the weights and biases for the gated encoder
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
)
)
)
self.b_gate = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)
self.r_mag = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)
self.b_mag = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)
self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
)
)
)
self.b_dec = nn.Parameter(
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
)
def initialize_weights_jumprelu(self):
# The params are identical to the standard SAE
# except we use a threshold parameter too
self.threshold = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)
self.b_enc = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)
self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
)
)
)
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
)
)
)
self.b_dec = nn.Parameter(
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
)
@overload
def to(
self: T,
device: Optional[Union[torch.device, str]] = ...,
dtype: Optional[torch.dtype] = ...,
non_blocking: bool = ...,
) -> T: ...
@overload
def to(self: T, dtype: torch.dtype, non_blocking: bool = ...) -> T: ...
@overload
def to(self: T, tensor: torch.Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args: Any, **kwargs: Any) -> "SAE": # type: ignore
device_arg = None
dtype_arg = None
# Check args
for arg in args:
if isinstance(arg, (torch.device, str)):
device_arg = arg
elif isinstance(arg, torch.dtype):
dtype_arg = arg
elif isinstance(arg, torch.Tensor):
device_arg = arg.device
dtype_arg = arg.dtype
# Check kwargs
device_arg = kwargs.get("device", device_arg)
dtype_arg = kwargs.get("dtype", dtype_arg)
if device_arg is not None:
# Convert device to torch.device if it's a string
device = (
torch.device(device_arg) if isinstance(device_arg, str) else device_arg
)
# Update the cfg.device
self.cfg.device = str(device)
# Update the .device property
self.device = device
if dtype_arg is not None:
# Update the cfg.dtype
self.cfg.dtype = str(dtype_arg)
# Update the .dtype property
self.dtype = dtype_arg
# Call the parent class's to() method to handle all cases (device, dtype, tensor)
return super().to(*args, **kwargs)
# Basic Forward Pass Functionality.
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
feature_acts = self.encode(x)
sae_out = self.decode(feature_acts)
# TEMP
if self.use_error_term and self.cfg.architecture == "standard":
with torch.no_grad():
# Recompute everything without hooks to get true error term
# Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct
# This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail
# NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction.
# move x to correct dtype
x = x.to(self.dtype)
# handle hook z reshaping if needed.
sae_in = self.reshape_fn_in(x) # type: ignore
# handle run time activation normalization if needed
sae_in = self.run_time_activation_norm_fn_in(sae_in)
# apply b_dec_to_input if using that method.
sae_in_cent = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input)
# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = sae_in_cent @ self.W_enc + self.b_enc
feature_acts = self.activation_fn(hidden_pre)
x_reconstruct_clean = self.reshape_fn_out(
self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec
+ self.b_dec,
d_head=self.d_head,
)
sae_out = self.run_time_activation_norm_fn_out(sae_out)
sae_error = self.hook_sae_error(x - x_reconstruct_clean)
return self.hook_sae_output(sae_out + sae_error)
# TODO: Add tests
elif self.use_error_term and self.cfg.architecture == "gated":
with torch.no_grad():
x = x.to(self.dtype)
sae_in = self.reshape_fn_in(x) # type: ignore
# handle run time activation normalization if needed
sae_in = self.run_time_activation_norm_fn_in(sae_in)
# apply b_dec_to_input if using that method.
sae_in = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input)
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
active_features = (gating_pre_activation > 0).float()
# Magnitude path with weight sharing
magnitude_pre_activation = self.hook_sae_acts_pre(
sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
)
feature_magnitudes = self.activation_fn(magnitude_pre_activation)
feature_acts_clean = self.hook_sae_acts_post(
active_features * feature_magnitudes
)
x_reconstruct_clean = self.reshape_fn_out(
self.apply_finetuning_scaling_factor(feature_acts_clean)
@ self.W_dec
+ self.b_dec,
d_head=self.d_head,
)
sae_error = self.hook_sae_error(x - x_reconstruct_clean)
return self.hook_sae_output(sae_out + sae_error)
# TODO: Add tests
elif self.use_error_term and self.cfg.architecture == "jumprelu":
with torch.no_grad():
x = x.to(self.dtype)
sae_in = self.reshape_fn_in(x) # type: ignore
# handle run time activation normalization if needed
sae_in = self.run_time_activation_norm_fn_in(sae_in)
# apply b_dec_to_input if using that method.
sae_in = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input)
# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = sae_in @ self.W_enc + self.b_enc
feature_acts_clean = self.activation_fn(hidden_pre) * (hidden_pre > self.threshold)
x_reconstruct_clean = self.reshape_fn_out(
self.apply_finetuning_scaling_factor(feature_acts_clean) @ self.W_dec
+ self.b_dec,
d_head=self.d_head, # TODO(conmy): d_head?! Eh?
)
sae_out = self.run_time_activation_norm_fn_out(sae_out)
sae_error = self.hook_sae_error(x - x_reconstruct_clean)
return self.hook_sae_output(sae_out + sae_error)
elif self.use_error_term:
raise ValueError(f"No error term implemented for {self.cfg.architecture=}")
return self.hook_sae_output(sae_out)
def encode_gated(
self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
x = x.to(self.dtype)
x = self.reshape_fn_in(x)
x = self.hook_sae_input(x)
x = self.run_time_activation_norm_fn_in(x)
sae_in = x - self.b_dec * self.cfg.apply_b_dec_to_input
# Gating path
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
active_features = (gating_pre_activation > 0).to(self.dtype)
# Magnitude path with weight sharing
magnitude_pre_activation = self.hook_sae_acts_pre(
sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
)
feature_magnitudes = self.activation_fn(magnitude_pre_activation)
feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)
return feature_acts
def encode_jumprelu(
self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calculate SAE features from inputs
"""
# move x to correct dtype
x = x.to(self.dtype)
# handle hook z reshaping if needed.
x = self.reshape_fn_in(x) # type: ignore
# handle run time activation normalization if needed
x = self.run_time_activation_norm_fn_in(x)
# apply b_dec_to_input if using that method.
sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input))
# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
feature_acts = self.hook_sae_acts_post(
self.activation_fn(hidden_pre) * (hidden_pre > self.threshold)
)
return feature_acts
def encode_standard(
self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calculate SAE features from inputs
"""
x = x.to(self.dtype)
x = self.reshape_fn_in(x)
x = self.hook_sae_input(x)
x = self.run_time_activation_norm_fn_in(x)
# apply b_dec_to_input if using that method.
sae_in = x - (self.b_dec * self.cfg.apply_b_dec_to_input)
# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
return feature_acts
def decode(
self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_in"]:
"""Decodes SAE feature activation tensor into a reconstructed input activation tensor."""
# "... d_sae, d_sae d_in -> ... d_in",
sae_out = self.hook_sae_recons(
self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec
)
# handle run time activation normalization if needed
# will fail if you call this twice without calling encode in between.
sae_out = self.run_time_activation_norm_fn_out(sae_out)
# handle hook z reshaping if needed.
sae_out = self.reshape_fn_out(sae_out, self.d_head) # type: ignore
return sae_out
@torch.no_grad()
def fold_W_dec_norm(self):
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
self.W_dec.data = self.W_dec.data / W_dec_norms
self.W_enc.data = self.W_enc.data * W_dec_norms.T
if self.cfg.architecture == "gated":
self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
else:
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
@torch.no_grad()
def fold_activation_norm_scaling_factor(
self, activation_norm_scaling_factor: float
):
self.W_enc.data = self.W_enc.data * activation_norm_scaling_factor
# previously weren't doing this.
self.W_dec.data = self.W_dec.data / activation_norm_scaling_factor
# once we normalize, we shouldn't need to scale activations.
self.cfg.normalize_activations = "none"
def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None):
if not os.path.exists(path):
os.mkdir(path)
# generate the weights
save_file(self.state_dict(), f"{path}/{SAE_WEIGHTS_PATH}")
# save the config
config = self.cfg.to_dict()
with open(f"{path}/{SAE_CFG_PATH}", "w") as f:
json.dump(config, f)
if sparsity is not None:
sparsity_in_dict = {"sparsity": sparsity}
save_file(sparsity_in_dict, f"{path}/{SPARSITY_PATH}") # type: ignore
@classmethod
def load_from_pretrained(
cls, path: str, device: str = "cpu", dtype: str | None = None
) -> "SAE":
# get the config
config_path = os.path.join(path, SAE_CFG_PATH)
with open(config_path, "r") as f:
cfg_dict = json.load(f)
cfg_dict = handle_config_defaulting(cfg_dict)
cfg_dict["device"] = device
if dtype is not None:
cfg_dict["dtype"] = dtype
weight_path = os.path.join(path, SAE_WEIGHTS_PATH)
cfg_dict, state_dict = read_sae_from_disk(
cfg_dict=cfg_dict,
weight_path=weight_path,
device=device,
dtype=DTYPE_MAP[cfg_dict["dtype"]],
)
sae_cfg = SAEConfig.from_dict(cfg_dict)
sae = cls(sae_cfg)
sae.load_state_dict(state_dict)
return sae
@classmethod
def from_pretrained(
cls,
release: str,
sae_id: str,
device: str = "cpu",
) -> Tuple["SAE", dict[str, Any], Optional[torch.Tensor]]:
"""
Load a pretrained SAE from the Hugging Face model hub.
Args:
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
device: The device to load the SAE on.
return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
"""
# get sae directory
sae_directory = get_pretrained_saes_directory()
# get the repo id and path to the SAE
if release not in sae_directory:
if "/" not in release:
raise ValueError(
f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
)
elif sae_id not in sae_directory[release].saes_map:
# If using Gemma Scope and not the canonical release, give a hint to use it
if (
"gemma-scope" in release
and "canonical" not in release
and f"{release}-canonical" in sae_directory
):
canonical_ids = list(
sae_directory[release + "-canonical"].saes_map.keys()
)
# Shorten the lengthy string of valid IDs
if len(canonical_ids) > 5:
str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
else:
str_canonical_ids = str(canonical_ids)
value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
else:
value_suffix = ""
valid_ids = list(sae_directory[release].saes_map.keys())
# Shorten the lengthy string of valid IDs
if len(valid_ids) > 5:
str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
else:
str_valid_ids = str(valid_ids)
raise ValueError(
f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
+ value_suffix
)
sae_info = sae_directory.get(release, None)
hf_repo_id = sae_info.repo_id if sae_info is not None else release
hf_path = sae_info.saes_map[sae_id] if sae_info is not None else sae_id
config_overrides = sae_info.config_overrides if sae_info is not None else None
neuronpedia_id = (
sae_info.neuronpedia_id[sae_id] if sae_info is not None else None
)
conversion_loader_name = "sae_lens"
if sae_info is not None and sae_info.conversion_func is not None:
conversion_loader_name = sae_info.conversion_func
if conversion_loader_name not in NAMED_PRETRAINED_SAE_LOADERS:
raise ValueError(
f"Conversion func {conversion_loader_name} not found in NAMED_PRETRAINED_SAE_LOADERS."
)
conversion_loader = NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name]
cfg_dict, state_dict, log_sparsities = conversion_loader(
repo_id=hf_repo_id,
folder_name=hf_path,
device=device,
force_download=False,
cfg_overrides=config_overrides,
)
sae = cls(SAEConfig.from_dict(cfg_dict))
sae.load_state_dict(state_dict)
sae.cfg.neuronpedia_id = neuronpedia_id
# Check if normalization is 'expected_average_only_in'
if cfg_dict.get("normalize_activations") == "expected_average_only_in":
norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
if norm_scaling_factor is not None:
sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
cfg_dict["normalize_activations"] = "none"
else:
warnings.warn(
f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
)
return sae, cfg_dict, log_sparsities
def get_name(self):
sae_name = f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_{self.cfg.d_sae}"
return sae_name
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "SAE":
return cls(SAEConfig.from_dict(config_dict))
def turn_on_forward_pass_hook_z_reshaping(self):
assert self.cfg.hook_name.endswith(
"_z"
), "This method should only be called for hook_z SAEs."
def reshape_fn_in(x: torch.Tensor):
self.d_head = x.shape[-1] # type: ignore
self.reshape_fn_in = lambda x: einops.rearrange(
x, "... n_heads d_head -> ... (n_heads d_head)"
)
return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")
self.reshape_fn_in = reshape_fn_in
self.reshape_fn_out = lambda x, d_head: einops.rearrange(
x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
)
self.hook_z_reshaping_mode = True
def turn_off_forward_pass_hook_z_reshaping(self):
self.reshape_fn_in = lambda x: x
self.reshape_fn_out = lambda x, d_head: x
self.d_head = None
self.hook_z_reshaping_mode = False
class TopK(nn.Module):
def __init__(
self, k: int, postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU()
):
super().__init__()
self.k = k
self.postact_fn = postact_fn
def forward(self, x: torch.Tensor) -> torch.Tensor:
topk = torch.topk(x, k=self.k, dim=-1)
values = self.postact_fn(topk.values)
result = torch.zeros_like(x)
result.scatter_(-1, topk.indices, values)
return result
def get_activation_fn(
activation_fn: str, **kwargs: Any
) -> Callable[[torch.Tensor], torch.Tensor]:
if activation_fn == "relu":
return torch.nn.ReLU()
elif activation_fn == "tanh-relu":
def tanh_relu(input: torch.Tensor) -> torch.Tensor:
input = torch.relu(input)
input = torch.tanh(input)
return input
return tanh_relu
elif activation_fn == "topk":
assert "k" in kwargs, "TopK activation function requires a k value."
k = kwargs.get("k", 1) # Default k to 1 if not provided
postact_fn = kwargs.get(
"postact_fn", nn.ReLU()
) # Default post-activation to ReLU if not provided
return TopK(k, postact_fn)
else:
raise ValueError(f"Unknown activation function: {activation_fn}")