Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions config_deepseek.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
model:
load_precision_mode: none # Options: [none,"4bits","8bits","full"]
lora: false

lora:
r: 8
lora_alpha: 16
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
lora_dropout: 0.05
bias: "none" #sticking to example : https://huggingface.co/docs/trl/main/en/peft_integration


generation:
min_length: 0
# top_k: 0.0
# top_p: 0.95
num_beams: 1
# temperature: 0.6
do_sample: true
#pad_token_id: 0
bos_token_id: 151646
eos_token_id: 151643
max_new_tokens: 400
# return_prompt: False
# generate_ref_response: False
use_cache: true
56 changes: 34 additions & 22 deletions early_exit/patching/attention_mixins/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,31 @@ def patched_layer_forward(

elif unfrozen_idx_or_mask is None:
unfrozen_elements = torch.arange(bsz)

residual = hidden_states
residual = hidden_states.clone()

hidden_states[unfrozen_elements] = self.input_layernorm(hidden_states[unfrozen_elements])

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
unfrozen_idx_or_mask=unfrozen_idx_or_mask # Key change
)
attention_output = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
unfrozen_idx_or_mask=unfrozen_idx_or_mask # Key change
)
### TODO: The following statements are very hacky. Please change it in the future versions
if len(attention_output) == 3:
hidden_states, self_attn_weights, present_key_value = attention_output
else:
hidden_states, self_attn_weights = attention_output
present_key_value = None
hidden_states = residual + hidden_states


# Fully Connected
residual = hidden_states
residual = hidden_states.clone()
hidden_states[unfrozen_elements] = self.post_attention_layernorm(hidden_states[unfrozen_elements])
hidden_states[unfrozen_elements] = self.mlp(hidden_states[unfrozen_elements])
hidden_states[unfrozen_elements] = residual[unfrozen_elements] + hidden_states[unfrozen_elements]
Expand All @@ -82,10 +85,11 @@ def patched_layer_forward(
if output_attentions:
outputs += (self_attn_weights,)

if use_cache:
if use_cache and present_key_value is not None:
outputs += (present_key_value,)

assert (_original_hidden_states == hidden_states)[~unfrozen_elements].all()
# Removing this line from main because I am not sure if ~unfrozen_elements is the right way to access things
# assert (_original_hidden_states == hidden_states)[~unfrozen_elements].all()

return outputs

Expand Down Expand Up @@ -121,13 +125,21 @@ def patched_attention_forward(
cache_position = cache_position,
position_embeddings = position_embeddings,
)

bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

# # print("Number of attention heads = ", self.config)if not hasattr(self, "num_heads"):
if not hasattr(self, "num_heads"):
self.num_heads = self.config.num_attention_heads
if not hasattr(self, "num_key_value_heads"):
self.num_key_value_heads = self.config.num_key_value_heads
if not hasattr(self, "hidden_size"):
self.hidden_size = self.config.hidden_size

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
Expand Down Expand Up @@ -181,7 +193,7 @@ def patched_attention_forward(
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (len(unfrozen_batch_idx), self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
Expand Down Expand Up @@ -221,7 +233,7 @@ def patched_attention_forward(

if not output_attentions:
attn_weights = None

return attn_output_with_zeros, attn_weights_with_zeros, past_key_value


Expand Down
4 changes: 3 additions & 1 deletion early_exit/patching/method_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def patched_forward_generation(self: EarlyExitModelMixin | PeftModelForCausalLM,
assert self.early_exit_mode == 'free_generate'
for name, module in self.named_modules():
if module_name_is_layer_base(name):
print(name)
# print(name)
assert module.early_exit_mode == 'free_generate'
module.exit_state = exit_state

Expand Down Expand Up @@ -255,5 +255,7 @@ def replace_attention_layers(model: AutoModelForCausalLM, lora_config_dict: dict
for name, param in model.named_parameters():
if 'lora' in name:
param.requires_grad = True
if 'early_exit_decision_weights' in name:
param.requires_grad = True

return model.to(device)
11 changes: 5 additions & 6 deletions early_exit/sft_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader

from shared_util.data import CSVPromptDataset
from shared_util.load import get_model, get_tokenizer, configs_from_yaml
from shared_util.generate import generate_text
from shared_utils.data import CSVPromptDataset
from shared_utils.load import get_model, get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text

from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

Expand Down Expand Up @@ -53,7 +53,7 @@


run = wandb.init(
entity="cot-mrc",
# entity="cot-mrc",
project="early-exit",
config=dict(
**config,
Expand Down Expand Up @@ -160,5 +160,4 @@
})

assert len(prompt_batch.idx) == 1, "Again, batch greater than 1 not allowed yet"
wandb.log(log_dict)

wandb.log(log_dict)
1 change: 0 additions & 1 deletion results_and_data

This file was deleted.

Loading