Summary
The MPS (Metal Performance Shaders) backend on Apple Silicon produces silently incorrect inference results for GPT-2 small in TransformerLens. The outputs are numerically wrong with no errors or warnings. This likely affects anyone running TransformerLens on a Mac.
Reproduction
import torch
from transformer_lens import HookedTransformer
for device in ["cpu", "mps"]:
model = HookedTransformer.from_pretrained("gpt2", device=device)
model.eval()
prompt = "When Mary and John went to the store, John gave a drink to"
tokens = model.to_tokens(prompt)
logits = model(tokens)
last_pos = tokens.shape[1] - 1
mary_id = model.to_tokens(" Mary", prepend_bos=False)[0, 0].item()
john_id = model.to_tokens(" John", prepend_bos=False)[0, 0].item()
probs = torch.softmax(logits[0, last_pos], dim=-1)
p_mary = probs[mary_id].item()
p_john = probs[john_id].item()
ld = logits[0, last_pos, mary_id].item() - logits[0, last_pos, john_id].item()
print(f"\n[{device.upper()}]")
print(f" P(Mary) = {p_mary:.4f}")
print(f" P(John) = {p_john:.4f}")
print(f" Logit Difference (Mary - John) = {ld:.3f}")
print(f" Top-1 prediction: '{model.to_string(logits[0, last_pos].argmax().item())}'")
Expected Output (CPU β correct)
[CPU] P(Mary)=0.6770 P(John)=0.0235 LD=+3.362 Top1=' Mary'
Actual Output (MPS β wrong)
[MPS] P(Mary)=0.0000 P(John)=0.0001 LD=-0.994 Top1=' the'
Impact
| Metric |
CPU (correct) |
MPS (wrong) |
| P(Mary) |
67.7% |
0.00% |
| P(John) |
2.4% |
0.01% |
| Logit Difference |
+3.36 |
-0.99 |
| Top-1 prediction |
" Mary" |
" the" |
| IOI accuracy |
100% |
0% |
The logit difference is not merely attenuated β it is inverted (positive β negative). The model's probability on the correct answer drops from 67.7% to 0.00%.
Why This Matters
This is especially dangerous for mechanistic interpretability research because:
-
The errors are silent. No exceptions, no warnings, no NaN/Inf values. The model produces fluent, plausible-looking logits β they're just wrong.
-
Activation patching results are invalidated. Since base model inference is wrong, all downstream analyses (logit lens, DLA, circuit discovery, causal tracing) produce meaningless results. We ran ~100 patching experiments before discovering this.
-
Many MI researchers use Macs. TransformerLens is popular among independent researchers, many of whom develop on Apple Silicon without access to CUDA GPUs. They may be unknowingly producing invalid results.
-
It affects the canonical benchmark. The IOI task (Wang et al., 2023) is the most widely replicated result in mechanistic interpretability. If it silently fails on MPS, less well-known tasks will too.
Root Cause
This is a PyTorch MPS issue, not a TransformerLens issue per se. PyTorch has dozens of open module: correctness (silent) + module: mps issues reporting silently incorrect results for various operations including torch.where(), torch.triu(), torch.stack(), GELU, Conv1d, indexing, and random number generation.
Suggested Actions
- Add an MPS warning to TransformerLens. When
device="mps" is detected, emit a prominent warning.
- Add an MPS validation check. Run the IOI sanity check and warn if the logit difference is negative.
- Document this in the README/docs.
- Consider defaulting to CPU on Apple Silicon (or at least not auto-selecting MPS).
Environment
- macOS: 26.2 (Sequoia)
- Python: 3.9.6
- PyTorch: 2.8.0
- TransformerLens: 2.17.0
- Hardware: Mac mini, Apple M4
Related PyTorch Issues
Summary
The MPS (Metal Performance Shaders) backend on Apple Silicon produces silently incorrect inference results for GPT-2 small in TransformerLens. The outputs are numerically wrong with no errors or warnings. This likely affects anyone running TransformerLens on a Mac.
Reproduction
Expected Output (CPU β correct)
Actual Output (MPS β wrong)
Impact
The logit difference is not merely attenuated β it is inverted (positive β negative). The model's probability on the correct answer drops from 67.7% to 0.00%.
Why This Matters
This is especially dangerous for mechanistic interpretability research because:
The errors are silent. No exceptions, no warnings, no NaN/Inf values. The model produces fluent, plausible-looking logits β they're just wrong.
Activation patching results are invalidated. Since base model inference is wrong, all downstream analyses (logit lens, DLA, circuit discovery, causal tracing) produce meaningless results. We ran ~100 patching experiments before discovering this.
Many MI researchers use Macs. TransformerLens is popular among independent researchers, many of whom develop on Apple Silicon without access to CUDA GPUs. They may be unknowingly producing invalid results.
It affects the canonical benchmark. The IOI task (Wang et al., 2023) is the most widely replicated result in mechanistic interpretability. If it silently fails on MPS, less well-known tasks will too.
Root Cause
This is a PyTorch MPS issue, not a TransformerLens issue per se. PyTorch has dozens of open
module: correctness (silent)+module: mpsissues reporting silently incorrect results for various operations includingtorch.where(),torch.triu(),torch.stack(), GELU, Conv1d, indexing, and random number generation.Suggested Actions
device="mps"is detected, emit a prominent warning.Environment
Related PyTorch Issues