From e5554e5b45d325ca18953c5b3c6d2adfa3f8c446 Mon Sep 17 00:00:00 2001 From: Neelectric Date: Mon, 28 Oct 2024 18:02:51 +0000 Subject: [PATCH] Forcing torch.load in lenses.py to load the .pt file of the lens to the same device that the model is on using map_location --- tuned_lens/nn/lenses.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tuned_lens/nn/lenses.py b/tuned_lens/nn/lenses.py index 0ef308b..f882334 100644 --- a/tuned_lens/nn/lenses.py +++ b/tuned_lens/nn/lenses.py @@ -276,7 +276,8 @@ def from_unembed_and_pretrained( **{k: v for k, v in kwargs.items() if k not in load_artifact_varnames} } # Load parameters - state = th.load(ckpt_path, **th_load_kwargs) + device = unembed.unembedding.weight.device + state = th.load(ckpt_path, **th_load_kwargs, map_location=device) lens.layer_translators.load_state_dict(state)