Skip to content

Commit 560ec23

Browse files
committed
Fix type of HookedTransformerConfig.device
This is typed as `Optional[str]` but sometimes returns `torch.device`. Updated the code to just return the `str` instead of wrapping with a device. I'm not confident that every function which takes a device will always be passed a string, so I didn't change functions like warn_if_mps. Found while working on #1219
1 parent 589acd4 commit 560ec23

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

tests/unit/pretrained_weight_conversions/test_apertus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_zero_biases_have_correct_device(self):
183183
"blocks.0.mlp.b_out",
184184
"unembed.b_U",
185185
]:
186-
assert sd[key].device.type == cfg.device.type, f"{key} on wrong device"
186+
assert sd[key].device.type == cfg.device, f"{key} on wrong device"
187187

188188
def test_unembed_shapes(self):
189189
cfg = make_cfg()

transformer_lens/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from dataclasses import dataclass
8-
from typing import Optional
8+
from typing import Optional, Union
99

1010
import torch
1111
import torch.optim as optim
@@ -50,7 +50,7 @@ class HookedTransformerTrainConfig:
5050
max_grad_norm: Optional[float] = None
5151
weight_decay: Optional[float] = None
5252
optimizer_name: str = "Adam"
53-
device: Optional[str] = None
53+
device: Optional[Union[str, torch.device]] = None
5454
warmup_steps: int = 0
5555
save_every: Optional[int] = None
5656
save_dir: Optional[str] = None

transformer_lens/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,9 +1015,9 @@ def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False
10151015
print(f"column mismatch: {col_mismatch}")
10161016

10171017

1018-
def get_device():
1018+
def get_device() -> str:
10191019
if torch.cuda.is_available():
1020-
return torch.device("cuda")
1020+
return "cuda"
10211021
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
10221022
major_version = int(torch.__version__.split(".")[0])
10231023
if major_version >= 2:
@@ -1026,17 +1026,17 @@ def get_device():
10261026
_MPS_MIN_SAFE_TORCH_VERSION is not None
10271027
and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION
10281028
):
1029-
return torch.device("mps")
1029+
return "mps"
10301030
if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1":
1031-
return torch.device("mps")
1031+
return "mps"
10321032
logging.info(
10331033
"MPS device available but not auto-selected due to known correctness issues "
10341034
"(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: "
10351035
"https://github.com/TransformerLensOrg/TransformerLens/issues/1178",
10361036
torch.__version__,
10371037
)
10381038

1039-
return torch.device("cpu")
1039+
return "cpu"
10401040

10411041

10421042
_mps_warned = False
@@ -1051,7 +1051,7 @@ def _torch_version_tuple() -> tuple[int, ...]:
10511051
return tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
10521052

10531053

1054-
def warn_if_mps(device):
1054+
def warn_if_mps(device: Union[str, torch.device]) -> None:
10551055
"""Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set.
10561056
10571057
Automatically suppressed when the installed PyTorch version meets or exceeds

0 commit comments

Comments
 (0)