Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions transformer_lens/model_bridge/generalized_components/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,6 @@ def revert(self, input_value, *full_context):
if hasattr(self, "hook_rot_k"):
self.hook_rot_k.hook_conversion = TransposeRotaryHeads()

def _setup_hook_z_reshape(self) -> None:
"""Backward compatibility alias for _setup_qkv_hook_reshaping."""
self._setup_qkv_hook_reshaping()

def _update_kv_cache(
self, k: torch.Tensor, v: torch.Tensor, **kwargs: Any
) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -367,6 +363,45 @@ def _apply_output_projection(self, attn_output: torch.Tensor) -> torch.Tensor:
attn_output = self.o(attn_output)
return attn_output

def _softmax_dropout_pattern(
self,
attn_scores: torch.Tensor,
target_dtype: torch.dtype | None = None,
upcast_to_fp32: bool = False,
) -> torch.Tensor:
"""Apply softmax, dropout, and hook_pattern to attention scores.

Args:
attn_scores: Raw attention scores [batch, heads, q_seq, kv_seq].
target_dtype: If set, cast weights to this dtype after softmax.
upcast_to_fp32: If True, compute softmax in float32 for numerical
stability, then cast to target_dtype.
"""
if upcast_to_fp32:
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32)
if target_dtype is not None:
attn_weights = attn_weights.to(target_dtype)
else:
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)
if target_dtype is not None:
attn_weights = attn_weights.to(target_dtype)
attn_weights = self._apply_attn_dropout(attn_weights)
attn_weights = self.hook_pattern(attn_weights)
return attn_weights

def _reshape_attn_output(
self,
attn_output: torch.Tensor,
batch_size: int,
seq_len: int,
num_heads: int,
head_dim: int,
) -> torch.Tensor:
"""Reshape attention output from [batch, heads, seq, dim] to [batch, seq, hidden]."""
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim)
return attn_output

def _apply_reconstruct_attention_mask(
self,
attn_scores: torch.Tensor,
Expand Down
90 changes: 53 additions & 37 deletions transformer_lens/model_bridge/generalized_components/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,60 +96,76 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
f"Original component not set for {self.name}. Call set_original_component() first."
)

# Check if we should stop before executing this block
# The _stop_at_layer_idx attribute is set by the bridge's forward method
if hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None:
# Extract layer index (supports TL/GPT-2/LLaMA naming patterns)
if self.name is not None:
# Try multiple patterns to extract layer index
match = (
re.search(r"blocks\.(\d+)", self.name)
or re.search(r"\.h\.(\d+)", self.name)
or re.search(r"\.layers\.(\d+)", self.name)
)
else:
match = None
if match:
layer_idx = int(match.group(1))
if layer_idx == self._stop_at_layer_idx:
# Get the input tensor to return
if len(args) > 0 and isinstance(args[0], torch.Tensor):
input_tensor = args[0]
elif "hidden_states" in kwargs and isinstance(
kwargs["hidden_states"], torch.Tensor
):
input_tensor = kwargs["hidden_states"]
else:
raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}")
# Run hook_in on the input before stopping
input_tensor = self.hook_in(input_tensor)
raise StopAtLayerException(input_tensor)

if len(args) > 0 and isinstance(args[0], torch.Tensor):
hooked_input = self.hook_in(args[0])
args = (hooked_input,) + args[1:]
elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
self._check_stop_at_layer(*args, **kwargs)
args, kwargs = self._hook_input_hidden_states(args, kwargs)

# Filter kwargs to only include parameters accepted by the original component
# This prevents errors when passing encoder-specific params to decoder-only models
filtered_kwargs = self._filter_kwargs_for_forward(kwargs, len(args))

output = self.original_component(*args, **filtered_kwargs)
return self._apply_output_hook(output)

def _apply_output_hook(self, output: Any, wrap_single_element: bool = True) -> Any:
"""Hook the primary tensor in the output and return the result.

Args:
output: Raw output from the original component (tensor or tuple).
wrap_single_element: If True, single-element tuples stay as tuples after
hooking (default, required by most HF models). If False, single-element
tuples are unwrapped to a bare tensor (Bloom convention).
"""
if isinstance(output, tuple) and len(output) > 0:
first = output[0]
if isinstance(first, torch.Tensor):
first = self.hook_out(first)
# Always return tuple to maintain consistency with HF's expected format
# e.g. GPT2Model does hidden_states = outputs[0], it expects outputs to be a tuple
if len(output) == 1:
return (first,)
return (first,) if wrap_single_element else first
output = (first,) + output[1:]
return output
if isinstance(output, torch.Tensor):
output = self.hook_out(output)
return output

def _check_stop_at_layer(self, *args: Any, **kwargs: Any) -> None:
"""Check if execution should stop before this block. Raises StopAtLayerException.

The _stop_at_layer_idx attribute is set by the bridge's forward method.
Supports TL/GPT-2/LLaMA naming patterns for layer index extraction.
"""
if not (hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None):
return
if self.name is not None:
match = (
re.search(r"blocks\.(\d+)", self.name)
or re.search(r"\.h\.(\d+)", self.name)
or re.search(r"\.layers\.(\d+)", self.name)
)
else:
match = None
if match:
layer_idx = int(match.group(1))
if layer_idx == self._stop_at_layer_idx:
if len(args) > 0 and isinstance(args[0], torch.Tensor):
input_tensor = args[0]
elif "hidden_states" in kwargs and isinstance(
kwargs["hidden_states"], torch.Tensor
):
input_tensor = kwargs["hidden_states"]
else:
raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}")
input_tensor = self.hook_in(input_tensor)
raise StopAtLayerException(input_tensor)

def _hook_input_hidden_states(self, args: tuple, kwargs: dict) -> tuple[tuple, dict]:
"""Apply hook_in to the hidden_states input, whether in args or kwargs."""
if len(args) > 0 and isinstance(args[0], torch.Tensor):
hooked_input = self.hook_in(args[0])
args = (hooked_input,) + args[1:]
elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
return args, kwargs

def _filter_kwargs_for_forward(
self, kwargs: Dict[str, Any], num_positional_args: int = 0
) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,19 +187,18 @@ def _reconstruct_attention(
attn_scores = self.hook_attn_scores(attn_scores)

# Softmax in float32 for numerical stability (matches HF BLOOM)
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32)
attn_weights = attn_weights.to(q.dtype)

attn_weights = self._apply_attn_dropout(attn_weights)
attn_weights = self.hook_pattern(attn_weights)
attn_weights = self._softmax_dropout_pattern(
attn_scores, target_dtype=q.dtype, upcast_to_fp32=True
)

# bmm in [batch*heads, seq, seq] format for BLOOM compatibility
attn_weights_bh = attn_weights.reshape(batch_size * num_heads, seq_len, -1)
attn_output = torch.bmm(attn_weights_bh, v_bh)

attn_output = attn_output.view(batch_size, num_heads, seq_len, head_dim)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim)
attn_output = self._reshape_attn_output(
attn_output, batch_size, seq_len, num_heads, head_dim
)
attn_output = self._apply_output_projection(attn_output)

return (attn_output, attn_weights)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(
"""
super().__init__(name, config, submodules, hook_alias_overrides)
self.config = config
self._alibi_cache: Optional[torch.Tensor] = None

@staticmethod
def build_alibi_tensor(
Expand Down Expand Up @@ -111,42 +110,8 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
f"Original component not set for {self.name}. Call set_original_component() first."
)

# Check if we should stop before executing this block
if hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None:
import re

if self.name is not None:
match = (
re.search(r"blocks\.(\d+)", self.name)
or re.search(r"\.h\.(\d+)", self.name)
or re.search(r"\.layers\.(\d+)", self.name)
)
else:
match = None
if match:
layer_idx = int(match.group(1))
if layer_idx == self._stop_at_layer_idx:
if len(args) > 0 and isinstance(args[0], torch.Tensor):
input_tensor = args[0]
elif "hidden_states" in kwargs and isinstance(
kwargs["hidden_states"], torch.Tensor
):
input_tensor = kwargs["hidden_states"]
else:
raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}")
input_tensor = self.hook_in(input_tensor)
from transformer_lens.model_bridge.exceptions import (
StopAtLayerException,
)

raise StopAtLayerException(input_tensor)

# Apply hook_in to hidden_states
if len(args) > 0 and isinstance(args[0], torch.Tensor):
hooked_input = self.hook_in(args[0])
args = (hooked_input,) + args[1:]
elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
self._check_stop_at_layer(*args, **kwargs)
args, kwargs = self._hook_input_hidden_states(args, kwargs)

# BLOOM blocks require 'alibi' and 'attention_mask' arguments.
# If HF's BloomModel.forward() is calling us, these will already be present.
Expand Down Expand Up @@ -198,16 +163,4 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:

# Call original component
output = self.original_component(*args, **kwargs)

# Apply hook_out
if isinstance(output, tuple) and len(output) > 0:
first = output[0]
if isinstance(first, torch.Tensor):
first = self.hook_out(first)
if len(output) == 1:
return first
output = (first,) + output[1:]
return output
if isinstance(output, torch.Tensor):
output = self.hook_out(output)
return output
return self._apply_output_hook(output, wrap_single_element=False)
67 changes: 26 additions & 41 deletions transformer_lens/model_bridge/generalized_components/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,7 @@ class GatedMLPBridge(MLPBridge):
"hook_pre_linear": "in.hook_out",
"hook_post": "out.hook_in",
}
property_aliases = {
"W_gate": "gate.weight",
"b_gate": "gate.bias",
"W_in": "in.weight",
"b_in": "in.bias",
"W_out": "out.weight",
"b_out": "out.bias",
}
# property_aliases inherited from MLPBridge (W_gate, b_gate, W_in, b_in, W_out, b_out)

def __init__(
self,
Expand Down Expand Up @@ -99,41 +92,33 @@ def forward(self, *args, **kwargs) -> torch.Tensor:
Output hidden states
"""
if hasattr(self, "_use_processed_weights") and self._use_processed_weights:
assert hasattr(self, "_processed_W_gate") and hasattr(self, "_processed_W_in"), (
"Processed weights flag is set but weights are missing. "
"This indicates a bug in set_processed_weights()."
)
assert self._processed_W_in is not None
assert self._processed_W_out is not None
hidden_states = args[0]
hidden_states = self.hook_in(hidden_states)
if hasattr(self, "_processed_W_gate") and hasattr(self, "_processed_W_in"):
assert self._processed_W_in is not None # Guarded by hasattr check above
assert self._processed_W_out is not None
gate_output = torch.nn.functional.linear(
hidden_states, self._processed_W_gate, self._processed_b_gate
)
if hasattr(self, "gate") and hasattr(self.gate, "hook_out"):
gate_output = self.gate.hook_out(gate_output)
linear_output = torch.nn.functional.linear(
hidden_states, self._processed_W_in, self._processed_b_in
)
in_module = getattr(self, "in", None)
if in_module is not None and hasattr(in_module, "hook_out"):
linear_output = in_module.hook_out(linear_output) # type: ignore[misc]
act_fn = resolve_activation_fn(self.config)
activated = act_fn(gate_output)
hidden = activated * linear_output
if hasattr(self, "out") and hasattr(self.out, "hook_in"):
hidden = self.out.hook_in(hidden)
output = torch.nn.functional.linear(
hidden, self._processed_W_out, self._processed_b_out
)
else:
import warnings

warnings.warn(
"Processed weights flag set but weights missing — "
"falling back to original component. "
"Intermediate MLP hooks will not fire.",
stacklevel=2,
)
new_args = (hidden_states,) + args[1:]
output = self.original_component(*new_args, **kwargs) # type: ignore[misc]
gate_output = torch.nn.functional.linear(
hidden_states, self._processed_W_gate, self._processed_b_gate
)
if hasattr(self, "gate") and hasattr(self.gate, "hook_out"):
gate_output = self.gate.hook_out(gate_output)
linear_output = torch.nn.functional.linear(
hidden_states, self._processed_W_in, self._processed_b_in
)
in_module = getattr(self, "in", None)
if in_module is not None and hasattr(in_module, "hook_out"):
linear_output = in_module.hook_out(linear_output) # type: ignore[misc]
act_fn = resolve_activation_fn(self.config)
activated = act_fn(gate_output)
hidden = activated * linear_output
if hasattr(self, "out") and hasattr(self.out, "hook_in"):
hidden = self.out.hook_in(hidden)
output = torch.nn.functional.linear(
hidden, self._processed_W_out, self._processed_b_out
)
output = self.hook_out(output)
return output
if self.original_component is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,7 @@ class JointQKVAttentionBridge(AttentionBridge):
the individual activations from the separated q, k, and v matrices are hooked and accessible.
"""

# Property aliases point to the linear bridge weights
property_aliases = {
"W_Q": "q.weight",
"W_K": "k.weight",
"W_V": "v.weight",
"W_O": "o.weight",
"b_Q": "q.bias",
"b_K": "k.bias",
"b_V": "v.bias",
"b_O": "o.bias",
}
# property_aliases inherited from AttentionBridge (W_Q, W_K, W_V, W_O, b_Q, b_K, b_V, b_O)

def __init__(
self,
Expand Down Expand Up @@ -422,16 +412,13 @@ def _reconstruct_attention(

attn_scores = self.hook_attn_scores(attn_scores)

if reorder_and_upcast:
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(v.dtype)
else:
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)

attn_weights = self._apply_attn_dropout(attn_weights)
attn_weights = self.hook_pattern(attn_weights)
attn_weights = self._softmax_dropout_pattern(
attn_scores,
target_dtype=v.dtype if reorder_and_upcast else None,
)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim)
attn_output = self._reshape_attn_output(
attn_output, batch_size, seq_len, num_heads, head_dim
)
attn_output = self._apply_output_projection(attn_output)
return (attn_output, attn_weights)
Loading
Loading