diff --git a/early_exit/patching/attention_mixins/base.py b/early_exit/patching/attention_mixins/base.py index 93d33e5..9bed756 100644 --- a/early_exit/patching/attention_mixins/base.py +++ b/early_exit/patching/attention_mixins/base.py @@ -21,7 +21,8 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.self_attn.base_self_attn_forward = self.self_attn.forward - self.self_attn.patched_self_attn_forward = MethodType(self.patched_attention_forward, self.self_attn) + self.self_attn.patched_self_attn_forward = self.self_attn.forward + # self.self_attn.patched_self_attn_forward = MethodType(self.patched_attention_forward, self.self_attn) @abstractmethod def patched_layer_forward(self, hidden_states: _T, *_, unfrozen_idx_or_mask: List[int], **kwargs): diff --git a/early_exit/patching/attention_mixins/qwen2.py b/early_exit/patching/attention_mixins/qwen2.py index e1e581b..b29dbe0 100644 --- a/early_exit/patching/attention_mixins/qwen2.py +++ b/early_exit/patching/attention_mixins/qwen2.py @@ -18,7 +18,7 @@ class Qwen2DecoderLayerFakeAttentionForwardMixin(LayerFakeAttentionForwardMixin): - def patched_layer_forward( + def patched_layer_forward_old( self, hidden_states: _T, attention_mask: Optional[_T] = None, @@ -103,7 +103,6 @@ def patched_layer_forward( return outputs - @staticmethod def patched_attention_forward( self: Qwen2Attention, @@ -255,5 +254,74 @@ def patched_attention_forward( return attn_output_with_zeros, attn_weights_with_zeros, past_key_value - - + def patched_layer_forward( + self, + hidden_states: _T, + attention_mask: Optional[_T] = None, + position_ids: Optional[_LT] = None, + past_key_value: Optional[Tuple[_T]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[_LT] = None, + position_embeddings: Optional[Tuple[_T, _T]] = None, + unfrozen_idx_or_mask: Optional[List[int] | _T] = None, + **kwargs, + ) -> Tuple[_FT, Optional[Tuple[_FT, _FT]]]: + + # Store original hidden states + original_hidden_states = hidden_states.clone() + + bsz, q_len, _ = hidden_states.size() + + # Process unfrozen mask + if isinstance(unfrozen_idx_or_mask, list): + unfrozen_mask = torch.zeros(bsz, dtype=torch.bool, device=hidden_states.device) + if len(unfrozen_idx_or_mask) > 0: + unfrozen_mask[unfrozen_idx_or_mask] = True + unfrozen_elements = unfrozen_mask + elif isinstance(unfrozen_idx_or_mask, _T): + gen_len = unfrozen_idx_or_mask.shape[1] + padding_required = q_len - gen_len + unfrozen_elements = F.pad( + input=unfrozen_idx_or_mask, + pad=(padding_required, 0), + value=True # Pre-rollout (prompt) residual stream never gets frozen + ).to(hidden_states.device) + elif unfrozen_idx_or_mask is None: + unfrozen_elements = torch.ones(bsz, dtype=torch.bool, device=hidden_states.device) + + # Call parent's forward method + outputs = super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs + ) + + # Extract hidden states from outputs + if isinstance(outputs, tuple): + new_hidden_states = outputs[0] + other_outputs = outputs[1:] + else: + new_hidden_states = outputs + other_outputs = () + + # Apply selective update: only update unfrozen elements + final_hidden_states = torch.where( + unfrozen_elements.unsqueeze(-1), # Expand mask to match hidden dimension + new_hidden_states, # Use new values where unfrozen + original_hidden_states # Keep original values where frozen + ) + + # Reconstruct outputs with updated hidden states + final_outputs = (final_hidden_states,) + other_outputs + + # Debug assertion + assert (original_hidden_states == final_hidden_states)[~unfrozen_elements].all() + + return final_outputs