diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 583ce1aea..a03701bce 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1618,16 +1618,44 @@ def load_and_process_state_dict( logging.warning( "You are using MoE, so the layer norm weights can't be folded! Skipping" ) - elif self.cfg.normalization_type in ["LN", "LNPre"]: - state_dict = self.fold_layer_norm(state_dict) - elif self.cfg.normalization_type in ["RMS", "RMSPre"]: - state_dict = self.fold_layer_norm( - state_dict, fold_biases=False, center_weights=False - ) - else: + elif self.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: logging.warning( "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" ) + else: + ln_keys_present = any( + k.endswith((".ln1.w", ".ln2.w", "ln_final.w")) for k in state_dict + ) + if not ln_keys_present: + logging.warning( + "fold_ln=True but no LayerNorm weights found in state_dict. " + "The model may have been saved with already-folded LayerNorms. " + "Skipping fold." + ) + else: + if self.cfg.normalization_type == "LN": + self.cfg.normalization_type = "LNPre" + self.ln_final = LayerNormPre(self.cfg) + for layer in self.blocks: + layer.ln1 = LayerNormPre(self.cfg) + layer.ln2 = LayerNormPre(self.cfg) + if self.cfg.is_layer_norm_activation(): + layer.mlp.ln = LayerNormPre(self.cfg) + elif self.cfg.normalization_type == "RMS": + self.cfg.normalization_type = "RMSPre" + self.ln_final = RMSNormPre(self.cfg) + for layer in self.blocks: + layer.ln1 = RMSNormPre(self.cfg) + layer.ln2 = RMSNormPre(self.cfg) + if self.cfg.is_layer_norm_activation(): + layer.mlp.ln = RMSNormPre(self.cfg) + + if self.cfg.normalization_type in ["LNPre"]: + state_dict = self.fold_layer_norm(state_dict) + elif self.cfg.normalization_type in ["RMSPre"]: + state_dict = self.fold_layer_norm( + state_dict, fold_biases=False, center_weights=False + ) if center_writing_weights: if self.cfg.normalization_type not in ["LN", "LNPre"]: @@ -1658,6 +1686,9 @@ def load_and_process_state_dict( self.load_state_dict({key: state_dict[key]}, strict=False) del state_dict[key] + if fold_ln: + self.setup() + def fill_missing_keys(self, state_dict): return loading.fill_missing_keys(self, state_dict) @@ -2052,31 +2083,6 @@ def process_weights_( version of the same model. """ state_dict = self.state_dict() - if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: - # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing - # A warning is already issued in `load_and_process_state_dict` - pass - elif fold_ln and self.cfg.normalization_type == "LN": - # If we're folding the LN into the weights, we need to replace all the layernorm layers - # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky, - # but it's the easiest way to do it. - self.cfg.normalization_type = "LNPre" - self.ln_final = LayerNormPre(self.cfg) - for layer in self.blocks: - layer.ln1 = LayerNormPre(self.cfg) - layer.ln2 = LayerNormPre(self.cfg) - if self.cfg.is_layer_norm_activation(): - layer.mlp.ln = LayerNormPre(self.cfg) - elif fold_ln and self.cfg.normalization_type == "RMS": - # We do the same for RMSNorm if used - self.cfg.normalization_type = "RMSPre" - self.ln_final = RMSNormPre(self.cfg) - for layer in self.blocks: - layer.ln1 = RMSNormPre(self.cfg) - layer.ln2 = RMSNormPre(self.cfg) - if self.cfg.is_layer_norm_activation(): - layer.mlp.ln = RMSNormPre(self.cfg) - self.load_and_process_state_dict( state_dict, fold_ln=fold_ln,