Skip to content

Commit 9c5a2a8

Browse files
zachgoldfine44Zachary Goldfineclaude
authored
Fix deprecated torch_dtype argument in HuggingFace from_pretrained calls (#1225)
Replace `torch_dtype=dtype` with `dtype=dtype` in all internal calls to HuggingFace's `from_pretrained()` methods in loading_from_pretrained.py. The `torch_dtype` parameter is deprecated in recent versions of the transformers library in favor of `dtype`. The backwards-compatible acceptance of `torch_dtype` from users (lines 2389-2391) is preserved so existing user code continues to work. Fixes #1093 Co-authored-by: Zachary Goldfine <zacharygoldfine@Zacharys-Air.attlocal.net> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d1dc12d commit 9c5a2a8

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

transformer_lens/loading_from_pretrained.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2439,15 +2439,15 @@ def get_pretrained_state_dict(
24392439
hf_model = AutoModelForCausalLM.from_pretrained(
24402440
official_model_name,
24412441
revision=f"checkpoint-{cfg.checkpoint_value}",
2442-
torch_dtype=dtype,
2442+
dtype=dtype,
24432443
token=huggingface_token if len(huggingface_token) > 0 else None,
24442444
**kwargs,
24452445
)
24462446
elif official_model_name.startswith("EleutherAI/pythia"):
24472447
hf_model = AutoModelForCausalLM.from_pretrained(
24482448
official_model_name,
24492449
revision=f"step{cfg.checkpoint_value}",
2450-
torch_dtype=dtype,
2450+
dtype=dtype,
24512451
token=huggingface_token,
24522452
**kwargs,
24532453
)
@@ -2460,28 +2460,28 @@ def get_pretrained_state_dict(
24602460
elif "hubert" in official_model_name:
24612461
hf_model = HubertModel.from_pretrained(
24622462
official_model_name,
2463-
torch_dtype=dtype,
2463+
dtype=dtype,
24642464
token=huggingface_token if len(huggingface_token) > 0 else None,
24652465
**kwargs,
24662466
)
24672467
elif "wav2vec2" in official_model_name:
24682468
hf_model = Wav2Vec2Model.from_pretrained(
24692469
official_model_name,
2470-
torch_dtype=dtype,
2470+
dtype=dtype,
24712471
token=huggingface_token if len(huggingface_token) > 0 else None,
24722472
**kwargs,
24732473
)
24742474
elif "bert" in official_model_name:
24752475
hf_model = BertForPreTraining.from_pretrained(
24762476
official_model_name,
2477-
torch_dtype=dtype,
2477+
dtype=dtype,
24782478
token=huggingface_token if len(huggingface_token) > 0 else None,
24792479
**kwargs,
24802480
)
24812481
elif "t5" in official_model_name:
24822482
hf_model = T5ForConditionalGeneration.from_pretrained(
24832483
official_model_name,
2484-
torch_dtype=dtype,
2484+
dtype=dtype,
24852485
token=huggingface_token if len(huggingface_token) > 0 else None,
24862486
**kwargs,
24872487
)
@@ -2491,14 +2491,14 @@ def get_pretrained_state_dict(
24912491

24922492
hf_model = AutoModel.from_pretrained(
24932493
official_model_name,
2494-
torch_dtype=dtype,
2494+
dtype=dtype,
24952495
token=huggingface_token if len(huggingface_token) > 0 else None,
24962496
**kwargs,
24972497
)
24982498
else:
24992499
hf_model = AutoModelForCausalLM.from_pretrained(
25002500
official_model_name,
2501-
torch_dtype=dtype,
2501+
dtype=dtype,
25022502
token=huggingface_token if len(huggingface_token) > 0 else None,
25032503
**kwargs,
25042504
)

0 commit comments

Comments
 (0)