Skip to content

[Bug Report] load_and_process_state_dict handles LayerNorm folding poorly #219

@afspies

Description

@afspies

Describe the bug
If one attempts to load the state_dict of a model which was saved without folded LayerNorms (i.e. without LayerNormPre) calling load_and_process_state_dict(state_dict, fold_ln=True) fails due to the strict use of load_state_dict. This can be circumvented by instead doing:

model = HookedTransformer(model_cfg)
model.load_and_process_state_dict(state_dict, fold_ln=False)
model.process_weights_(fold_ln=True)
model.setup()

Without calling model.setup() the LayerNorm hooks remain inside the model, but are not properly attached and thus suitable activations are not returned when doing run_with_cache, causing issues in ActivationCache manipulation helpers.

Additionally, if the original model was saved with folded layernorms, calling load_and_process_state_dict(state_dict, fold_ln=True) raises an error as no layernorm parameters and located in the state dict.

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions