Question
Is this to faithfully follow a paper?
For context: I am new to SAE, sae_lens library, and I have never tried dictionary_learning library
For context 2: I have a problem with sparsity/dead_features metric being too high after I change sae_lens library implementation to use .detach() in via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec Code
For context 3: I noticed that sparsity/dead_feature metric do not decrease over the course of the training and I can't find resampling neurons logic in sae_lens library. So, now I am looking at dictionary_learning implementation on how to implement resampling neurons
GatedSAETrainer
|
class GatedSAETrainer(SAETrainer): |
GatedAnnealTrainer
|
def resample_neurons(self, deads, activations): |
|
with t.no_grad(): |
|
if deads.sum() == 0: return |
|
print(f"resampling {deads.sum().item()} neurons") |
|
|
|
# compute loss for each activation |
|
losses = (activations - self.ae(activations)).norm(dim=-1) |
|
|
|
# sample input to create encoder/decoder weights from |
|
n_resample = min([deads.sum(), losses.shape[0]]) |
|
indices = t.multinomial(losses, num_samples=n_resample, replacement=False) |
|
sampled_vecs = activations[indices] |
|
|
|
# reset encoder/decoder weights for dead neurons |
|
alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() |
|
self.ae.encoder.weight[deads][:n_resample] = sampled_vecs * alive_norm * 0.2 |
|
self.ae.decoder.weight[:,deads][:,:n_resample] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T |
|
self.ae.encoder.bias[deads][:n_resample] = 0. |
|
|
|
|
|
# reset Adam parameters for dead neurons |
|
state_dict = self.optimizer.state_dict()['state'] |
|
## encoder weight |
|
state_dict[1]['exp_avg'][deads] = 0. |
|
state_dict[1]['exp_avg_sq'][deads] = 0. |
|
## encoder bias |
|
state_dict[2]['exp_avg'][deads] = 0. |
|
state_dict[2]['exp_avg_sq'][deads] = 0. |
|
## decoder weight |
|
state_dict[3]['exp_avg'][:,deads] = 0. |
|
state_dict[3]['exp_avg_sq'][:,deads] = 0. |
StandardTrainer
|
def resample_neurons(self, deads, activations): |
|
with t.no_grad(): |
|
if deads.sum() == 0: return |
|
print(f"resampling {deads.sum().item()} neurons") |
|
|
|
# compute loss for each activation |
|
losses = (activations - self.ae(activations)).norm(dim=-1) |
|
|
|
# sample input to create encoder/decoder weights from |
|
n_resample = min([deads.sum(), losses.shape[0]]) |
|
indices = t.multinomial(losses, num_samples=n_resample, replacement=False) |
|
sampled_vecs = activations[indices] |
|
|
|
# get norm of the living neurons |
|
alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() |
|
|
|
# resample first n_resample dead neurons |
|
deads[deads.nonzero()[n_resample:]] = False |
|
self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2 |
|
self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T |
|
self.ae.encoder.bias[deads] = 0. |
|
|
|
|
|
# reset Adam parameters for dead neurons |
|
state_dict = self.optimizer.state_dict()['state'] |
|
## encoder weight |
|
state_dict[1]['exp_avg'][deads] = 0. |
|
state_dict[1]['exp_avg_sq'][deads] = 0. |
|
## encoder bias |
|
state_dict[2]['exp_avg'][deads] = 0. |
|
state_dict[2]['exp_avg_sq'][deads] = 0. |
|
## decoder weight |
|
state_dict[3]['exp_avg'][:,deads] = 0. |
|
state_dict[3]['exp_avg_sq'][:,deads] = 0. |
Question
Is this to faithfully follow a paper?
For context: I am new to SAE,
sae_lenslibrary, and I have never trieddictionary_learninglibraryFor context 2: I have a problem with
sparsity/dead_featuresmetric being too high after I changesae_lenslibrary implementation to use.detach()invia_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_decCodeFor context 3: I noticed that
sparsity/dead_featuremetric do not decrease over the course of the training and I can't find resampling neurons logic insae_lenslibrary. So, now I am looking atdictionary_learningimplementation on how to implement resampling neuronsGatedSAETrainerdictionary_learning/dictionary_learning/trainers/gdm.py
Line 13 in 60ec6bf
GatedAnnealTrainerdictionary_learning/dictionary_learning/trainers/gated_anneal.py
Lines 105 to 135 in 60ec6bf
StandardTrainerdictionary_learning/dictionary_learning/trainers/standard.py
Lines 76 to 109 in 60ec6bf