Skip to content

Commit b0ef355

Browse files
authored
Added additional bridge analysis tools (#1237)
1 parent 519ae6c commit b0ef355

File tree

2 files changed

+319
-0
lines changed

2 files changed

+319
-0
lines changed
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""Tests for TransformerBridge mechanistic interpretability analysis methods.
2+
3+
Tests tokens_to_residual_directions, accumulated_bias, all_composition_scores,
4+
all_head_labels, and top-level W_E/W_U/b_U properties. Validates against
5+
HookedTransformer for correctness, not just shape/type.
6+
7+
Uses distilgpt2 (CI-cached).
8+
"""
9+
10+
import pytest
11+
import torch
12+
13+
from transformer_lens import HookedTransformer
14+
from transformer_lens.model_bridge.bridge import TransformerBridge
15+
16+
17+
@pytest.fixture(scope="module")
18+
def bridge_compat():
19+
b = TransformerBridge.boot_transformers("distilgpt2", device="cpu")
20+
b.enable_compatibility_mode()
21+
return b
22+
23+
24+
@pytest.fixture(scope="module")
25+
def reference_ht():
26+
return HookedTransformer.from_pretrained("distilgpt2", device="cpu")
27+
28+
29+
class TestTopLevelWeightProperties:
30+
"""Test W_E, W_U, b_U delegate to the correct component tensors."""
31+
32+
def test_W_E_is_same_object_as_embed(self, bridge_compat):
33+
"""bridge.W_E should be the exact same tensor as bridge.embed.W_E."""
34+
assert bridge_compat.W_E is bridge_compat.embed.W_E
35+
36+
def test_W_U_equals_unembed(self, bridge_compat):
37+
"""bridge.W_U should equal bridge.unembed.W_U (may be a view/transpose)."""
38+
assert torch.equal(bridge_compat.W_U, bridge_compat.unembed.W_U)
39+
40+
def test_b_U_equals_unembed(self, bridge_compat):
41+
"""bridge.b_U should equal bridge.unembed.b_U."""
42+
assert torch.equal(bridge_compat.b_U, bridge_compat.unembed.b_U)
43+
44+
def test_W_E_matches_hooked_transformer(self, bridge_compat, reference_ht):
45+
"""bridge.W_E values should match HookedTransformer.W_E."""
46+
assert bridge_compat.W_E.shape == reference_ht.W_E.shape
47+
# After weight processing, embeddings may differ due to centering.
48+
# But shapes must match and both must be non-zero.
49+
assert bridge_compat.W_E.std() > 0
50+
assert reference_ht.W_E.std() > 0
51+
52+
def test_W_U_matches_hooked_transformer(self, bridge_compat, reference_ht):
53+
"""bridge.W_U values should match HookedTransformer.W_U."""
54+
assert bridge_compat.W_U.shape == reference_ht.W_U.shape
55+
max_diff = (bridge_compat.W_U - reference_ht.W_U).abs().max().item()
56+
assert max_diff < 1e-4, f"W_U differs by {max_diff}"
57+
58+
59+
class TestTokensToResidualDirections:
60+
"""Test tokens_to_residual_directions produces correct unembedding vectors."""
61+
62+
def test_single_token_string(self, bridge_compat):
63+
"""String token should return a 1-D vector of size d_model."""
64+
rd = bridge_compat.tokens_to_residual_directions("hello")
65+
assert rd.shape == (bridge_compat.cfg.d_model,)
66+
67+
def test_single_token_int(self, bridge_compat):
68+
"""Integer token should return a 1-D vector of size d_model."""
69+
rd = bridge_compat.tokens_to_residual_directions(100)
70+
assert rd.shape == (bridge_compat.cfg.d_model,)
71+
72+
def test_equals_W_U_column(self, bridge_compat):
73+
"""Result should be exactly the corresponding column of W_U."""
74+
token_id = 42
75+
rd = bridge_compat.tokens_to_residual_directions(token_id)
76+
expected = bridge_compat.W_U[:, token_id]
77+
assert torch.equal(rd, expected)
78+
79+
def test_batch_tokens(self, bridge_compat):
80+
"""1-D tensor of tokens should return (n_tokens, d_model)."""
81+
tokens = torch.tensor([100, 200, 300])
82+
rd = bridge_compat.tokens_to_residual_directions(tokens)
83+
assert rd.shape == (3, bridge_compat.cfg.d_model)
84+
# Each row should match the corresponding W_U column
85+
for i, tok in enumerate(tokens):
86+
assert torch.equal(rd[i], bridge_compat.W_U[:, tok])
87+
88+
def test_matches_hooked_transformer(self, bridge_compat, reference_ht):
89+
"""Output should match HookedTransformer for the same tokens."""
90+
tokens = torch.tensor([10, 20, 30])
91+
bridge_rd = bridge_compat.tokens_to_residual_directions(tokens)
92+
ht_rd = reference_ht.tokens_to_residual_directions(tokens)
93+
max_diff = (bridge_rd - ht_rd).abs().max().item()
94+
assert max_diff < 1e-4, f"Residual directions differ by {max_diff}"
95+
96+
97+
class TestAccumulatedBias:
98+
"""Test accumulated_bias sums biases correctly."""
99+
100+
def test_layer_zero_is_zeros(self, bridge_compat):
101+
"""accumulated_bias(0) should be all zeros (no layers processed)."""
102+
ab = bridge_compat.accumulated_bias(0)
103+
assert ab.shape == (bridge_compat.cfg.d_model,)
104+
assert torch.allclose(ab, torch.zeros_like(ab))
105+
106+
def test_layer_one_includes_first_block(self, bridge_compat):
107+
"""accumulated_bias(1) should include block 0's biases and be non-zero."""
108+
ab = bridge_compat.accumulated_bias(1)
109+
assert ab.shape == (bridge_compat.cfg.d_model,)
110+
# distilgpt2 has biases, so this should be non-zero
111+
assert ab.norm() > 0
112+
113+
def test_monotonically_increasing_norm(self, bridge_compat):
114+
"""Accumulated bias norm should generally increase with more layers."""
115+
# Not strictly monotonic, but bias(n_layers) should have larger norm than bias(0)
116+
ab_0 = bridge_compat.accumulated_bias(0)
117+
ab_all = bridge_compat.accumulated_bias(bridge_compat.cfg.n_layers)
118+
assert ab_all.norm() > ab_0.norm()
119+
120+
def test_matches_hooked_transformer(self, bridge_compat, reference_ht):
121+
"""Output should match HookedTransformer."""
122+
for layer in [0, 1, 3, bridge_compat.cfg.n_layers]:
123+
bridge_ab = bridge_compat.accumulated_bias(layer)
124+
ht_ab = reference_ht.accumulated_bias(layer)
125+
max_diff = (bridge_ab - ht_ab).abs().max().item()
126+
assert max_diff < 1e-4, f"accumulated_bias({layer}) differs by {max_diff}"
127+
128+
def test_mlp_input_flag(self, bridge_compat, reference_ht):
129+
"""mlp_input=True should include the current layer's attn bias."""
130+
bridge_ab = bridge_compat.accumulated_bias(1, mlp_input=True)
131+
ht_ab = reference_ht.accumulated_bias(1, mlp_input=True)
132+
max_diff = (bridge_ab - ht_ab).abs().max().item()
133+
assert max_diff < 1e-4, f"accumulated_bias(1, mlp_input=True) differs by {max_diff}"
134+
135+
136+
class TestAllCompositionScores:
137+
"""Test all_composition_scores produces correct composition score matrices."""
138+
139+
def test_shape(self, bridge_compat):
140+
"""Shape should be (n_layers, n_heads, n_layers, n_heads)."""
141+
cfg = bridge_compat.cfg
142+
scores = bridge_compat.all_composition_scores("Q")
143+
assert scores.shape == (cfg.n_layers, cfg.n_heads, cfg.n_layers, cfg.n_heads)
144+
145+
def test_upper_triangular_masking(self, bridge_compat):
146+
"""Scores should be zero where left_layer >= right_layer."""
147+
scores = bridge_compat.all_composition_scores("Q")
148+
n_layers = bridge_compat.cfg.n_layers
149+
for l1 in range(n_layers):
150+
for l2 in range(l1 + 1): # l2 <= l1
151+
assert (
152+
scores[l1, :, l2, :] == 0
153+
).all(), f"Scores at L{l1}->L{l2} should be zero (upper triangular)"
154+
155+
def test_nonzero_above_diagonal(self, bridge_compat):
156+
"""At least some scores above the diagonal should be non-zero."""
157+
scores = bridge_compat.all_composition_scores("Q")
158+
# Check L0 -> L1 (first above-diagonal block)
159+
assert scores[0, :, 1, :].abs().sum() > 0
160+
161+
def test_all_modes_work(self, bridge_compat):
162+
"""Q, K, V modes should all produce valid tensors."""
163+
for mode in ["Q", "K", "V"]:
164+
scores = bridge_compat.all_composition_scores(mode)
165+
assert not torch.isnan(scores).any(), f"NaN in {mode} composition scores"
166+
167+
def test_invalid_mode_raises(self, bridge_compat):
168+
"""Invalid mode should raise ValueError."""
169+
with pytest.raises(ValueError, match="mode must be one of"):
170+
bridge_compat.all_composition_scores("X")
171+
172+
173+
class TestAllHeadLabels:
174+
"""Test all_head_labels produces correct labels."""
175+
176+
def test_count(self, bridge_compat):
177+
"""Should have n_layers * n_heads labels."""
178+
labels = bridge_compat.all_head_labels
179+
expected = bridge_compat.cfg.n_layers * bridge_compat.cfg.n_heads
180+
assert len(labels) == expected
181+
182+
def test_format(self, bridge_compat):
183+
"""Labels should follow L{layer}H{head} format."""
184+
labels = bridge_compat.all_head_labels
185+
assert labels[0] == "L0H0"
186+
assert labels[1] == "L0H1"
187+
assert labels[bridge_compat.cfg.n_heads] == "L1H0"
188+
189+
def test_matches_hooked_transformer(self, bridge_compat, reference_ht):
190+
"""Should match HookedTransformer's labels exactly."""
191+
assert bridge_compat.all_head_labels == reference_ht.all_head_labels()

transformer_lens/model_bridge/bridge.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,21 @@ def b_out(self) -> torch.Tensor:
11031103
"""Stack the MLP output biases across all layers."""
11041104
return self._stack_block_params("mlp.b_out")
11051105

1106+
@property
1107+
def W_U(self) -> torch.Tensor:
1108+
"""Unembedding matrix (d_model, d_vocab). Maps residual stream to logits."""
1109+
return self.unembed.W_U
1110+
1111+
@property
1112+
def b_U(self) -> torch.Tensor:
1113+
"""Unembedding bias (d_vocab)."""
1114+
return self.unembed.b_U
1115+
1116+
@property
1117+
def W_E(self) -> torch.Tensor:
1118+
"""Token embedding matrix (d_vocab, d_model)."""
1119+
return self.embed.W_E
1120+
11061121
@property
11071122
def QK(self):
11081123
return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
@@ -1111,6 +1126,119 @@ def QK(self):
11111126
def OV(self):
11121127
return FactoredMatrix(self.W_V, self.W_O)
11131128

1129+
# ------------------------------------------------------------------
1130+
# Mechanistic interpretability analysis methods
1131+
# ------------------------------------------------------------------
1132+
1133+
def tokens_to_residual_directions(
1134+
self,
1135+
tokens: Union[str, int, torch.Tensor],
1136+
) -> torch.Tensor:
1137+
"""Map tokens to their unembedding vectors (residual stream directions).
1138+
1139+
Returns the columns of W_U corresponding to the given tokens — i.e. the
1140+
directions in the residual stream that the model dots with to produce the
1141+
logit for each token.
1142+
1143+
WARNING: If you use this without folding in LayerNorm (compatibility mode),
1144+
the results will be misleading because LN weights change the unembed map.
1145+
1146+
Args:
1147+
tokens: A single token (str, int, or scalar tensor), a 1-D tensor of
1148+
token IDs, or a 2-D batch of token IDs.
1149+
1150+
Returns:
1151+
Tensor of unembedding vectors with shape matching the input token shape
1152+
plus a trailing d_model dimension.
1153+
"""
1154+
if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
1155+
residual_directions = self.W_U[:, tokens]
1156+
residual_directions = einops.rearrange(
1157+
residual_directions, "d_model ... -> ... d_model"
1158+
)
1159+
return residual_directions
1160+
else:
1161+
if isinstance(tokens, str):
1162+
token = self.to_single_token(tokens)
1163+
elif isinstance(tokens, int):
1164+
token = tokens
1165+
elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1:
1166+
token = int(tokens.item())
1167+
else:
1168+
raise ValueError(f"Invalid token type: {type(tokens)}")
1169+
residual_direction = self.W_U[:, token]
1170+
return residual_direction
1171+
1172+
def accumulated_bias(
1173+
self,
1174+
layer: int,
1175+
mlp_input: bool = False,
1176+
include_mlp_biases: bool = True,
1177+
) -> torch.Tensor:
1178+
"""Sum of attention and MLP output biases up to the input of a given layer.
1179+
1180+
Args:
1181+
layer: Layer number in [0, n_layers]. 0 means no layers, n_layers means all.
1182+
mlp_input: If True, include the attention output bias of the target layer
1183+
(i.e. bias up to the MLP input of that layer).
1184+
include_mlp_biases: Whether to include MLP biases. Useful to set False when
1185+
expanding attn_out into individual heads but keeping mlp_out as-is.
1186+
1187+
Returns:
1188+
Tensor of shape [d_model] with the accumulated bias.
1189+
"""
1190+
accumulated = torch.zeros(self.cfg.d_model, device=self.cfg.device)
1191+
for i in range(layer):
1192+
block = self.blocks[i]
1193+
b_O = getattr(block.attn, "b_O", None)
1194+
if b_O is not None:
1195+
accumulated = accumulated + b_O
1196+
if include_mlp_biases:
1197+
b_out = getattr(block.mlp, "b_out", None)
1198+
if b_out is not None:
1199+
accumulated = accumulated + b_out
1200+
if mlp_input:
1201+
assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer"
1202+
block = self.blocks[layer]
1203+
b_O = getattr(block.attn, "b_O", None)
1204+
if b_O is not None:
1205+
accumulated = accumulated + b_O
1206+
return accumulated
1207+
1208+
def all_composition_scores(self, mode: str) -> torch.Tensor:
1209+
"""Composition scores for all pairs of heads.
1210+
1211+
Returns an (n_layers, n_heads, n_layers, n_heads) tensor that is upper
1212+
triangular on the layer axes (a head can only compose with later heads).
1213+
1214+
See https://transformer-circuits.pub/2021/framework/index.html
1215+
1216+
Args:
1217+
mode: One of "Q", "K", "V" — which composition type to compute.
1218+
"""
1219+
left = self.OV
1220+
if mode == "Q":
1221+
right = self.QK
1222+
elif mode == "K":
1223+
right = self.QK.T
1224+
elif mode == "V":
1225+
right = self.OV
1226+
else:
1227+
raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}")
1228+
1229+
scores = utils.composition_scores(left, right, broadcast_dims=True)
1230+
mask = (
1231+
torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None]
1232+
< torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None]
1233+
)
1234+
scores = torch.where(mask, scores, torch.zeros_like(scores))
1235+
return scores
1236+
1237+
@property
1238+
def all_head_labels(self) -> list[str]:
1239+
"""Human-readable labels for all attention heads, e.g. ['L0H0', 'L0H1', ...]."""
1240+
return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]
1241+
11141242
def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]:
11151243
"""Returns parameters following standard PyTorch semantics.
11161244

0 commit comments

Comments
 (0)