@@ -1201,7 +1201,11 @@ def forward(
12011201 loss_per_token: Whether to return loss per token
12021202 prepend_bos: Whether to prepend BOS token
12031203 padding_side: Which side to pad on
1204- start_at_layer: Layer to start forward pass from
1204+ start_at_layer: Not implemented in TransformerBridge. The bridge delegates
1205+ to HuggingFace's model.forward() which owns the layer iteration loop,
1206+ making start_at_layer infeasible without monkey-patching HF internals
1207+ (fragile across HF versions) or exception-based layer skipping (corrupts
1208+ model state). Raises NotImplementedError if a non-None value is passed.
12051209 stop_at_layer: Layer to stop forward pass at
12061210 pixel_values: Optional image tensor for multimodal models (e.g., LLaVA, Gemma3).
12071211 The tensor is passed directly to the underlying HuggingFace model.
@@ -1215,6 +1219,14 @@ def forward(
12151219 Model output based on return_type
12161220 """
12171221
1222+ if start_at_layer is not None :
1223+ raise NotImplementedError (
1224+ "start_at_layer is not supported in TransformerBridge. "
1225+ "The bridge delegates to HuggingFace's model.forward() which controls "
1226+ "the layer iteration loop. See the TransformerBridge review plan for a "
1227+ "detailed analysis of implementation approaches and their tradeoffs."
1228+ )
1229+
12181230 # Set stop_at_layer flag on all blocks if requested
12191231 if stop_at_layer is not None and hasattr (self , "blocks" ):
12201232 for block in self .blocks :
@@ -1382,11 +1394,18 @@ def forward(
13821394 # Execution stopped at the requested layer
13831395 return e .layer_output
13841396 finally :
1385- # Clean up the stop_at_layer flag on all blocks
1397+ # Clean up state that may be inconsistent after StopAtLayerException
13861398 if stop_at_layer is not None and hasattr (self , "blocks" ):
1399+ # Reset the stop flag on all blocks
13871400 for block in self .blocks :
13881401 block ._stop_at_layer_idx = None
13891402
1403+ # Clear any stale KV cache — layers after the stop point didn't
1404+ # execute, so the cache is incomplete and would corrupt subsequent
1405+ # generate() calls that expect a full cache.
1406+ if hasattr (self , "_last_hf_cache" ):
1407+ del self ._last_hf_cache
1408+
13901409 def get_hook_point (self , hook_name : str ) -> Optional [HookPoint ]:
13911410 """Get a hook point by name from the bridge's hook system."""
13921411 if hook_name in self ._hook_registry :
@@ -1465,7 +1484,7 @@ def run_with_cache(
14651484 return_cache_object: Whether to return ActivationCache object
14661485 remove_batch_dim: Whether to remove batch dimension
14671486 names_filter: Filter for which activations to cache (str, list of str, or callable)
1468- stop_at_layer: Layer to stop forward pass at (not yet fully implemented )
1487+ stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop )
14691488 **kwargs: Additional arguments
14701489 # type: ignore[name-defined]
14711490 Returns:
@@ -1660,7 +1679,7 @@ def run_with_hooks(
16601679 clear_contexts: Whether to clear hook contexts
16611680 return_type: What to return ("logits", "loss", etc.)
16621681 names_filter: Filter for hook names (not used directly, for compatibility)
1663- stop_at_layer: Layer to stop at (not yet fully implemented )
1682+ stop_at_layer: Layer to stop at (uses StopAtLayerException; cleans up KV cache on stop )
16641683 remove_batch_dim: Whether to remove batch dimension from hook inputs (only works for batch_size==1)
16651684 **kwargs: Additional arguments
16661685
@@ -1806,8 +1825,12 @@ def generate(
18061825 repetition by dividing positive logits and multiplying negative logits for
18071826 previously seen tokens. Default 1.0 (no penalty).
18081827 use_past_kv_cache: If True, use KV caching for faster generation
1809- prepend_bos: Not used in Bridge (kept for API compatibility)
1810- padding_side: Not used in Bridge (kept for API compatibility)
1828+ prepend_bos: Accepted for API compatibility but not applied during generation.
1829+ The HF model expects tokens in its native format (tokenizer defaults).
1830+ Overriding BOS can silently degrade generation quality.
1831+ padding_side: Accepted for API compatibility but not applied during generation.
1832+ The generation loop always extends tokens to the right, so overriding
1833+ initial padding_side creates inconsistent token layout.
18111834 return_type: The type of output to return - 'input', 'str', or 'tokens'
18121835 verbose: Not used in Bridge (kept for API compatibility)
18131836 output_logits: If True, return a ModelOutput with sequences and logits tuple
@@ -1819,7 +1842,30 @@ def generate(
18191842 Generated sequence as string, list of strings, or tensor depending on input type and return_type.
18201843 If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes.
18211844 """
1822- # Convert input to tokens using to_tokens() for consistent special token handling
1845+ # prepend_bos and padding_side are intentionally not applied during generation.
1846+ # The HF model expects tokens in its native format. Overriding BOS can silently
1847+ # degrade quality, and overriding padding_side conflicts with the generation loop
1848+ # which always extends tokens to the right.
1849+ if prepend_bos is not None :
1850+ import warnings
1851+
1852+ warnings .warn (
1853+ "prepend_bos is ignored during TransformerBridge.generate(). "
1854+ "The HF model expects tokens with the tokenizer's default BOS handling. "
1855+ "To control BOS, tokenize with to_tokens(prepend_bos=...) and pass the "
1856+ "resulting tensor to generate()." ,
1857+ stacklevel = 2 ,
1858+ )
1859+ if padding_side is not None :
1860+ import warnings
1861+
1862+ warnings .warn (
1863+ "padding_side is ignored during TransformerBridge.generate(). "
1864+ "The generation loop extends tokens to the right regardless of initial "
1865+ "padding. To control padding, tokenize with to_tokens(padding_side=...) "
1866+ "and pass the resulting tensor to generate()." ,
1867+ stacklevel = 2 ,
1868+ )
18231869 _generate_from_embeds = False
18241870 if isinstance (input , str ):
18251871 input_tokens = self .to_tokens (input , move_to_device = True , truncate = False )
0 commit comments