Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions sae_lens/saes/gated_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,13 @@ def encode(
)
feature_magnitudes = self.activation_fn(magnitude_pre_activation)

feature_acts = active_features * feature_magnitudes

if self.cfg.normalize_decoder:
feature_acts = feature_acts * self.W_dec.norm(dim=-1)

# Combine gating and magnitudes
return self.hook_sae_acts_post(active_features * feature_magnitudes)
return self.hook_sae_acts_post(feature_acts)

def decode(
self, feature_acts: Float[torch.Tensor, "... d_sae"]
Expand All @@ -82,9 +87,14 @@ def decode(
3) Run any reconstruction hooks and out-normalization if configured.
4) If the SAE was reshaping hook_z activations, reshape back.
"""
if self.cfg.normalize_decoder:
feature_acts = feature_acts * self.W_dec.norm(dim=-1)
# 1) optional finetuning scaling
# 2) linear transform
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
W_dec_for_sae_out = self.W_dec
if self.cfg.normalize_decoder:
W_dec_for_sae_out = W_dec_for_sae_out / ( W_dec_for_sae_out.norm(dim=-1, keepdim=True) + 1e-8 )
sae_out_pre = feature_acts @ W_dec_for_sae_out + self.b_dec
# 3) hooking and normalization
sae_out_pre = self.hook_sae_recons(sae_out_pre)
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
Expand Down Expand Up @@ -167,6 +177,9 @@ def encode_with_hidden_pre(
# Combine gating path and magnitude path
feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)

if self.cfg.normalize_decoder:
feature_acts = feature_acts * self.W_dec.norm(dim=-1)

# Return both the final feature activations and the pre-activation (for logging or penalty)
return feature_acts, magnitude_pre_activation

Expand All @@ -187,13 +200,15 @@ def calculate_aux_loss(
pi_gate_act = torch.relu(pi_gate)

# L1-like penalty scaled by W_dec norms
l1_loss = (
step_input.coefficients["l1"]
* torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
)
l1_loss = step_input.coefficients["l1"] * torch.sum(pi_gate_act, dim=-1).mean()

# Aux reconstruction: reconstruct x purely from gating path
via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
W_dec_for_aux = self.W_dec.detach()
if self.cfg.normalize_decoder:
W_dec_for_aux = W_dec_for_aux / ( W_dec_for_aux.norm(dim=-1, keepdim=True) + 1e-8 )
via_gate_reconstruction = (
pi_gate_act @ W_dec_for_aux + self.b_dec.detach()
)
aux_recon_loss = (
(via_gate_reconstruction - step_input.sae_in).pow(2).sum(dim=-1).mean()
)
Expand Down
1 change: 1 addition & 0 deletions sae_lens/saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class SAEConfig(ABC):
] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
reshape_activations: Literal["none", "hook_z"] = "none"
metadata: SAEMetadata = field(default_factory=SAEMetadata)
normalize_decoder: bool = False

@classmethod
@abstractmethod
Expand Down
7 changes: 1 addition & 6 deletions tests/refactor_compatibility/test_gated_sae_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,4 @@ def test_gated_training_equivalence(): # type: ignore
atol=1e-5,
msg="Output differs between old and new Gated implementation",
)
assert_close(
old_out.loss,
new_out.loss,
atol=1e-5,
msg="Loss differs between old and new Gated implementation",
)
# the losses should no longer be equivalent, since we fixed a bug with the auxiliary reconstruction loss
22 changes: 22 additions & 0 deletions tests/saes/test_gated_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,25 @@ def test_GatedTrainingSAE_save_and_load_inference_sae(tmp_path: Path) -> None:
training_full_out = training_sae(sae_in)
inference_full_out = inference_sae(sae_in)
assert_close(training_full_out, inference_full_out)


def test_GatedTrainingSAE_auxiliary_reconstruction_loss_does_not_apply_gradient_to_decoder_weights():
cfg = build_gated_sae_training_cfg()
sae = GatedTrainingSAE.from_dict(cfg.to_dict())

aux_losses = sae.calculate_aux_loss(
step_input=TrainStepInput(
sae_in=torch.randn(10, cfg.d_in),
coefficients={"l1": 1.0},
dead_neuron_mask=None,
),
feature_acts=torch.randn(10, cfg.d_sae),
hidden_pre=torch.randn(10, cfg.d_sae),
sae_out=torch.randn(10, cfg.d_in),
)
aux_losses["auxiliary_reconstruction_loss"].backward()

assert sae.W_dec.grad is None or sae.W_dec.grad.sum() == 0.0
assert sae.b_dec.grad is None or sae.b_dec.grad.sum() == 0.0
assert sae.W_enc.grad is not None and sae.W_enc.grad.sum() != 0.0
assert sae.b_gate.grad is not None and sae.b_gate.grad.sum() != 0.0