Skip to content

[Question] Residual Peeking in TopK and BatchTopK auxiliary losses #666

@xXCoolinXx

Description

@xXCoolinXx

Questions

I had a question about the particular auxiliary loss implemented for the TopK and BatchTopK SAEs.

The specific implementation used by this repo is

def calculate_topk_aux_loss(
        self,
        sae_in: torch.Tensor,
        sae_out: torch.Tensor,
        hidden_pre: torch.Tensor,
        dead_neuron_mask: torch.Tensor | None,
    ) -> torch.Tensor:
        """
        Calculate TopK auxiliary loss.

        This auxiliary loss encourages dead neurons to learn useful features by having
        them reconstruct the residual error from the live neurons. It's a key part of
        preventing neuron death in TopK SAEs.
        """
        # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
        # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
        if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
            return sae_out.new_tensor(0.0)

        if self.cfg.normalize_activations in ("constant_norm_rescale", "layer_norm"):
            raise ValueError(
                "TopK auxiliary loss does not support activation normalization "
                f"(normalize_activations={self.cfg.normalize_activations!r}). "
                "The aux loss reconstruction would be in normalized space while the "
                "residual is in the original space, producing incorrect gradients."
            )

        residual = (sae_in - sae_out).detach()

        # Heuristic from Appendix B.1 in the paper
        k_aux = sae_in.shape[-1] // 2

        # Reduce the scale of the loss if there are a small number of dead latents
        scale = min(num_dead / k_aux, 1.0)
        k_aux = min(k_aux, num_dead)

        auxk_acts = calculate_topk_aux_acts(
            k_aux=k_aux,
            hidden_pre=hidden_pre,
            dead_neuron_mask=dead_neuron_mask,
        )

        # Encourage the top ~50% of dead latents to predict the residual of the
        # top k living latents. Per the paper (Appendix A.2), the reconstruction
        # is ê = W_dec @ z (no bias), since b_dec is already in the residual.
        recons = act_times_W_dec(
            auxk_acts, self.W_dec, self.cfg.rescale_acts_by_decoder_norm
        )
        # Apply the same reshaping as decode() so recons matches the residual's shape
        recons = self.reshape_fn_out(recons, self.d_head)
        auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
        return self.cfg.aux_loss_coefficient * scale * auxk_loss

    @override
    def process_state_dict_for_saving_inference(
        self, state_dict: dict[str, Any]
    ) -> None:
        super().process_state_dict_for_saving_inference(state_dict)
        if self.cfg.rescale_acts_by_decoder_norm:
            _fold_norm_topk(
                W_enc=state_dict["W_enc"],
                b_enc=state_dict["b_enc"],
                W_dec=state_dict["W_dec"],
            )

More simply, what this code is doing is selecting the top $k//2$ dead latent variables and providing a gradient signal based on reconstructing the residual.

The issue is that, if for example a latent or group of latents is undertrained, it always produces a high residual. We want those original latents to be solely responsible for whatever feature they are modeling (linear direction, manifold, whatever), but it is possible that latents chosen by the auxiliary loss essentially become error correcting terms to reduce this high residual, making it so that neither the original latent(s) modeling the feature nor the error correcting term fully describe the feature on their own.

My main question is if there is anything that could be done to fix this. I think one obvious solution is just using a JumpReLU SAE, since it relies on a pre-activation loss for not firing, rather than using an auxiliary loss based on the residual.

But this is undesirable to me, for I've found that, once I make any form of architectural modification to the SAE, the JumpReLU hyperparameters need to be completely retuned (and sometimes the function itself needs to be modified). BatchTopK is nice in that it is flexible enough to be used with a much wider variety of architectural modifications without significant tuning.

I'm curious if there might be some straight through estimator style approach for TopK/BatchTopK that can be used to selectively apply a small gradient signal that is more "natural" and doesn't lead to this error correction issue. And also whether anyone thinks this is an actual issue or if I'm misinterpreting something.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions