Skip to content

Commit bd31e43

Browse files
committed
Feature(LLMLingua): add KV-Cache Compression & HF Space Demo
1 parent b44fd13 commit bd31e43

File tree

3 files changed

+69
-11
lines changed

3 files changed

+69
-11
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,22 @@
22
<img src="images/LLMLingua_logo.png" alt="LLMLingua" style="width: 20%; min-width: 100px; display: block; margin: auto;">
33
</p>
44

5-
# LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models [[paper]()] & LongLLMLingua [[paper]()]
5+
# LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models [paper] & LongLLMLingua [paper]
66

77
https://github.com/microsoft/LLMLingua/assets/30883354/ef52995c-ef3c-4eac-a9fd-1acb491c325b
88

9+
You can try the LLMLingua demo in [HF Space](https://huggingface.co/spaces/microsoft/LLMLingua).
10+
911
## Tl;DR
1012

1113
LLMLingua, that uses a well-trained small language model after alignment, such as GPT2-small or LLaMA-7B, to detect the unimportant tokens in the prompt and enable inference with the compressed prompt in black-box LLMs, achieving up to 20x compression with minimal performance loss.
1214

13-
[LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models]() (EMNLP 2023).
15+
LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models (EMNLP 2023).<br>
1416
_Huiqiang Jiang, Qianhui Wu, Chin-Yew Lin, Yuqing Yang and Lili Qiu_
1517

1618
LongLLMLingua is a method that enhances LLMs' ability to perceive key information in long-context scenarios using prompt compression, achieveing up to $28.5 in cost savings per 1,000 samples while also improving performance.
1719

18-
[LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression]() (Under Review).
20+
LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression (Under Review).<br>
1921
_Huiqiang Jiang, Qianhui Wu, Xufang Luo, Dongsheng Li, Chin-Yew Lin, Yuqing Yang and Lili Qiu_
2022

2123
## 🎥 Overview

llmlingua/prompt_compressor.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ def __init__(
2424
self.load_model(model_name, device_map, use_auth_token)
2525
self.sbert = None
2626
self.open_api_config = open_api_config
27+
self.cache_bos_num = 10
2728

2829
def load_model(
2930
self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False
3031
):
31-
config = AutoConfig.from_pretrained(model_name)
32-
tokenizer = AutoTokenizer.from_pretrained(model_name)
32+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
33+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
3334
tokenizer.padding_side = "left"
3435
tokenizer.pad_token_id = (
3536
config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
@@ -40,12 +41,11 @@ def load_model(
4041
if "cuda" in device_map or "cpu" in device_map:
4142
model = AutoModelForCausalLM.from_pretrained(
4243
model_name,
43-
torch_dtype="auto",
44+
torch_dtype="auto" if device_map == "cuda" else torch.float32,
4445
config=config,
4546
ignore_mismatched_sizes=True,
47+
trust_remote_code=True,
4648
).to(device_map)
47-
if device_map == "cpu":
48-
model = model.type(torch.float32)
4949
else:
5050
model = AutoModelForCausalLM.from_pretrained(
5151
model_name,
@@ -56,10 +56,12 @@ def load_model(
5656
offload_state_dict=True,
5757
cache_dir="/tmp/cache",
5858
use_auth_token=use_auth_token,
59+
trust_remote_code=True,
5960
)
6061
self.tokenizer = tokenizer
6162
self.model = model
6263
self.context_idxs = []
64+
self.max_position_embeddings = config.max_position_embeddings
6365

6466
def get_ppl(
6567
self,
@@ -83,7 +85,7 @@ def get_ppl(
8385
past_length = 0
8486
if end is None:
8587
end = input_ids.shape[1]
86-
end = min(end, past_length + 4096)
88+
end = min(end, past_length + self.max_position_embeddings)
8789
with torch.no_grad():
8890
response = self.model(
8991
input_ids[:, past_length:end],
@@ -145,11 +147,17 @@ def compress_prompt(
145147
assert not (
146148
rank_method == "longllmlingua" and not question
147149
), "In the LongLLMLingua, it is necessary to set a question."
150+
if condition_compare and "_condition" not in condition_in_question:
151+
condition_in_question += "_condition"
148152
if rank_method == "longllmlingua":
149153
if condition_in_question == "none":
150154
condition_in_question = "after"
151155
elif rank_method == "llmlingua":
152-
condition_in_question = "none"
156+
condition_in_question = (
157+
"none"
158+
if "_condition" not in condition_in_question
159+
else "none_condition"
160+
)
153161
origin_tokens = len(
154162
encoding.encode("\n\n".join([instruction] + context + [question]).strip())
155163
)
@@ -653,8 +661,52 @@ def iterative_compress_prompt(
653661
keep_flag = torch.tensor(keep_flag).to(self.device)
654662
past_key_values, past_loss, ready_end = None, None, 0
655663
self_past_key_values, self_past_loss, self_ready_end = None, None, 0
664+
pop_compressed_input_ids, pop_self_compressed_input_ids = None, None
656665
idx = 0
657666
while end <= compressed_input_ids.shape[1]:
667+
if end > self.max_position_embeddings and past_key_values is not None:
668+
# KV-Cache Compression
669+
e, s = end - self.max_position_embeddings, self.cache_bos_num
670+
if pop_compressed_input_ids is None:
671+
pop_compressed_input_ids = compressed_input_ids[:, :e]
672+
else:
673+
pop_compressed_input_ids = torch.cat(
674+
[pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1
675+
)
676+
compressed_input_ids = compressed_input_ids[:, e:]
677+
compressed_attention_mask = compressed_attention_mask[:, e:]
678+
past_key_values = [
679+
[
680+
torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
681+
torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
682+
]
683+
for k, v in past_key_values
684+
]
685+
end, ready_end = end - e, ready_end - e
686+
if condition_compare:
687+
self_ready_end -= e
688+
if pop_self_compressed_input_ids is None:
689+
pop_self_compressed_input_ids = self_compressed_input_ids[:, :e]
690+
else:
691+
pop_self_compressed_input_ids = torch.cat(
692+
[
693+
pop_self_compressed_input_ids,
694+
self_compressed_input_ids[:, :e],
695+
],
696+
dim=-1,
697+
)
698+
self_compressed_input_ids = self_compressed_input_ids[:, e:]
699+
self_compressed_attention_mask = self_compressed_attention_mask[
700+
:, e:
701+
]
702+
self_past_key_values = [
703+
[
704+
torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
705+
torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
706+
]
707+
for k, v in self_past_key_values
708+
]
709+
658710
loss, past_key_values = self.get_ppl(
659711
"",
660712
"token",
@@ -762,6 +814,10 @@ def iterative_compress_prompt(
762814
)
763815
end += iterative_size
764816
idx += 1
817+
if pop_compressed_input_ids is not None:
818+
compressed_input_ids = torch.cat(
819+
[pop_compressed_input_ids, compressed_input_ids], dim=-1
820+
)
765821
return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]
766822

767823
def recover(

llmlingua/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
_MINOR = "1"
33
# On master and in a nightly release the patch should be one ahead of the last
44
# released build.
5-
_PATCH = "1"
5+
_PATCH = "2"
66
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
77
# https://semver.org/#is-v123-a-semantic-version for the semantics.
88
_SUFFIX = ""

0 commit comments

Comments
 (0)