In the loss for the jumprelu:
|
f = JumpReLUFunction.apply(pre_jump, self.ae.threshold, self.bandwidth) |
|
|
|
active_indices = f.sum(0) > 0 |
|
did_fire = torch.zeros_like(self.num_tokens_since_fired, dtype=torch.bool) |
|
did_fire[active_indices] = True |
|
self.num_tokens_since_fired += x.size(0) |
|
self.num_tokens_since_fired[active_indices] = 0 |
|
self.dead_features = ( |
|
(self.num_tokens_since_fired > self.dead_feature_threshold).sum().item() |
|
) |
|
|
|
recon = self.ae.decode(f) |
|
|
|
recon_loss = (x - recon).pow(2).sum(dim=-1).mean() |
|
l0 = StepFunction.apply(f, self.ae.threshold, self.bandwidth).sum(dim=-1).mean() |
The threshold is applied twice: first in the line 156 and then in the line 170. I think that in the line 170 the StepFunction should be applied to the pre_jump value (this is also how it is done in the collab linked in the docstring as well as in the equation 10 in the paper. While in the forward pass it does not matter, it may affect the pseudoderivative.
In the loss for the jumprelu:
dictionary_learning/dictionary_learning/trainers/jumprelu.py
Lines 156 to 170 in 60ec6bf
The threshold is applied twice: first in the line 156 and then in the line 170. I think that in the line 170 the StepFunction should be applied to the pre_jump value (this is also how it is done in the collab linked in the docstring as well as in the equation 10 in the paper. While in the forward pass it does not matter, it may affect the pseudoderivative.