Skip to content

Commit 8ff5637

Browse files
authored
deduplicating repeated code and verifying no regressions (#1239)
1 parent b4aac4b commit 8ff5637

File tree

9 files changed

+146
-191
lines changed

9 files changed

+146
-191
lines changed

transformer_lens/model_bridge/generalized_components/attention.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,6 @@ def revert(self, input_value, *full_context):
290290
if hasattr(self, "hook_rot_k"):
291291
self.hook_rot_k.hook_conversion = TransposeRotaryHeads()
292292

293-
def _setup_hook_z_reshape(self) -> None:
294-
"""Backward compatibility alias for _setup_qkv_hook_reshaping."""
295-
self._setup_qkv_hook_reshaping()
296-
297293
def _update_kv_cache(
298294
self, k: torch.Tensor, v: torch.Tensor, **kwargs: Any
299295
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -367,6 +363,45 @@ def _apply_output_projection(self, attn_output: torch.Tensor) -> torch.Tensor:
367363
attn_output = self.o(attn_output)
368364
return attn_output
369365

366+
def _softmax_dropout_pattern(
367+
self,
368+
attn_scores: torch.Tensor,
369+
target_dtype: torch.dtype | None = None,
370+
upcast_to_fp32: bool = False,
371+
) -> torch.Tensor:
372+
"""Apply softmax, dropout, and hook_pattern to attention scores.
373+
374+
Args:
375+
attn_scores: Raw attention scores [batch, heads, q_seq, kv_seq].
376+
target_dtype: If set, cast weights to this dtype after softmax.
377+
upcast_to_fp32: If True, compute softmax in float32 for numerical
378+
stability, then cast to target_dtype.
379+
"""
380+
if upcast_to_fp32:
381+
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32)
382+
if target_dtype is not None:
383+
attn_weights = attn_weights.to(target_dtype)
384+
else:
385+
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)
386+
if target_dtype is not None:
387+
attn_weights = attn_weights.to(target_dtype)
388+
attn_weights = self._apply_attn_dropout(attn_weights)
389+
attn_weights = self.hook_pattern(attn_weights)
390+
return attn_weights
391+
392+
def _reshape_attn_output(
393+
self,
394+
attn_output: torch.Tensor,
395+
batch_size: int,
396+
seq_len: int,
397+
num_heads: int,
398+
head_dim: int,
399+
) -> torch.Tensor:
400+
"""Reshape attention output from [batch, heads, seq, dim] to [batch, seq, hidden]."""
401+
attn_output = attn_output.transpose(1, 2).contiguous()
402+
attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim)
403+
return attn_output
404+
370405
def _apply_reconstruct_attention_mask(
371406
self,
372407
attn_scores: torch.Tensor,

transformer_lens/model_bridge/generalized_components/block.py

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -96,60 +96,76 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
9696
f"Original component not set for {self.name}. Call set_original_component() first."
9797
)
9898

99-
# Check if we should stop before executing this block
100-
# The _stop_at_layer_idx attribute is set by the bridge's forward method
101-
if hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None:
102-
# Extract layer index (supports TL/GPT-2/LLaMA naming patterns)
103-
if self.name is not None:
104-
# Try multiple patterns to extract layer index
105-
match = (
106-
re.search(r"blocks\.(\d+)", self.name)
107-
or re.search(r"\.h\.(\d+)", self.name)
108-
or re.search(r"\.layers\.(\d+)", self.name)
109-
)
110-
else:
111-
match = None
112-
if match:
113-
layer_idx = int(match.group(1))
114-
if layer_idx == self._stop_at_layer_idx:
115-
# Get the input tensor to return
116-
if len(args) > 0 and isinstance(args[0], torch.Tensor):
117-
input_tensor = args[0]
118-
elif "hidden_states" in kwargs and isinstance(
119-
kwargs["hidden_states"], torch.Tensor
120-
):
121-
input_tensor = kwargs["hidden_states"]
122-
else:
123-
raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}")
124-
# Run hook_in on the input before stopping
125-
input_tensor = self.hook_in(input_tensor)
126-
raise StopAtLayerException(input_tensor)
127-
128-
if len(args) > 0 and isinstance(args[0], torch.Tensor):
129-
hooked_input = self.hook_in(args[0])
130-
args = (hooked_input,) + args[1:]
131-
elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
132-
kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
99+
self._check_stop_at_layer(*args, **kwargs)
100+
args, kwargs = self._hook_input_hidden_states(args, kwargs)
133101

134102
# Filter kwargs to only include parameters accepted by the original component
135103
# This prevents errors when passing encoder-specific params to decoder-only models
136104
filtered_kwargs = self._filter_kwargs_for_forward(kwargs, len(args))
137105

138106
output = self.original_component(*args, **filtered_kwargs)
107+
return self._apply_output_hook(output)
108+
109+
def _apply_output_hook(self, output: Any, wrap_single_element: bool = True) -> Any:
110+
"""Hook the primary tensor in the output and return the result.
111+
112+
Args:
113+
output: Raw output from the original component (tensor or tuple).
114+
wrap_single_element: If True, single-element tuples stay as tuples after
115+
hooking (default, required by most HF models). If False, single-element
116+
tuples are unwrapped to a bare tensor (Bloom convention).
117+
"""
139118
if isinstance(output, tuple) and len(output) > 0:
140119
first = output[0]
141120
if isinstance(first, torch.Tensor):
142121
first = self.hook_out(first)
143-
# Always return tuple to maintain consistency with HF's expected format
144-
# e.g. GPT2Model does hidden_states = outputs[0], it expects outputs to be a tuple
145122
if len(output) == 1:
146-
return (first,)
123+
return (first,) if wrap_single_element else first
147124
output = (first,) + output[1:]
148125
return output
149126
if isinstance(output, torch.Tensor):
150127
output = self.hook_out(output)
151128
return output
152129

130+
def _check_stop_at_layer(self, *args: Any, **kwargs: Any) -> None:
131+
"""Check if execution should stop before this block. Raises StopAtLayerException.
132+
133+
The _stop_at_layer_idx attribute is set by the bridge's forward method.
134+
Supports TL/GPT-2/LLaMA naming patterns for layer index extraction.
135+
"""
136+
if not (hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None):
137+
return
138+
if self.name is not None:
139+
match = (
140+
re.search(r"blocks\.(\d+)", self.name)
141+
or re.search(r"\.h\.(\d+)", self.name)
142+
or re.search(r"\.layers\.(\d+)", self.name)
143+
)
144+
else:
145+
match = None
146+
if match:
147+
layer_idx = int(match.group(1))
148+
if layer_idx == self._stop_at_layer_idx:
149+
if len(args) > 0 and isinstance(args[0], torch.Tensor):
150+
input_tensor = args[0]
151+
elif "hidden_states" in kwargs and isinstance(
152+
kwargs["hidden_states"], torch.Tensor
153+
):
154+
input_tensor = kwargs["hidden_states"]
155+
else:
156+
raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}")
157+
input_tensor = self.hook_in(input_tensor)
158+
raise StopAtLayerException(input_tensor)
159+
160+
def _hook_input_hidden_states(self, args: tuple, kwargs: dict) -> tuple[tuple, dict]:
161+
"""Apply hook_in to the hidden_states input, whether in args or kwargs."""
162+
if len(args) > 0 and isinstance(args[0], torch.Tensor):
163+
hooked_input = self.hook_in(args[0])
164+
args = (hooked_input,) + args[1:]
165+
elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
166+
kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
167+
return args, kwargs
168+
153169
def _filter_kwargs_for_forward(
154170
self, kwargs: Dict[str, Any], num_positional_args: int = 0
155171
) -> Dict[str, Any]:

transformer_lens/model_bridge/generalized_components/bloom_attention.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,19 +187,18 @@ def _reconstruct_attention(
187187
attn_scores = self.hook_attn_scores(attn_scores)
188188

189189
# Softmax in float32 for numerical stability (matches HF BLOOM)
190-
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32)
191-
attn_weights = attn_weights.to(q.dtype)
192-
193-
attn_weights = self._apply_attn_dropout(attn_weights)
194-
attn_weights = self.hook_pattern(attn_weights)
190+
attn_weights = self._softmax_dropout_pattern(
191+
attn_scores, target_dtype=q.dtype, upcast_to_fp32=True
192+
)
195193

196194
# bmm in [batch*heads, seq, seq] format for BLOOM compatibility
197195
attn_weights_bh = attn_weights.reshape(batch_size * num_heads, seq_len, -1)
198196
attn_output = torch.bmm(attn_weights_bh, v_bh)
199197

200198
attn_output = attn_output.view(batch_size, num_heads, seq_len, head_dim)
201-
attn_output = attn_output.transpose(1, 2).contiguous()
202-
attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim)
199+
attn_output = self._reshape_attn_output(
200+
attn_output, batch_size, seq_len, num_heads, head_dim
201+
)
203202
attn_output = self._apply_output_projection(attn_output)
204203

205204
return (attn_output, attn_weights)

transformer_lens/model_bridge/generalized_components/bloom_block.py

Lines changed: 3 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def __init__(
3838
"""
3939
super().__init__(name, config, submodules, hook_alias_overrides)
4040
self.config = config
41-
self._alibi_cache: Optional[torch.Tensor] = None
4241

4342
@staticmethod
4443
def build_alibi_tensor(
@@ -111,42 +110,8 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
111110
f"Original component not set for {self.name}. Call set_original_component() first."
112111
)
113112

114-
# Check if we should stop before executing this block
115-
if hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None:
116-
import re
117-
118-
if self.name is not None:
119-
match = (
120-
re.search(r"blocks\.(\d+)", self.name)
121-
or re.search(r"\.h\.(\d+)", self.name)
122-
or re.search(r"\.layers\.(\d+)", self.name)
123-
)
124-
else:
125-
match = None
126-
if match:
127-
layer_idx = int(match.group(1))
128-
if layer_idx == self._stop_at_layer_idx:
129-
if len(args) > 0 and isinstance(args[0], torch.Tensor):
130-
input_tensor = args[0]
131-
elif "hidden_states" in kwargs and isinstance(
132-
kwargs["hidden_states"], torch.Tensor
133-
):
134-
input_tensor = kwargs["hidden_states"]
135-
else:
136-
raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}")
137-
input_tensor = self.hook_in(input_tensor)
138-
from transformer_lens.model_bridge.exceptions import (
139-
StopAtLayerException,
140-
)
141-
142-
raise StopAtLayerException(input_tensor)
143-
144-
# Apply hook_in to hidden_states
145-
if len(args) > 0 and isinstance(args[0], torch.Tensor):
146-
hooked_input = self.hook_in(args[0])
147-
args = (hooked_input,) + args[1:]
148-
elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
149-
kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
113+
self._check_stop_at_layer(*args, **kwargs)
114+
args, kwargs = self._hook_input_hidden_states(args, kwargs)
150115

151116
# BLOOM blocks require 'alibi' and 'attention_mask' arguments.
152117
# If HF's BloomModel.forward() is calling us, these will already be present.
@@ -198,16 +163,4 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
198163

199164
# Call original component
200165
output = self.original_component(*args, **kwargs)
201-
202-
# Apply hook_out
203-
if isinstance(output, tuple) and len(output) > 0:
204-
first = output[0]
205-
if isinstance(first, torch.Tensor):
206-
first = self.hook_out(first)
207-
if len(output) == 1:
208-
return first
209-
output = (first,) + output[1:]
210-
return output
211-
if isinstance(output, torch.Tensor):
212-
output = self.hook_out(output)
213-
return output
166+
return self._apply_output_hook(output, wrap_single_element=False)

transformer_lens/model_bridge/generalized_components/gated_mlp.py

Lines changed: 26 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,7 @@ class GatedMLPBridge(MLPBridge):
6060
"hook_pre_linear": "in.hook_out",
6161
"hook_post": "out.hook_in",
6262
}
63-
property_aliases = {
64-
"W_gate": "gate.weight",
65-
"b_gate": "gate.bias",
66-
"W_in": "in.weight",
67-
"b_in": "in.bias",
68-
"W_out": "out.weight",
69-
"b_out": "out.bias",
70-
}
63+
# property_aliases inherited from MLPBridge (W_gate, b_gate, W_in, b_in, W_out, b_out)
7164

7265
def __init__(
7366
self,
@@ -99,41 +92,33 @@ def forward(self, *args, **kwargs) -> torch.Tensor:
9992
Output hidden states
10093
"""
10194
if hasattr(self, "_use_processed_weights") and self._use_processed_weights:
95+
assert hasattr(self, "_processed_W_gate") and hasattr(self, "_processed_W_in"), (
96+
"Processed weights flag is set but weights are missing. "
97+
"This indicates a bug in set_processed_weights()."
98+
)
99+
assert self._processed_W_in is not None
100+
assert self._processed_W_out is not None
102101
hidden_states = args[0]
103102
hidden_states = self.hook_in(hidden_states)
104-
if hasattr(self, "_processed_W_gate") and hasattr(self, "_processed_W_in"):
105-
assert self._processed_W_in is not None # Guarded by hasattr check above
106-
assert self._processed_W_out is not None
107-
gate_output = torch.nn.functional.linear(
108-
hidden_states, self._processed_W_gate, self._processed_b_gate
109-
)
110-
if hasattr(self, "gate") and hasattr(self.gate, "hook_out"):
111-
gate_output = self.gate.hook_out(gate_output)
112-
linear_output = torch.nn.functional.linear(
113-
hidden_states, self._processed_W_in, self._processed_b_in
114-
)
115-
in_module = getattr(self, "in", None)
116-
if in_module is not None and hasattr(in_module, "hook_out"):
117-
linear_output = in_module.hook_out(linear_output) # type: ignore[misc]
118-
act_fn = resolve_activation_fn(self.config)
119-
activated = act_fn(gate_output)
120-
hidden = activated * linear_output
121-
if hasattr(self, "out") and hasattr(self.out, "hook_in"):
122-
hidden = self.out.hook_in(hidden)
123-
output = torch.nn.functional.linear(
124-
hidden, self._processed_W_out, self._processed_b_out
125-
)
126-
else:
127-
import warnings
128-
129-
warnings.warn(
130-
"Processed weights flag set but weights missing — "
131-
"falling back to original component. "
132-
"Intermediate MLP hooks will not fire.",
133-
stacklevel=2,
134-
)
135-
new_args = (hidden_states,) + args[1:]
136-
output = self.original_component(*new_args, **kwargs) # type: ignore[misc]
103+
gate_output = torch.nn.functional.linear(
104+
hidden_states, self._processed_W_gate, self._processed_b_gate
105+
)
106+
if hasattr(self, "gate") and hasattr(self.gate, "hook_out"):
107+
gate_output = self.gate.hook_out(gate_output)
108+
linear_output = torch.nn.functional.linear(
109+
hidden_states, self._processed_W_in, self._processed_b_in
110+
)
111+
in_module = getattr(self, "in", None)
112+
if in_module is not None and hasattr(in_module, "hook_out"):
113+
linear_output = in_module.hook_out(linear_output) # type: ignore[misc]
114+
act_fn = resolve_activation_fn(self.config)
115+
activated = act_fn(gate_output)
116+
hidden = activated * linear_output
117+
if hasattr(self, "out") and hasattr(self.out, "hook_in"):
118+
hidden = self.out.hook_in(hidden)
119+
output = torch.nn.functional.linear(
120+
hidden, self._processed_W_out, self._processed_b_out
121+
)
137122
output = self.hook_out(output)
138123
return output
139124
if self.original_component is None:

transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,7 @@ class JointQKVAttentionBridge(AttentionBridge):
2727
the individual activations from the separated q, k, and v matrices are hooked and accessible.
2828
"""
2929

30-
# Property aliases point to the linear bridge weights
31-
property_aliases = {
32-
"W_Q": "q.weight",
33-
"W_K": "k.weight",
34-
"W_V": "v.weight",
35-
"W_O": "o.weight",
36-
"b_Q": "q.bias",
37-
"b_K": "k.bias",
38-
"b_V": "v.bias",
39-
"b_O": "o.bias",
40-
}
30+
# property_aliases inherited from AttentionBridge (W_Q, W_K, W_V, W_O, b_Q, b_K, b_V, b_O)
4131

4232
def __init__(
4333
self,
@@ -422,16 +412,13 @@ def _reconstruct_attention(
422412

423413
attn_scores = self.hook_attn_scores(attn_scores)
424414

425-
if reorder_and_upcast:
426-
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)
427-
attn_weights = attn_weights.to(v.dtype)
428-
else:
429-
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)
430-
431-
attn_weights = self._apply_attn_dropout(attn_weights)
432-
attn_weights = self.hook_pattern(attn_weights)
415+
attn_weights = self._softmax_dropout_pattern(
416+
attn_scores,
417+
target_dtype=v.dtype if reorder_and_upcast else None,
418+
)
433419
attn_output = torch.matmul(attn_weights, v)
434-
attn_output = attn_output.transpose(1, 2).contiguous()
435-
attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim)
420+
attn_output = self._reshape_attn_output(
421+
attn_output, batch_size, seq_len, num_heads, head_dim
422+
)
436423
attn_output = self._apply_output_projection(attn_output)
437424
return (attn_output, attn_weights)

0 commit comments

Comments
 (0)