Skip to content

Commit 519ae6c

Browse files
authored
Feature/prepend bos padding side brdige (#1236)
* Bridge & MLP layer feature improvements * Adjusted our approach to prepend_bos and padding_side
1 parent 71a6418 commit 519ae6c

File tree

2 files changed

+97
-8
lines changed

2 files changed

+97
-8
lines changed

transformer_lens/model_bridge/bridge.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,7 +1201,11 @@ def forward(
12011201
loss_per_token: Whether to return loss per token
12021202
prepend_bos: Whether to prepend BOS token
12031203
padding_side: Which side to pad on
1204-
start_at_layer: Layer to start forward pass from
1204+
start_at_layer: Not implemented in TransformerBridge. The bridge delegates
1205+
to HuggingFace's model.forward() which owns the layer iteration loop,
1206+
making start_at_layer infeasible without monkey-patching HF internals
1207+
(fragile across HF versions) or exception-based layer skipping (corrupts
1208+
model state). Raises NotImplementedError if a non-None value is passed.
12051209
stop_at_layer: Layer to stop forward pass at
12061210
pixel_values: Optional image tensor for multimodal models (e.g., LLaVA, Gemma3).
12071211
The tensor is passed directly to the underlying HuggingFace model.
@@ -1215,6 +1219,14 @@ def forward(
12151219
Model output based on return_type
12161220
"""
12171221

1222+
if start_at_layer is not None:
1223+
raise NotImplementedError(
1224+
"start_at_layer is not supported in TransformerBridge. "
1225+
"The bridge delegates to HuggingFace's model.forward() which controls "
1226+
"the layer iteration loop. See the TransformerBridge review plan for a "
1227+
"detailed analysis of implementation approaches and their tradeoffs."
1228+
)
1229+
12181230
# Set stop_at_layer flag on all blocks if requested
12191231
if stop_at_layer is not None and hasattr(self, "blocks"):
12201232
for block in self.blocks:
@@ -1382,11 +1394,18 @@ def forward(
13821394
# Execution stopped at the requested layer
13831395
return e.layer_output
13841396
finally:
1385-
# Clean up the stop_at_layer flag on all blocks
1397+
# Clean up state that may be inconsistent after StopAtLayerException
13861398
if stop_at_layer is not None and hasattr(self, "blocks"):
1399+
# Reset the stop flag on all blocks
13871400
for block in self.blocks:
13881401
block._stop_at_layer_idx = None
13891402

1403+
# Clear any stale KV cache — layers after the stop point didn't
1404+
# execute, so the cache is incomplete and would corrupt subsequent
1405+
# generate() calls that expect a full cache.
1406+
if hasattr(self, "_last_hf_cache"):
1407+
del self._last_hf_cache
1408+
13901409
def get_hook_point(self, hook_name: str) -> Optional[HookPoint]:
13911410
"""Get a hook point by name from the bridge's hook system."""
13921411
if hook_name in self._hook_registry:
@@ -1465,7 +1484,7 @@ def run_with_cache(
14651484
return_cache_object: Whether to return ActivationCache object
14661485
remove_batch_dim: Whether to remove batch dimension
14671486
names_filter: Filter for which activations to cache (str, list of str, or callable)
1468-
stop_at_layer: Layer to stop forward pass at (not yet fully implemented)
1487+
stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop)
14691488
**kwargs: Additional arguments
14701489
# type: ignore[name-defined]
14711490
Returns:
@@ -1660,7 +1679,7 @@ def run_with_hooks(
16601679
clear_contexts: Whether to clear hook contexts
16611680
return_type: What to return ("logits", "loss", etc.)
16621681
names_filter: Filter for hook names (not used directly, for compatibility)
1663-
stop_at_layer: Layer to stop at (not yet fully implemented)
1682+
stop_at_layer: Layer to stop at (uses StopAtLayerException; cleans up KV cache on stop)
16641683
remove_batch_dim: Whether to remove batch dimension from hook inputs (only works for batch_size==1)
16651684
**kwargs: Additional arguments
16661685
@@ -1806,8 +1825,12 @@ def generate(
18061825
repetition by dividing positive logits and multiplying negative logits for
18071826
previously seen tokens. Default 1.0 (no penalty).
18081827
use_past_kv_cache: If True, use KV caching for faster generation
1809-
prepend_bos: Not used in Bridge (kept for API compatibility)
1810-
padding_side: Not used in Bridge (kept for API compatibility)
1828+
prepend_bos: Accepted for API compatibility but not applied during generation.
1829+
The HF model expects tokens in its native format (tokenizer defaults).
1830+
Overriding BOS can silently degrade generation quality.
1831+
padding_side: Accepted for API compatibility but not applied during generation.
1832+
The generation loop always extends tokens to the right, so overriding
1833+
initial padding_side creates inconsistent token layout.
18111834
return_type: The type of output to return - 'input', 'str', or 'tokens'
18121835
verbose: Not used in Bridge (kept for API compatibility)
18131836
output_logits: If True, return a ModelOutput with sequences and logits tuple
@@ -1819,7 +1842,30 @@ def generate(
18191842
Generated sequence as string, list of strings, or tensor depending on input type and return_type.
18201843
If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes.
18211844
"""
1822-
# Convert input to tokens using to_tokens() for consistent special token handling
1845+
# prepend_bos and padding_side are intentionally not applied during generation.
1846+
# The HF model expects tokens in its native format. Overriding BOS can silently
1847+
# degrade quality, and overriding padding_side conflicts with the generation loop
1848+
# which always extends tokens to the right.
1849+
if prepend_bos is not None:
1850+
import warnings
1851+
1852+
warnings.warn(
1853+
"prepend_bos is ignored during TransformerBridge.generate(). "
1854+
"The HF model expects tokens with the tokenizer's default BOS handling. "
1855+
"To control BOS, tokenize with to_tokens(prepend_bos=...) and pass the "
1856+
"resulting tensor to generate().",
1857+
stacklevel=2,
1858+
)
1859+
if padding_side is not None:
1860+
import warnings
1861+
1862+
warnings.warn(
1863+
"padding_side is ignored during TransformerBridge.generate(). "
1864+
"The generation loop extends tokens to the right regardless of initial "
1865+
"padding. To control padding, tokenize with to_tokens(padding_side=...) "
1866+
"and pass the resulting tensor to generate().",
1867+
stacklevel=2,
1868+
)
18231869
_generate_from_embeds = False
18241870
if isinstance(input, str):
18251871
input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)

transformer_lens/model_bridge/generalized_components/gated_mlp.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def __init__(
8787
def forward(self, *args, **kwargs) -> torch.Tensor:
8888
"""Forward pass through the gated MLP bridge.
8989
90+
Intermediate hooks (gate.hook_out, in.hook_out, out.hook_in) only fire in
91+
compatibility mode with processed weights enabled. In non-compatibility mode,
92+
the HF component is called as an opaque forward and only hook_in/hook_out fire.
93+
9094
Args:
9195
*args: Positional arguments for the original component
9296
**kwargs: Keyword arguments for the original component
@@ -98,6 +102,8 @@ def forward(self, *args, **kwargs) -> torch.Tensor:
98102
hidden_states = args[0]
99103
hidden_states = self.hook_in(hidden_states)
100104
if hasattr(self, "_processed_W_gate") and hasattr(self, "_processed_W_in"):
105+
assert self._processed_W_in is not None # Guarded by hasattr check above
106+
assert self._processed_W_out is not None
101107
gate_output = torch.nn.functional.linear(
102108
hidden_states, self._processed_W_gate, self._processed_b_gate
103109
)
@@ -118,6 +124,14 @@ def forward(self, *args, **kwargs) -> torch.Tensor:
118124
hidden, self._processed_W_out, self._processed_b_out
119125
)
120126
else:
127+
import warnings
128+
129+
warnings.warn(
130+
"Processed weights flag set but weights missing — "
131+
"falling back to original component. "
132+
"Intermediate MLP hooks will not fire.",
133+
stacklevel=2,
134+
)
121135
new_args = (hidden_states,) + args[1:]
122136
output = self.original_component(*new_args, **kwargs) # type: ignore[misc]
123137
output = self.hook_out(output)
@@ -159,17 +173,46 @@ def set_processed_weights(
159173
return
160174
b_gate = weights.get("gate.bias")
161175

176+
W_in = weights.get("in.weight")
177+
b_in = weights.get("in.bias")
178+
W_out = weights.get("out.weight")
179+
b_out = weights.get("out.bias")
180+
162181
if verbose:
163182
print(f" Setting W_gate with shape: {W_gate.shape}")
164183
if b_gate is not None:
165184
print(f" Setting b_gate with shape: {b_gate.shape}")
185+
if W_in is not None:
186+
print(f" Setting W_in with shape: {W_in.shape}")
187+
if W_out is not None:
188+
print(f" Setting W_out with shape: {W_out.shape}")
166189

167-
gate_module = getattr(self, "gate", None)
168190
self._use_processed_weights = True
169191
self._processed_W_gate = W_gate
170192
self._processed_b_gate = b_gate
193+
self._processed_W_in = W_in
194+
self._processed_b_in = b_in
195+
self._processed_W_out = W_out
196+
self._processed_b_out = b_out
197+
198+
# Distribute to submodules if they support it
199+
gate_module = getattr(self, "gate", None)
171200
if gate_module and hasattr(gate_module, "set_processed_weights"):
172201
gate_weights: Dict[str, torch.Tensor] = {"weight": W_gate}
173202
if b_gate is not None:
174203
gate_weights["bias"] = b_gate
175204
gate_module.set_processed_weights(gate_weights, verbose=verbose)
205+
206+
in_module = getattr(self, "in", None)
207+
if in_module and hasattr(in_module, "set_processed_weights") and W_in is not None:
208+
in_weights: Dict[str, torch.Tensor] = {"weight": W_in}
209+
if b_in is not None:
210+
in_weights["bias"] = b_in
211+
in_module.set_processed_weights(in_weights, verbose=verbose)
212+
213+
out_module = getattr(self, "out", None)
214+
if out_module and hasattr(out_module, "set_processed_weights") and W_out is not None:
215+
out_weights: Dict[str, torch.Tensor] = {"weight": W_out}
216+
if b_out is not None:
217+
out_weights["bias"] = b_out
218+
out_module.set_processed_weights(out_weights, verbose=verbose)

0 commit comments

Comments
 (0)