From 691e7da11002eba47109d03e09acae281d562f09 Mon Sep 17 00:00:00 2001 From: Kaminyou Date: Sun, 16 Nov 2025 19:43:49 +0000 Subject: [PATCH] fix incorrect cache length --- main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index b609ce5..01c9da1 100644 --- a/main.py +++ b/main.py @@ -33,6 +33,7 @@ print("Saving KV cache for doc: ", doc_id) text = data[doc_id]['prompt'] input_ids = tokenizer(text, return_tensors="pt").input_ids.cuda() + token_length = input_ids.shape[1] # print("Length of input: ", input_ids.shape) st = time.monotonic() @@ -44,8 +45,8 @@ key_value = [] for i in range(len(kv)): kv[i] = list(kv[i]) - kv[i][0] = kv[i][0][:, :, :-1][0] - kv[i][1] = kv[i][1][:, :, :-1][0] + kv[i][0] = kv[i][0][:, :, :token_length][0] + kv[i][1] = kv[i][1][:, :, :token_length][0] kv[i] = tuple(kv[i]) kv = tuple(kv) kv_tensor = to_blob(kv)