diff --git a/transformer_lens/model_bridge/generalized_components/attention.py b/transformer_lens/model_bridge/generalized_components/attention.py index e3bcc7ed2..05d5e0982 100644 --- a/transformer_lens/model_bridge/generalized_components/attention.py +++ b/transformer_lens/model_bridge/generalized_components/attention.py @@ -290,10 +290,6 @@ def revert(self, input_value, *full_context): if hasattr(self, "hook_rot_k"): self.hook_rot_k.hook_conversion = TransposeRotaryHeads() - def _setup_hook_z_reshape(self) -> None: - """Backward compatibility alias for _setup_qkv_hook_reshaping.""" - self._setup_qkv_hook_reshaping() - def _update_kv_cache( self, k: torch.Tensor, v: torch.Tensor, **kwargs: Any ) -> tuple[torch.Tensor, torch.Tensor]: @@ -367,6 +363,45 @@ def _apply_output_projection(self, attn_output: torch.Tensor) -> torch.Tensor: attn_output = self.o(attn_output) return attn_output + def _softmax_dropout_pattern( + self, + attn_scores: torch.Tensor, + target_dtype: torch.dtype | None = None, + upcast_to_fp32: bool = False, + ) -> torch.Tensor: + """Apply softmax, dropout, and hook_pattern to attention scores. + + Args: + attn_scores: Raw attention scores [batch, heads, q_seq, kv_seq]. + target_dtype: If set, cast weights to this dtype after softmax. + upcast_to_fp32: If True, compute softmax in float32 for numerical + stability, then cast to target_dtype. + """ + if upcast_to_fp32: + attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32) + if target_dtype is not None: + attn_weights = attn_weights.to(target_dtype) + else: + attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) + if target_dtype is not None: + attn_weights = attn_weights.to(target_dtype) + attn_weights = self._apply_attn_dropout(attn_weights) + attn_weights = self.hook_pattern(attn_weights) + return attn_weights + + def _reshape_attn_output( + self, + attn_output: torch.Tensor, + batch_size: int, + seq_len: int, + num_heads: int, + head_dim: int, + ) -> torch.Tensor: + """Reshape attention output from [batch, heads, seq, dim] to [batch, seq, hidden].""" + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim) + return attn_output + def _apply_reconstruct_attention_mask( self, attn_scores: torch.Tensor, diff --git a/transformer_lens/model_bridge/generalized_components/block.py b/transformer_lens/model_bridge/generalized_components/block.py index c4d434df6..48147a9d2 100644 --- a/transformer_lens/model_bridge/generalized_components/block.py +++ b/transformer_lens/model_bridge/generalized_components/block.py @@ -96,60 +96,76 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: f"Original component not set for {self.name}. Call set_original_component() first." ) - # Check if we should stop before executing this block - # The _stop_at_layer_idx attribute is set by the bridge's forward method - if hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None: - # Extract layer index (supports TL/GPT-2/LLaMA naming patterns) - if self.name is not None: - # Try multiple patterns to extract layer index - match = ( - re.search(r"blocks\.(\d+)", self.name) - or re.search(r"\.h\.(\d+)", self.name) - or re.search(r"\.layers\.(\d+)", self.name) - ) - else: - match = None - if match: - layer_idx = int(match.group(1)) - if layer_idx == self._stop_at_layer_idx: - # Get the input tensor to return - if len(args) > 0 and isinstance(args[0], torch.Tensor): - input_tensor = args[0] - elif "hidden_states" in kwargs and isinstance( - kwargs["hidden_states"], torch.Tensor - ): - input_tensor = kwargs["hidden_states"] - else: - raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}") - # Run hook_in on the input before stopping - input_tensor = self.hook_in(input_tensor) - raise StopAtLayerException(input_tensor) - - if len(args) > 0 and isinstance(args[0], torch.Tensor): - hooked_input = self.hook_in(args[0]) - args = (hooked_input,) + args[1:] - elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): - kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"]) + self._check_stop_at_layer(*args, **kwargs) + args, kwargs = self._hook_input_hidden_states(args, kwargs) # Filter kwargs to only include parameters accepted by the original component # This prevents errors when passing encoder-specific params to decoder-only models filtered_kwargs = self._filter_kwargs_for_forward(kwargs, len(args)) output = self.original_component(*args, **filtered_kwargs) + return self._apply_output_hook(output) + + def _apply_output_hook(self, output: Any, wrap_single_element: bool = True) -> Any: + """Hook the primary tensor in the output and return the result. + + Args: + output: Raw output from the original component (tensor or tuple). + wrap_single_element: If True, single-element tuples stay as tuples after + hooking (default, required by most HF models). If False, single-element + tuples are unwrapped to a bare tensor (Bloom convention). + """ if isinstance(output, tuple) and len(output) > 0: first = output[0] if isinstance(first, torch.Tensor): first = self.hook_out(first) - # Always return tuple to maintain consistency with HF's expected format - # e.g. GPT2Model does hidden_states = outputs[0], it expects outputs to be a tuple if len(output) == 1: - return (first,) + return (first,) if wrap_single_element else first output = (first,) + output[1:] return output if isinstance(output, torch.Tensor): output = self.hook_out(output) return output + def _check_stop_at_layer(self, *args: Any, **kwargs: Any) -> None: + """Check if execution should stop before this block. Raises StopAtLayerException. + + The _stop_at_layer_idx attribute is set by the bridge's forward method. + Supports TL/GPT-2/LLaMA naming patterns for layer index extraction. + """ + if not (hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None): + return + if self.name is not None: + match = ( + re.search(r"blocks\.(\d+)", self.name) + or re.search(r"\.h\.(\d+)", self.name) + or re.search(r"\.layers\.(\d+)", self.name) + ) + else: + match = None + if match: + layer_idx = int(match.group(1)) + if layer_idx == self._stop_at_layer_idx: + if len(args) > 0 and isinstance(args[0], torch.Tensor): + input_tensor = args[0] + elif "hidden_states" in kwargs and isinstance( + kwargs["hidden_states"], torch.Tensor + ): + input_tensor = kwargs["hidden_states"] + else: + raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}") + input_tensor = self.hook_in(input_tensor) + raise StopAtLayerException(input_tensor) + + def _hook_input_hidden_states(self, args: tuple, kwargs: dict) -> tuple[tuple, dict]: + """Apply hook_in to the hidden_states input, whether in args or kwargs.""" + if len(args) > 0 and isinstance(args[0], torch.Tensor): + hooked_input = self.hook_in(args[0]) + args = (hooked_input,) + args[1:] + elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): + kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"]) + return args, kwargs + def _filter_kwargs_for_forward( self, kwargs: Dict[str, Any], num_positional_args: int = 0 ) -> Dict[str, Any]: diff --git a/transformer_lens/model_bridge/generalized_components/bloom_attention.py b/transformer_lens/model_bridge/generalized_components/bloom_attention.py index 90dedecc1..cbbd29099 100644 --- a/transformer_lens/model_bridge/generalized_components/bloom_attention.py +++ b/transformer_lens/model_bridge/generalized_components/bloom_attention.py @@ -187,19 +187,18 @@ def _reconstruct_attention( attn_scores = self.hook_attn_scores(attn_scores) # Softmax in float32 for numerical stability (matches HF BLOOM) - attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32) - attn_weights = attn_weights.to(q.dtype) - - attn_weights = self._apply_attn_dropout(attn_weights) - attn_weights = self.hook_pattern(attn_weights) + attn_weights = self._softmax_dropout_pattern( + attn_scores, target_dtype=q.dtype, upcast_to_fp32=True + ) # bmm in [batch*heads, seq, seq] format for BLOOM compatibility attn_weights_bh = attn_weights.reshape(batch_size * num_heads, seq_len, -1) attn_output = torch.bmm(attn_weights_bh, v_bh) attn_output = attn_output.view(batch_size, num_heads, seq_len, head_dim) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim) + attn_output = self._reshape_attn_output( + attn_output, batch_size, seq_len, num_heads, head_dim + ) attn_output = self._apply_output_projection(attn_output) return (attn_output, attn_weights) diff --git a/transformer_lens/model_bridge/generalized_components/bloom_block.py b/transformer_lens/model_bridge/generalized_components/bloom_block.py index ccb29bc03..7a20ac094 100644 --- a/transformer_lens/model_bridge/generalized_components/bloom_block.py +++ b/transformer_lens/model_bridge/generalized_components/bloom_block.py @@ -38,7 +38,6 @@ def __init__( """ super().__init__(name, config, submodules, hook_alias_overrides) self.config = config - self._alibi_cache: Optional[torch.Tensor] = None @staticmethod def build_alibi_tensor( @@ -111,42 +110,8 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: f"Original component not set for {self.name}. Call set_original_component() first." ) - # Check if we should stop before executing this block - if hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None: - import re - - if self.name is not None: - match = ( - re.search(r"blocks\.(\d+)", self.name) - or re.search(r"\.h\.(\d+)", self.name) - or re.search(r"\.layers\.(\d+)", self.name) - ) - else: - match = None - if match: - layer_idx = int(match.group(1)) - if layer_idx == self._stop_at_layer_idx: - if len(args) > 0 and isinstance(args[0], torch.Tensor): - input_tensor = args[0] - elif "hidden_states" in kwargs and isinstance( - kwargs["hidden_states"], torch.Tensor - ): - input_tensor = kwargs["hidden_states"] - else: - raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}") - input_tensor = self.hook_in(input_tensor) - from transformer_lens.model_bridge.exceptions import ( - StopAtLayerException, - ) - - raise StopAtLayerException(input_tensor) - - # Apply hook_in to hidden_states - if len(args) > 0 and isinstance(args[0], torch.Tensor): - hooked_input = self.hook_in(args[0]) - args = (hooked_input,) + args[1:] - elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): - kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"]) + self._check_stop_at_layer(*args, **kwargs) + args, kwargs = self._hook_input_hidden_states(args, kwargs) # BLOOM blocks require 'alibi' and 'attention_mask' arguments. # If HF's BloomModel.forward() is calling us, these will already be present. @@ -198,16 +163,4 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: # Call original component output = self.original_component(*args, **kwargs) - - # Apply hook_out - if isinstance(output, tuple) and len(output) > 0: - first = output[0] - if isinstance(first, torch.Tensor): - first = self.hook_out(first) - if len(output) == 1: - return first - output = (first,) + output[1:] - return output - if isinstance(output, torch.Tensor): - output = self.hook_out(output) - return output + return self._apply_output_hook(output, wrap_single_element=False) diff --git a/transformer_lens/model_bridge/generalized_components/gated_mlp.py b/transformer_lens/model_bridge/generalized_components/gated_mlp.py index 4b863e1ff..80e1a2d54 100644 --- a/transformer_lens/model_bridge/generalized_components/gated_mlp.py +++ b/transformer_lens/model_bridge/generalized_components/gated_mlp.py @@ -60,14 +60,7 @@ class GatedMLPBridge(MLPBridge): "hook_pre_linear": "in.hook_out", "hook_post": "out.hook_in", } - property_aliases = { - "W_gate": "gate.weight", - "b_gate": "gate.bias", - "W_in": "in.weight", - "b_in": "in.bias", - "W_out": "out.weight", - "b_out": "out.bias", - } + # property_aliases inherited from MLPBridge (W_gate, b_gate, W_in, b_in, W_out, b_out) def __init__( self, @@ -99,41 +92,33 @@ def forward(self, *args, **kwargs) -> torch.Tensor: Output hidden states """ if hasattr(self, "_use_processed_weights") and self._use_processed_weights: + assert hasattr(self, "_processed_W_gate") and hasattr(self, "_processed_W_in"), ( + "Processed weights flag is set but weights are missing. " + "This indicates a bug in set_processed_weights()." + ) + assert self._processed_W_in is not None + assert self._processed_W_out is not None hidden_states = args[0] hidden_states = self.hook_in(hidden_states) - if hasattr(self, "_processed_W_gate") and hasattr(self, "_processed_W_in"): - assert self._processed_W_in is not None # Guarded by hasattr check above - assert self._processed_W_out is not None - gate_output = torch.nn.functional.linear( - hidden_states, self._processed_W_gate, self._processed_b_gate - ) - if hasattr(self, "gate") and hasattr(self.gate, "hook_out"): - gate_output = self.gate.hook_out(gate_output) - linear_output = torch.nn.functional.linear( - hidden_states, self._processed_W_in, self._processed_b_in - ) - in_module = getattr(self, "in", None) - if in_module is not None and hasattr(in_module, "hook_out"): - linear_output = in_module.hook_out(linear_output) # type: ignore[misc] - act_fn = resolve_activation_fn(self.config) - activated = act_fn(gate_output) - hidden = activated * linear_output - if hasattr(self, "out") and hasattr(self.out, "hook_in"): - hidden = self.out.hook_in(hidden) - output = torch.nn.functional.linear( - hidden, self._processed_W_out, self._processed_b_out - ) - else: - import warnings - - warnings.warn( - "Processed weights flag set but weights missing — " - "falling back to original component. " - "Intermediate MLP hooks will not fire.", - stacklevel=2, - ) - new_args = (hidden_states,) + args[1:] - output = self.original_component(*new_args, **kwargs) # type: ignore[misc] + gate_output = torch.nn.functional.linear( + hidden_states, self._processed_W_gate, self._processed_b_gate + ) + if hasattr(self, "gate") and hasattr(self.gate, "hook_out"): + gate_output = self.gate.hook_out(gate_output) + linear_output = torch.nn.functional.linear( + hidden_states, self._processed_W_in, self._processed_b_in + ) + in_module = getattr(self, "in", None) + if in_module is not None and hasattr(in_module, "hook_out"): + linear_output = in_module.hook_out(linear_output) # type: ignore[misc] + act_fn = resolve_activation_fn(self.config) + activated = act_fn(gate_output) + hidden = activated * linear_output + if hasattr(self, "out") and hasattr(self.out, "hook_in"): + hidden = self.out.hook_in(hidden) + output = torch.nn.functional.linear( + hidden, self._processed_W_out, self._processed_b_out + ) output = self.hook_out(output) return output if self.original_component is None: diff --git a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py index 720b8ea7e..530cff3fb 100644 --- a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py +++ b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py @@ -27,17 +27,7 @@ class JointQKVAttentionBridge(AttentionBridge): the individual activations from the separated q, k, and v matrices are hooked and accessible. """ - # Property aliases point to the linear bridge weights - property_aliases = { - "W_Q": "q.weight", - "W_K": "k.weight", - "W_V": "v.weight", - "W_O": "o.weight", - "b_Q": "q.bias", - "b_K": "k.bias", - "b_V": "v.bias", - "b_O": "o.bias", - } + # property_aliases inherited from AttentionBridge (W_Q, W_K, W_V, W_O, b_Q, b_K, b_V, b_O) def __init__( self, @@ -422,16 +412,13 @@ def _reconstruct_attention( attn_scores = self.hook_attn_scores(attn_scores) - if reorder_and_upcast: - attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) - attn_weights = attn_weights.to(v.dtype) - else: - attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) - - attn_weights = self._apply_attn_dropout(attn_weights) - attn_weights = self.hook_pattern(attn_weights) + attn_weights = self._softmax_dropout_pattern( + attn_scores, + target_dtype=v.dtype if reorder_and_upcast else None, + ) attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim) + attn_output = self._reshape_attn_output( + attn_output, batch_size, seq_len, num_heads, head_dim + ) attn_output = self._apply_output_projection(attn_output) return (attn_output, attn_weights) diff --git a/transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py b/transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py index a13f92c6a..8868bb16d 100644 --- a/transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py +++ b/transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py @@ -205,12 +205,11 @@ def _reconstruct_attention( ) attn_scores = self.hook_attn_scores(attn_scores) - attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) - attn_weights = self._apply_attn_dropout(attn_weights) - attn_weights = self.hook_pattern(attn_weights) + attn_weights = self._softmax_dropout_pattern(attn_scores) attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim) + attn_output = self._reshape_attn_output( + attn_output, batch_size, seq_len, num_heads, head_dim + ) attn_output = self._apply_output_projection(attn_output) return (attn_output, attn_weights) diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index 2f67419db..fa837711d 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -2486,14 +2486,15 @@ "architecture_id": "MistralForCausalLM", "model_id": "trl-internal-testing/tiny-MistralForCausalLM-0.2", "status": 1, - "verified_date": "2026-03-10", + "verified_date": "2026-04-07", "metadata": null, "note": "Full verification completed with issues, low text quality", "phase1_score": 100.0, "phase2_score": 100.0, "phase3_score": 100.0, "phase4_score": 47.5, - "phase7_score": null + "phase7_score": null, + "phase8_score": null }, { "architecture_id": "MistralForCausalLM", diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index 08d286a61..bc2fde984 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-04-07T18:30:31.307768", + "last_updated": "2026-04-07T18:56:31.723897", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -11101,26 +11101,6 @@ "invalidated": false, "invalidation_reason": null }, - { - "model_id": "distilbert/distilgpt2", - "architecture_id": "GPT2LMHeadModel", - "verified_date": "2026-04-07", - "verified_by": "verify_models", - "transformerlens_version": null, - "notes": "Full verification completed", - "invalidated": false, - "invalidation_reason": null - }, - { - "model_id": "distilbert/distilgpt2", - "architecture_id": "GPT2LMHeadModel", - "verified_date": "2026-04-07", - "verified_by": "verify_models", - "transformerlens_version": null, - "notes": "Full verification completed", - "invalidated": false, - "invalidation_reason": null - }, { "model_id": "openai-community/gpt2", "architecture_id": "GPT2LMHeadModel", @@ -11152,12 +11132,12 @@ "invalidation_reason": null }, { - "model_id": "microsoft/Phi-3-mini-4k-instruct", - "architecture_id": "Phi3ForCausalLM", + "model_id": "trl-internal-testing/tiny-MistralForCausalLM-0.2", + "architecture_id": "MistralForCausalLM", "verified_date": "2026-04-07", "verified_by": "verify_models", "transformerlens_version": null, - "notes": "Full verification completed", + "notes": "Full verification completed with issues, low text quality", "invalidated": false, "invalidation_reason": null }