-
Notifications
You must be signed in to change notification settings - Fork 93
Expand file tree
/
Copy pathreranking.py
More file actions
170 lines (122 loc) · 4.47 KB
/
reranking.py
File metadata and controls
170 lines (122 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from helper_utils import word_wrap, load_chroma
from pypdf import PdfReader
import os
from openai import OpenAI
from dotenv import load_dotenv
from pypdf import PdfReader
import numpy as np
from langchain_community.document_loaders import PyPDFLoader
# Load environment variables from .env file
load_dotenv()
openai_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=openai_key)
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
embedding_function = SentenceTransformerEmbeddingFunction()
reader = PdfReader("data/microsoft-annual-report.pdf")
pdf_texts = [p.extract_text().strip() for p in reader.pages]
# Filter the empty strings
pdf_texts = [text for text in pdf_texts if text]
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
SentenceTransformersTokenTextSplitter,
)
character_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0
)
character_split_texts = character_splitter.split_text("\n\n".join(pdf_texts))
token_splitter = SentenceTransformersTokenTextSplitter(
chunk_overlap=0, tokens_per_chunk=256
)
token_split_texts = []
for text in character_split_texts:
token_split_texts += token_splitter.split_text(text)
chroma_client = chromadb.Client()
chroma_collection = chroma_client.get_or_create_collection(
"microsoft-collect", embedding_function=embedding_function
)
# extract the embeddings of the token_split_texts
ids = [str(i) for i in range(len(token_split_texts))]
chroma_collection.add(ids=ids, documents=token_split_texts)
count = chroma_collection.count()
query = "What has been the investment in research and development?"
results = chroma_collection.query(
query_texts=query, n_results=10, include=["documents", "embeddings"]
)
retrieved_documents = results["documents"][0]
for document in results["documents"][0]:
print(word_wrap(document))
print("")
from sentence_transformers import CrossEncoder
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
pairs = [[query, doc] for doc in retrieved_documents]
scores = cross_encoder.predict(pairs)
print("Scores:")
for score in scores:
print(score)
print("New Ordering:")
for o in np.argsort(scores)[::-1]:
print(o + 1)
original_query = (
"What were the most important factors that contributed to increases in revenue?"
)
generated_queries = [
"What were the major drivers of revenue growth?",
"Were there any new product launches that contributed to the increase in revenue?",
"Did any changes in pricing or promotions impact the revenue growth?",
"What were the key market trends that facilitated the increase in revenue?",
"Did any acquisitions or partnerships contribute to the revenue growth?",
]
# concatenate the original query with the generated queries
queries = [original_query] + generated_queries
results = chroma_collection.query(
query_texts=queries, n_results=10, include=["documents", "embeddings"]
)
retrieved_documents = results["documents"]
# Deduplicate the retrieved documents
unique_documents = set()
for documents in retrieved_documents:
for document in documents:
unique_documents.add(document)
unique_documents = list(unique_documents)
pairs = []
for doc in unique_documents:
pairs.append([original_query, doc])
scores = cross_encoder.predict(pairs)
print("Scores:")
for score in scores:
print(score)
print("New Ordering:")
for o in np.argsort(scores)[::-1]:
print(o)
# ====
top_indices = np.argsort(scores)[::-1][:5]
top_documents = [unique_documents[i] for i in top_indices]
# Concatenate the top documents into a single context
context = "\n\n".join(top_documents)
# Generate the final answer using the OpenAI model
def generate_multi_query(query, context, model="gpt-3.5-turbo"):
prompt = f"""
You are a knowledgeable financial research assistant.
Your users are inquiring about an annual report.
"""
messages = [
{
"role": "system",
"content": prompt,
},
{
"role": "user",
"content": f"based on the following context:\n\n{context}\n\nAnswer the query: '{query}'",
},
]
response = client.chat.completions.create(
model=model,
messages=messages,
)
content = response.choices[0].message.content
content = content.split("\n")
return content
res = generate_multi_query(query=original_query, context=context)
print("Final Answer:")
print(res)