diff --git a/tuned_lens/nn/lenses.py b/tuned_lens/nn/lenses.py index 0ef308b..f882334 100644 --- a/tuned_lens/nn/lenses.py +++ b/tuned_lens/nn/lenses.py @@ -276,7 +276,8 @@ def from_unembed_and_pretrained( **{k: v for k, v in kwargs.items() if k not in load_artifact_varnames} } # Load parameters - state = th.load(ckpt_path, **th_load_kwargs) + device = unembed.unembedding.weight.device + state = th.load(ckpt_path, **th_load_kwargs, map_location=device) lens.layer_translators.load_state_dict(state)