Skip to content

Conversation

@karthikviswanathn
Copy link
Collaborator

Changes to make patched_layer_forward() faster. This is done by calling

outputs = super().forward(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs
        )

and

torch.where(
            unfrozen_elements.unsqueeze(-1),  # Expand mask to match hidden dimension
            new_hidden_states,                 # Use new values where unfrozen
            original_hidden_states             # Keep original values where frozen
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants