Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion early_exit/patching/attention_mixins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.self_attn.base_self_attn_forward = self.self_attn.forward
self.self_attn.patched_self_attn_forward = MethodType(self.patched_attention_forward, self.self_attn)
self.self_attn.patched_self_attn_forward = self.self_attn.forward
# self.self_attn.patched_self_attn_forward = MethodType(self.patched_attention_forward, self.self_attn)

@abstractmethod
def patched_layer_forward(self, hidden_states: _T, *_, unfrozen_idx_or_mask: List[int], **kwargs):
Expand Down
76 changes: 72 additions & 4 deletions early_exit/patching/attention_mixins/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

class Qwen2DecoderLayerFakeAttentionForwardMixin(LayerFakeAttentionForwardMixin):

def patched_layer_forward(
def patched_layer_forward_old(
self,
hidden_states: _T,
attention_mask: Optional[_T] = None,
Expand Down Expand Up @@ -103,7 +103,6 @@ def patched_layer_forward(

return outputs


@staticmethod
def patched_attention_forward(
self: Qwen2Attention,
Expand Down Expand Up @@ -255,5 +254,74 @@ def patched_attention_forward(

return attn_output_with_zeros, attn_weights_with_zeros, past_key_value



def patched_layer_forward(
self,
hidden_states: _T,
attention_mask: Optional[_T] = None,
position_ids: Optional[_LT] = None,
past_key_value: Optional[Tuple[_T]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[_LT] = None,
position_embeddings: Optional[Tuple[_T, _T]] = None,
unfrozen_idx_or_mask: Optional[List[int] | _T] = None,
**kwargs,
) -> Tuple[_FT, Optional[Tuple[_FT, _FT]]]:

# Store original hidden states
original_hidden_states = hidden_states.clone()

bsz, q_len, _ = hidden_states.size()

# Process unfrozen mask
if isinstance(unfrozen_idx_or_mask, list):
unfrozen_mask = torch.zeros(bsz, dtype=torch.bool, device=hidden_states.device)
if len(unfrozen_idx_or_mask) > 0:
unfrozen_mask[unfrozen_idx_or_mask] = True
unfrozen_elements = unfrozen_mask
elif isinstance(unfrozen_idx_or_mask, _T):
gen_len = unfrozen_idx_or_mask.shape[1]
padding_required = q_len - gen_len
unfrozen_elements = F.pad(
input=unfrozen_idx_or_mask,
pad=(padding_required, 0),
value=True # Pre-rollout (prompt) residual stream never gets frozen
).to(hidden_states.device)
elif unfrozen_idx_or_mask is None:
unfrozen_elements = torch.ones(bsz, dtype=torch.bool, device=hidden_states.device)

# Call parent's forward method
outputs = super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs
)

# Extract hidden states from outputs
if isinstance(outputs, tuple):
new_hidden_states = outputs[0]
other_outputs = outputs[1:]
else:
new_hidden_states = outputs
other_outputs = ()

# Apply selective update: only update unfrozen elements
final_hidden_states = torch.where(
unfrozen_elements.unsqueeze(-1), # Expand mask to match hidden dimension
new_hidden_states, # Use new values where unfrozen
original_hidden_states # Keep original values where frozen
)

# Reconstruct outputs with updated hidden states
final_outputs = (final_hidden_states,) + other_outputs

# Debug assertion
assert (original_hidden_states == final_hidden_states)[~unfrozen_elements].all()

return final_outputs