diff --git a/main.py b/main.py index b609ce5..01c9da1 100644 --- a/main.py +++ b/main.py @@ -33,6 +33,7 @@ print("Saving KV cache for doc: ", doc_id) text = data[doc_id]['prompt'] input_ids = tokenizer(text, return_tensors="pt").input_ids.cuda() + token_length = input_ids.shape[1] # print("Length of input: ", input_ids.shape) st = time.monotonic() @@ -44,8 +45,8 @@ key_value = [] for i in range(len(kv)): kv[i] = list(kv[i]) - kv[i][0] = kv[i][0][:, :, :-1][0] - kv[i][1] = kv[i][1][:, :, :-1][0] + kv[i][0] = kv[i][0][:, :, :token_length][0] + kv[i][1] = kv[i][1][:, :, :token_length][0] kv[i] = tuple(kv[i]) kv = tuple(kv) kv_tensor = to_blob(kv)