fix: handle LayerNorm folding correctly in load_and_process_state_dict#1215
fix: handle LayerNorm folding correctly in load_and_process_state_dict#1215VedantMadane wants to merge 3 commits intoTransformerLensOrg:devfrom
Conversation
Previously, calling load_and_process_state_dict(state_dict, fold_ln=True) had two failure modes: 1. If the state_dict had unfolded LN weights, fold_layer_norm removed the LN keys but the model's modules were not replaced with LNPre, leaving mismatched architecture and broken hooks. 2. If the state_dict was already folded (no LN keys), fold_layer_norm crashed with a KeyError trying to access missing LN weight keys. Fix both by: - Checking whether LN keys exist before attempting to fold (skip with warning if already folded) - Replacing LN/RMS modules with LNPre/RMSPre before folding, matching the logic previously only in process_weights_ - Calling self.setup() after loading to re-attach hooks - Simplifying process_weights_ to delegate fully to the fixed method Fixes TransformerLensOrg#219 Signed-off-by: Vedant Madane <6527493+VedantMadane@users.noreply.github.com>
|
In the future, please run For a change like this that is meant to address specific failures, it is important to add new Unit Tests covering these previous failed states (to prevent regression) and any new features/functionalities you add (to make sure they don't break in the future). |
6ab9305 to
5efeb3f
Compare
|
Apologies for the formatting noise — the diff is now cleaned up to contain only the logic changes. No auto-formatter reformatting of existing lines. |
Thank you! I will take a look at reviewing this soon |
Fixes #219
Problem
load_and_process_state_dict(state_dict, fold_ln=True)has two failure modes when called directly (outside thefrom_pretrainedpath):Unfolded state dict:
fold_layer_normcorrectly removes LN keys from the state dict, but the model'sLayerNormmodules are never replaced withLayerNormPre. This leaves mismatched architecture (modules expectw/bparams that no longer exist) and broken hooks after loading.Already-folded state dict:
fold_layer_normcrashes withKeyErrorbecause it tries to accessblocks.{l}.ln1.wkeys that were already removed when the model was saved.The workaround (from the issue) was a 3-step dance:
Fix
Detect already-folded state dicts: Check if LN weight keys (
.ln1.w,.ln2.w,ln_final.w) exist before folding. If missing, skip with a warning instead of crashing.Replace LN modules when folding: Move the
LayerNorm->LayerNormPre(andRMSNorm->RMSNormPre) module replacement fromprocess_weights_intoload_and_process_state_dict, so direct callers get the same correct behavior.Re-attach hooks: Call
self.setup()after loading whenfold_ln=Trueto ensure hooks are properly connected to the new modules.Simplify
process_weights_: Since the module replacement now lives inload_and_process_state_dict,process_weights_can simply delegate without duplicating the logic.What this enables