diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index a98b0193a..f7c909624 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1201,7 +1201,11 @@ def forward( loss_per_token: Whether to return loss per token prepend_bos: Whether to prepend BOS token padding_side: Which side to pad on - start_at_layer: Layer to start forward pass from + start_at_layer: Not implemented in TransformerBridge. The bridge delegates + to HuggingFace's model.forward() which owns the layer iteration loop, + making start_at_layer infeasible without monkey-patching HF internals + (fragile across HF versions) or exception-based layer skipping (corrupts + model state). Raises NotImplementedError if a non-None value is passed. stop_at_layer: Layer to stop forward pass at pixel_values: Optional image tensor for multimodal models (e.g., LLaVA, Gemma3). The tensor is passed directly to the underlying HuggingFace model. @@ -1215,6 +1219,14 @@ def forward( Model output based on return_type """ + if start_at_layer is not None: + raise NotImplementedError( + "start_at_layer is not supported in TransformerBridge. " + "The bridge delegates to HuggingFace's model.forward() which controls " + "the layer iteration loop. See the TransformerBridge review plan for a " + "detailed analysis of implementation approaches and their tradeoffs." + ) + # Set stop_at_layer flag on all blocks if requested if stop_at_layer is not None and hasattr(self, "blocks"): for block in self.blocks: @@ -1382,11 +1394,18 @@ def forward( # Execution stopped at the requested layer return e.layer_output finally: - # Clean up the stop_at_layer flag on all blocks + # Clean up state that may be inconsistent after StopAtLayerException if stop_at_layer is not None and hasattr(self, "blocks"): + # Reset the stop flag on all blocks for block in self.blocks: block._stop_at_layer_idx = None + # Clear any stale KV cache — layers after the stop point didn't + # execute, so the cache is incomplete and would corrupt subsequent + # generate() calls that expect a full cache. + if hasattr(self, "_last_hf_cache"): + del self._last_hf_cache + def get_hook_point(self, hook_name: str) -> Optional[HookPoint]: """Get a hook point by name from the bridge's hook system.""" if hook_name in self._hook_registry: @@ -1465,7 +1484,7 @@ def run_with_cache( return_cache_object: Whether to return ActivationCache object remove_batch_dim: Whether to remove batch dimension names_filter: Filter for which activations to cache (str, list of str, or callable) - stop_at_layer: Layer to stop forward pass at (not yet fully implemented) + stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop) **kwargs: Additional arguments # type: ignore[name-defined] Returns: @@ -1660,7 +1679,7 @@ def run_with_hooks( clear_contexts: Whether to clear hook contexts return_type: What to return ("logits", "loss", etc.) names_filter: Filter for hook names (not used directly, for compatibility) - stop_at_layer: Layer to stop at (not yet fully implemented) + stop_at_layer: Layer to stop at (uses StopAtLayerException; cleans up KV cache on stop) remove_batch_dim: Whether to remove batch dimension from hook inputs (only works for batch_size==1) **kwargs: Additional arguments @@ -1806,8 +1825,12 @@ def generate( repetition by dividing positive logits and multiplying negative logits for previously seen tokens. Default 1.0 (no penalty). use_past_kv_cache: If True, use KV caching for faster generation - prepend_bos: Not used in Bridge (kept for API compatibility) - padding_side: Not used in Bridge (kept for API compatibility) + prepend_bos: Accepted for API compatibility but not applied during generation. + The HF model expects tokens in its native format (tokenizer defaults). + Overriding BOS can silently degrade generation quality. + padding_side: Accepted for API compatibility but not applied during generation. + The generation loop always extends tokens to the right, so overriding + initial padding_side creates inconsistent token layout. return_type: The type of output to return - 'input', 'str', or 'tokens' verbose: Not used in Bridge (kept for API compatibility) output_logits: If True, return a ModelOutput with sequences and logits tuple @@ -1819,7 +1842,30 @@ def generate( Generated sequence as string, list of strings, or tensor depending on input type and return_type. If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes. """ - # Convert input to tokens using to_tokens() for consistent special token handling + # prepend_bos and padding_side are intentionally not applied during generation. + # The HF model expects tokens in its native format. Overriding BOS can silently + # degrade quality, and overriding padding_side conflicts with the generation loop + # which always extends tokens to the right. + if prepend_bos is not None: + import warnings + + warnings.warn( + "prepend_bos is ignored during TransformerBridge.generate(). " + "The HF model expects tokens with the tokenizer's default BOS handling. " + "To control BOS, tokenize with to_tokens(prepend_bos=...) and pass the " + "resulting tensor to generate().", + stacklevel=2, + ) + if padding_side is not None: + import warnings + + warnings.warn( + "padding_side is ignored during TransformerBridge.generate(). " + "The generation loop extends tokens to the right regardless of initial " + "padding. To control padding, tokenize with to_tokens(padding_side=...) " + "and pass the resulting tensor to generate().", + stacklevel=2, + ) _generate_from_embeds = False if isinstance(input, str): input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) diff --git a/transformer_lens/model_bridge/generalized_components/gated_mlp.py b/transformer_lens/model_bridge/generalized_components/gated_mlp.py index fbf511665..4b863e1ff 100644 --- a/transformer_lens/model_bridge/generalized_components/gated_mlp.py +++ b/transformer_lens/model_bridge/generalized_components/gated_mlp.py @@ -87,6 +87,10 @@ def __init__( def forward(self, *args, **kwargs) -> torch.Tensor: """Forward pass through the gated MLP bridge. + Intermediate hooks (gate.hook_out, in.hook_out, out.hook_in) only fire in + compatibility mode with processed weights enabled. In non-compatibility mode, + the HF component is called as an opaque forward and only hook_in/hook_out fire. + Args: *args: Positional arguments for the original component **kwargs: Keyword arguments for the original component @@ -98,6 +102,8 @@ def forward(self, *args, **kwargs) -> torch.Tensor: hidden_states = args[0] hidden_states = self.hook_in(hidden_states) if hasattr(self, "_processed_W_gate") and hasattr(self, "_processed_W_in"): + assert self._processed_W_in is not None # Guarded by hasattr check above + assert self._processed_W_out is not None gate_output = torch.nn.functional.linear( hidden_states, self._processed_W_gate, self._processed_b_gate ) @@ -118,6 +124,14 @@ def forward(self, *args, **kwargs) -> torch.Tensor: hidden, self._processed_W_out, self._processed_b_out ) else: + import warnings + + warnings.warn( + "Processed weights flag set but weights missing — " + "falling back to original component. " + "Intermediate MLP hooks will not fire.", + stacklevel=2, + ) new_args = (hidden_states,) + args[1:] output = self.original_component(*new_args, **kwargs) # type: ignore[misc] output = self.hook_out(output) @@ -159,17 +173,46 @@ def set_processed_weights( return b_gate = weights.get("gate.bias") + W_in = weights.get("in.weight") + b_in = weights.get("in.bias") + W_out = weights.get("out.weight") + b_out = weights.get("out.bias") + if verbose: print(f" Setting W_gate with shape: {W_gate.shape}") if b_gate is not None: print(f" Setting b_gate with shape: {b_gate.shape}") + if W_in is not None: + print(f" Setting W_in with shape: {W_in.shape}") + if W_out is not None: + print(f" Setting W_out with shape: {W_out.shape}") - gate_module = getattr(self, "gate", None) self._use_processed_weights = True self._processed_W_gate = W_gate self._processed_b_gate = b_gate + self._processed_W_in = W_in + self._processed_b_in = b_in + self._processed_W_out = W_out + self._processed_b_out = b_out + + # Distribute to submodules if they support it + gate_module = getattr(self, "gate", None) if gate_module and hasattr(gate_module, "set_processed_weights"): gate_weights: Dict[str, torch.Tensor] = {"weight": W_gate} if b_gate is not None: gate_weights["bias"] = b_gate gate_module.set_processed_weights(gate_weights, verbose=verbose) + + in_module = getattr(self, "in", None) + if in_module and hasattr(in_module, "set_processed_weights") and W_in is not None: + in_weights: Dict[str, torch.Tensor] = {"weight": W_in} + if b_in is not None: + in_weights["bias"] = b_in + in_module.set_processed_weights(in_weights, verbose=verbose) + + out_module = getattr(self, "out", None) + if out_module and hasattr(out_module, "set_processed_weights") and W_out is not None: + out_weights: Dict[str, torch.Tensor] = {"weight": W_out} + if b_out is not None: + out_weights["bias"] = b_out + out_module.set_processed_weights(out_weights, verbose=verbose)