Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,11 @@ def forward(
loss_per_token: Whether to return loss per token
prepend_bos: Whether to prepend BOS token
padding_side: Which side to pad on
start_at_layer: Layer to start forward pass from
start_at_layer: Not implemented in TransformerBridge. The bridge delegates
to HuggingFace's model.forward() which owns the layer iteration loop,
making start_at_layer infeasible without monkey-patching HF internals
(fragile across HF versions) or exception-based layer skipping (corrupts
model state). Raises NotImplementedError if a non-None value is passed.
stop_at_layer: Layer to stop forward pass at
pixel_values: Optional image tensor for multimodal models (e.g., LLaVA, Gemma3).
The tensor is passed directly to the underlying HuggingFace model.
Expand All @@ -1215,6 +1219,14 @@ def forward(
Model output based on return_type
"""

if start_at_layer is not None:
raise NotImplementedError(
"start_at_layer is not supported in TransformerBridge. "
"The bridge delegates to HuggingFace's model.forward() which controls "
"the layer iteration loop. See the TransformerBridge review plan for a "
"detailed analysis of implementation approaches and their tradeoffs."
)

# Set stop_at_layer flag on all blocks if requested
if stop_at_layer is not None and hasattr(self, "blocks"):
for block in self.blocks:
Expand Down Expand Up @@ -1382,11 +1394,18 @@ def forward(
# Execution stopped at the requested layer
return e.layer_output
finally:
# Clean up the stop_at_layer flag on all blocks
# Clean up state that may be inconsistent after StopAtLayerException
if stop_at_layer is not None and hasattr(self, "blocks"):
# Reset the stop flag on all blocks
for block in self.blocks:
block._stop_at_layer_idx = None

# Clear any stale KV cache — layers after the stop point didn't
# execute, so the cache is incomplete and would corrupt subsequent
# generate() calls that expect a full cache.
if hasattr(self, "_last_hf_cache"):
del self._last_hf_cache

def get_hook_point(self, hook_name: str) -> Optional[HookPoint]:
"""Get a hook point by name from the bridge's hook system."""
if hook_name in self._hook_registry:
Expand Down Expand Up @@ -1465,7 +1484,7 @@ def run_with_cache(
return_cache_object: Whether to return ActivationCache object
remove_batch_dim: Whether to remove batch dimension
names_filter: Filter for which activations to cache (str, list of str, or callable)
stop_at_layer: Layer to stop forward pass at (not yet fully implemented)
stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop)
**kwargs: Additional arguments
# type: ignore[name-defined]
Returns:
Expand Down Expand Up @@ -1660,7 +1679,7 @@ def run_with_hooks(
clear_contexts: Whether to clear hook contexts
return_type: What to return ("logits", "loss", etc.)
names_filter: Filter for hook names (not used directly, for compatibility)
stop_at_layer: Layer to stop at (not yet fully implemented)
stop_at_layer: Layer to stop at (uses StopAtLayerException; cleans up KV cache on stop)
remove_batch_dim: Whether to remove batch dimension from hook inputs (only works for batch_size==1)
**kwargs: Additional arguments

Expand Down Expand Up @@ -1806,8 +1825,12 @@ def generate(
repetition by dividing positive logits and multiplying negative logits for
previously seen tokens. Default 1.0 (no penalty).
use_past_kv_cache: If True, use KV caching for faster generation
prepend_bos: Not used in Bridge (kept for API compatibility)
padding_side: Not used in Bridge (kept for API compatibility)
prepend_bos: Accepted for API compatibility but not applied during generation.
The HF model expects tokens in its native format (tokenizer defaults).
Overriding BOS can silently degrade generation quality.
padding_side: Accepted for API compatibility but not applied during generation.
The generation loop always extends tokens to the right, so overriding
initial padding_side creates inconsistent token layout.
return_type: The type of output to return - 'input', 'str', or 'tokens'
verbose: Not used in Bridge (kept for API compatibility)
output_logits: If True, return a ModelOutput with sequences and logits tuple
Expand All @@ -1819,7 +1842,30 @@ def generate(
Generated sequence as string, list of strings, or tensor depending on input type and return_type.
If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes.
"""
# Convert input to tokens using to_tokens() for consistent special token handling
# prepend_bos and padding_side are intentionally not applied during generation.
# The HF model expects tokens in its native format. Overriding BOS can silently
# degrade quality, and overriding padding_side conflicts with the generation loop
# which always extends tokens to the right.
if prepend_bos is not None:
import warnings

warnings.warn(
"prepend_bos is ignored during TransformerBridge.generate(). "
"The HF model expects tokens with the tokenizer's default BOS handling. "
"To control BOS, tokenize with to_tokens(prepend_bos=...) and pass the "
"resulting tensor to generate().",
stacklevel=2,
)
if padding_side is not None:
import warnings

warnings.warn(
"padding_side is ignored during TransformerBridge.generate(). "
"The generation loop extends tokens to the right regardless of initial "
"padding. To control padding, tokenize with to_tokens(padding_side=...) "
"and pass the resulting tensor to generate().",
stacklevel=2,
)
_generate_from_embeds = False
if isinstance(input, str):
input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def __init__(
def forward(self, *args, **kwargs) -> torch.Tensor:
"""Forward pass through the gated MLP bridge.

Intermediate hooks (gate.hook_out, in.hook_out, out.hook_in) only fire in
compatibility mode with processed weights enabled. In non-compatibility mode,
the HF component is called as an opaque forward and only hook_in/hook_out fire.

Args:
*args: Positional arguments for the original component
**kwargs: Keyword arguments for the original component
Expand All @@ -98,6 +102,8 @@ def forward(self, *args, **kwargs) -> torch.Tensor:
hidden_states = args[0]
hidden_states = self.hook_in(hidden_states)
if hasattr(self, "_processed_W_gate") and hasattr(self, "_processed_W_in"):
assert self._processed_W_in is not None # Guarded by hasattr check above
assert self._processed_W_out is not None
gate_output = torch.nn.functional.linear(
hidden_states, self._processed_W_gate, self._processed_b_gate
)
Expand All @@ -118,6 +124,14 @@ def forward(self, *args, **kwargs) -> torch.Tensor:
hidden, self._processed_W_out, self._processed_b_out
)
else:
import warnings

warnings.warn(
"Processed weights flag set but weights missing — "
"falling back to original component. "
"Intermediate MLP hooks will not fire.",
stacklevel=2,
)
new_args = (hidden_states,) + args[1:]
output = self.original_component(*new_args, **kwargs) # type: ignore[misc]
output = self.hook_out(output)
Expand Down Expand Up @@ -159,17 +173,46 @@ def set_processed_weights(
return
b_gate = weights.get("gate.bias")

W_in = weights.get("in.weight")
b_in = weights.get("in.bias")
W_out = weights.get("out.weight")
b_out = weights.get("out.bias")

if verbose:
print(f" Setting W_gate with shape: {W_gate.shape}")
if b_gate is not None:
print(f" Setting b_gate with shape: {b_gate.shape}")
if W_in is not None:
print(f" Setting W_in with shape: {W_in.shape}")
if W_out is not None:
print(f" Setting W_out with shape: {W_out.shape}")

gate_module = getattr(self, "gate", None)
self._use_processed_weights = True
self._processed_W_gate = W_gate
self._processed_b_gate = b_gate
self._processed_W_in = W_in
self._processed_b_in = b_in
self._processed_W_out = W_out
self._processed_b_out = b_out

# Distribute to submodules if they support it
gate_module = getattr(self, "gate", None)
if gate_module and hasattr(gate_module, "set_processed_weights"):
gate_weights: Dict[str, torch.Tensor] = {"weight": W_gate}
if b_gate is not None:
gate_weights["bias"] = b_gate
gate_module.set_processed_weights(gate_weights, verbose=verbose)

in_module = getattr(self, "in", None)
if in_module and hasattr(in_module, "set_processed_weights") and W_in is not None:
in_weights: Dict[str, torch.Tensor] = {"weight": W_in}
if b_in is not None:
in_weights["bias"] = b_in
in_module.set_processed_weights(in_weights, verbose=verbose)

out_module = getattr(self, "out", None)
if out_module and hasattr(out_module, "set_processed_weights") and W_out is not None:
out_weights: Dict[str, torch.Tensor] = {"weight": W_out}
if b_out is not None:
out_weights["bias"] = b_out
out_module.set_processed_weights(out_weights, verbose=verbose)
Loading