Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 38 additions & 32 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,16 +1618,44 @@ def load_and_process_state_dict(
logging.warning(
"You are using MoE, so the layer norm weights can't be folded! Skipping"
)
elif self.cfg.normalization_type in ["LN", "LNPre"]:
state_dict = self.fold_layer_norm(state_dict)
elif self.cfg.normalization_type in ["RMS", "RMSPre"]:
state_dict = self.fold_layer_norm(
state_dict, fold_biases=False, center_weights=False
)
else:
elif self.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]:
logging.warning(
"You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping"
)
else:
ln_keys_present = any(
k.endswith((".ln1.w", ".ln2.w", "ln_final.w")) for k in state_dict
)
if not ln_keys_present:
logging.warning(
"fold_ln=True but no LayerNorm weights found in state_dict. "
"The model may have been saved with already-folded LayerNorms. "
"Skipping fold."
)
else:
if self.cfg.normalization_type == "LN":
self.cfg.normalization_type = "LNPre"
self.ln_final = LayerNormPre(self.cfg)
for layer in self.blocks:
layer.ln1 = LayerNormPre(self.cfg)
layer.ln2 = LayerNormPre(self.cfg)
if self.cfg.is_layer_norm_activation():
layer.mlp.ln = LayerNormPre(self.cfg)
elif self.cfg.normalization_type == "RMS":
self.cfg.normalization_type = "RMSPre"
self.ln_final = RMSNormPre(self.cfg)
for layer in self.blocks:
layer.ln1 = RMSNormPre(self.cfg)
layer.ln2 = RMSNormPre(self.cfg)
if self.cfg.is_layer_norm_activation():
layer.mlp.ln = RMSNormPre(self.cfg)

if self.cfg.normalization_type in ["LNPre"]:
state_dict = self.fold_layer_norm(state_dict)
elif self.cfg.normalization_type in ["RMSPre"]:
state_dict = self.fold_layer_norm(
state_dict, fold_biases=False, center_weights=False
)

if center_writing_weights:
if self.cfg.normalization_type not in ["LN", "LNPre"]:
Expand Down Expand Up @@ -1658,6 +1686,9 @@ def load_and_process_state_dict(
self.load_state_dict({key: state_dict[key]}, strict=False)
del state_dict[key]

if fold_ln:
self.setup()

def fill_missing_keys(self, state_dict):
return loading.fill_missing_keys(self, state_dict)

Expand Down Expand Up @@ -2052,31 +2083,6 @@ def process_weights_(
version of the same model.
"""
state_dict = self.state_dict()
if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1:
# If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing
# A warning is already issued in `load_and_process_state_dict`
pass
elif fold_ln and self.cfg.normalization_type == "LN":
# If we're folding the LN into the weights, we need to replace all the layernorm layers
# with LayerNormPres, which do not have learnable parameters. This is somewhat hacky,
# but it's the easiest way to do it.
self.cfg.normalization_type = "LNPre"
self.ln_final = LayerNormPre(self.cfg)
for layer in self.blocks:
layer.ln1 = LayerNormPre(self.cfg)
layer.ln2 = LayerNormPre(self.cfg)
if self.cfg.is_layer_norm_activation():
layer.mlp.ln = LayerNormPre(self.cfg)
elif fold_ln and self.cfg.normalization_type == "RMS":
# We do the same for RMSNorm if used
self.cfg.normalization_type = "RMSPre"
self.ln_final = RMSNormPre(self.cfg)
for layer in self.blocks:
layer.ln1 = RMSNormPre(self.cfg)
layer.ln2 = RMSNormPre(self.cfg)
if self.cfg.is_layer_norm_activation():
layer.mlp.ln = RMSNormPre(self.cfg)

self.load_and_process_state_dict(
state_dict,
fold_ln=fold_ln,
Expand Down
Loading