diff --git a/App.py b/App.py
index c044d27..0bf20b9 100644
--- a/App.py
+++ b/App.py
@@ -47,12 +47,12 @@
# model_name="gpt-4o-mini", max_tokens=MAX_TOKENS, temperature=TEMPERATURE
# )
LARGE_LANGUAGE_MODEL = AzureChatOpenAI(
- azure_deployment="gpt-4o-mini",
+ azure_deployment="gpt-4o-mini",
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_SWEDEN"],
api_key=os.environ["AZURE_OPENAI_API_KEY_SWEDEN"],
- api_version=os.environ["AZURE_OPENAI_API_VERSION_SWEDEN"],
- max_tokens=MAX_TOKENS,
- temperature=TEMPERATURE
+ api_version=os.environ["AZURE_OPENAI_API_VERSION_SWEDEN"],
+ max_tokens=MAX_TOKENS,
+ temperature=TEMPERATURE,
)
@@ -69,9 +69,7 @@ def load_embedding_model():
@st.cache_resource
def load_db_manifestos():
return VectorDatabase(
- embedding_model=embedding_model,
- source_type="manifestos",
- database_directory=DATABASE_DIR_MANIFESTOS,
+ embedding_model=embedding_model, source_type="manifestos", database_directory=DATABASE_DIR_MANIFESTOS
)
@@ -131,9 +129,7 @@ def load_db_manifestos():
# The keys represent the (random) order of appearance of the parties in the app
# and not fixed parties as opposed to the above party_dict.
- st.session_state.show_individual_parties = {
- f"party_{i+1}": False for i in range(len(st.session_state.parties))
- }
+ st.session_state.show_individual_parties = {f"party_{i + 1}": False for i in range(len(st.session_state.parties))}
# The "example_prompts" dictionary will contain randomly selected example prompts for the user to choose from:
if "example_prompts" not in st.session_state:
@@ -146,9 +142,7 @@ def load_db_manifestos():
all_example_prompts[key] = []
all_example_prompts[key].append(value)
- st.session_state.example_prompts = {
- key: random.sample(value, 3) for key, value in all_example_prompts.items()
- }
+ st.session_state.example_prompts = {key: random.sample(value, 3) for key, value in all_example_prompts.items()}
if "number_of_requests" not in st.session_state:
st.session_state.number_of_requests = 0
@@ -183,7 +177,7 @@ def img_to_bytes(img_path):
def img_to_html(img_path):
- img_html = f""
+ img_html = f"
"
return img_html
@@ -194,9 +188,7 @@ def submit_query():
st.session_state.feedback_text = ""
st.session_state.feedback_submitted = False
st.session_state.stage = 1
- st.session_state.show_individual_parties = {
- f"party_{i+1}": False for i in range(len(st.session_state.parties))
- }
+ st.session_state.show_individual_parties = {f"party_{i + 1}": False for i in range(len(st.session_state.parties))}
random.shuffle(st.session_state.parties)
@@ -218,9 +210,9 @@ def generate_response():
st.session_state.response = rag.query(query)
# Assert that the response contains all parties
- assert set(st.session_state.response["answer"].keys()) == set(
- st.session_state.parties
- ), "LLM response does not contain all parties"
+ assert set(st.session_state.response["answer"].keys()) == set(st.session_state.parties), (
+ "LLM response does not contain all parties"
+ )
break
except Exception as e:
@@ -230,33 +222,19 @@ def generate_response():
if retry_count > max_retries:
print(f"Max number of tries ({max_retries}) reached, aborting")
st.session_state.response = None
- st.error(
- translate(
- "error-api-unavailable",
- st.session_state.language,
- )
- )
+ st.error(translate("error-api-unavailable", st.session_state.language))
# Display error message in app:
raise e
else:
print(f"Retrying, retry number {retry_count}")
pass
st.session_state.log_id = add_log_dict(
- {
- "query": st.session_state.query,
- "answer": st.session_state.response["answer"],
- }
+ {"query": st.session_state.query, "answer": st.session_state.response["answer"]}
)
def submit_feedback(feedback_rating, feedback_text):
- update_log_dict(
- st.session_state.log_id,
- {
- "feedback-rating": feedback_rating,
- "feedback-text": feedback_text,
- },
- )
+ update_log_dict(st.session_state.log_id, {"feedback-rating": feedback_rating, "feedback-text": feedback_text})
st.rerun()
@@ -277,24 +255,13 @@ def convert_date_format(date_string):
##################################
with st.sidebar:
- selected_language = st.radio(
- label="Language",
- options=["🇩🇪 Deutsch", "🇬🇧 English"],
- horizontal=True,
- )
+ selected_language = st.radio(label="Language", options=["🇩🇪 Deutsch", "🇬🇧 English"], horizontal=True)
languages = {"🇩🇪 Deutsch": "de", "🇬🇧 English": "en"}
st.session_state.language = languages[selected_language]
rag.language = st.session_state.language
st.header("🗳️ electify.eu", divider="blue")
-st.write(
- "##### :grey["
- + translate(
- "subheadline",
- st.session_state.language,
- )
- + "]"
-)
+st.write("##### :grey[" + translate("subheadline", st.session_state.language) + "]")
support_button(
text=f"💙 {translate('support-button', st.session_state.language)}",
@@ -304,38 +271,25 @@ def convert_date_format(date_string):
if st.session_state.number_of_requests >= 3:
# Show support banner after 3 requests in a single session.
st.info(
- f"{translate('support-banner', st.session_state.language)}(https://buymeacoffee.com/electify.eu)",
- icon="💙",
+ f"{translate('support-banner', st.session_state.language)}(https://buymeacoffee.com/electify.eu)", icon="💙"
)
query = st.text_input(
- label=translate(
- "query-instruction",
- st.session_state.language,
- ),
- placeholder="",
- value=st.session_state.query,
+ label=translate("query-instruction", st.session_state.language), placeholder="", value=st.session_state.query
)
col_submit, col_checkbox = st.columns([1, 3])
# Submit button
with col_submit:
- st.button(
- translate("submit-query", st.session_state.language),
- on_click=submit_query,
- type="primary",
- )
+ st.button(translate("submit-query", st.session_state.language), on_click=submit_query, type="primary")
# Checkbox to show/hide party names globally
with col_checkbox:
st.session_state.show_all_parties = st.checkbox(
label=translate("show-party-names", st.session_state.language),
value=True,
- help=translate(
- "show-party-names-help",
- st.session_state.language,
- ),
+ help=translate("show-party-names-help", st.session_state.language),
)
# Allow the user to select up to 6 parties
@@ -350,12 +304,7 @@ def update_party_selection(party):
party_selection[party] = not party_selection[party]
st.session_state.parties = [k for k, v in party_selection.items() if v]
- st.write(
- translate(
- "select-parties-instruction",
- st.session_state.language,
- )
- )
+ st.write(translate("select-parties-instruction", st.session_state.language))
for party in available_parties:
st.checkbox(
label=party_dict[party]["name"],
@@ -365,15 +314,11 @@ def update_party_selection(party):
)
if len(st.session_state.parties) == 0:
- st.markdown(
- f"⚠️ **:red[{translate('error-min-1-party', st.session_state.language)}]**"
- )
+ st.markdown(f"⚠️ **:red[{translate('error-min-1-party', st.session_state.language)}]**")
# Reset to default parties
st.session_state.parties = rag.parties
elif len(st.session_state.parties) > 6:
- st.markdown(
- f"⚠️ **:red[{translate('error-max-6-parties', st.session_state.language)}]**"
- )
+ st.markdown(f"⚠️ **:red[{translate('error-max-6-parties', st.session_state.language)}]**")
# Limit to the six first selected parties
st.session_state.parties = st.session_state.parties[:6]
@@ -395,43 +340,23 @@ def update_party_selection(party):
# STAGE > 0: Show disclaimer once the user has submitted a query (and keep showing it)
if st.session_state.stage > 0:
if len(st.session_state.parties) == 0:
- st.error(
- translate(
- "error-min-1-party",
- st.session_state.language,
- )
- )
+ st.error(translate("error-min-1-party", st.session_state.language))
st.session_state.stage = 0
else:
st.info(
"☝️ "
- + translate(
- "disclaimer-llm",
- st.session_state.language,
- )
+ + translate("disclaimer-llm", st.session_state.language)
+ " \n"
- + translate(
- "disclaimer-research",
- st.session_state.language,
- )
+ + translate("disclaimer-research", st.session_state.language)
+ " \n"
- + translate(
- "disclaimer-random-order",
- st.session_state.language,
- ),
+ + translate("disclaimer-random-order", st.session_state.language)
)
# STAGE 1: User submitted a query and we are waiting for the response
if st.session_state.stage == 1:
st.session_state.number_of_requests += 1
- with st.spinner(
- translate(
- "loading-response",
- st.session_state.language,
- )
- + "🕵️"
- ):
+ with st.spinner(translate("loading-response", st.session_state.language) + "🕵️"):
generate_response()
st.session_state.stage = 2
@@ -439,7 +364,6 @@ def update_party_selection(party):
# STAGE > 1: The response has been generated and is displayed
if st.session_state.stage > 1:
-
# Initialize an empty list to hold all columns
col_list = []
# Create a pair of columns for each party
@@ -451,14 +375,11 @@ def update_party_selection(party):
p = i + 1
col1, col2 = col_list[i]
- most_relevant_manifesto_page_number = st.session_state.response["docs"][
- "manifestos"
- ][party][0].metadata["page"]
+ most_relevant_manifesto_page_number = st.session_state.response["docs"]["manifestos"][party][0].metadata[
+ "page"
+ ]
- show_party = (
- st.session_state.show_all_parties
- or st.session_state.show_individual_parties[f"party_{p}"]
- )
+ show_party = st.session_state.show_all_parties or st.session_state.show_individual_parties[f"party_{p}"]
# In this column, we show the party image
with col1:
@@ -470,12 +391,7 @@ def update_party_selection(party):
else:
file_loc = "streamlit_app/assets/placeholder_logo.png"
st.markdown(img_to_html(file_loc), unsafe_allow_html=True)
- st.button(
- translate("show-party", st.session_state.language),
- on_click=reveal_party,
- args=(p,),
- key=p,
- )
+ st.button(translate("show-party", st.session_state.language), on_click=reveal_party, args=(p,), key=p)
# In this column, we show the RAG response
with col2:
if show_party:
@@ -483,17 +399,12 @@ def update_party_selection(party):
else:
st.header(f"{translate('party', st.session_state.language)} {p}")
- if party == "afd":
- st.caption(
- f"⚠️ **{translate("warning-afd", st.session_state.language)}**"
- )
+ if party == "afd" and show_party:
+ st.caption(f"⚠️ **{translate('warning-afd', st.session_state.language)}**")
st.write(st.session_state.response["answer"][party])
if show_party:
- is_answer_empty = (
- "keine passende antwort"
- in st.session_state.response["answer"][party].lower()
- )
+ is_answer_empty = "keine passende antwort" in st.session_state.response["answer"][party].lower()
page_reference_string = (
""
@@ -502,33 +413,17 @@ def update_party_selection(party):
)
st.write(
- f"""{translate('learn-more-in', st.session_state.language)} [{translate('party-manifesto', st.session_state.language)} **{party_dict[party]['name']}**{page_reference_string}]({party_dict[party]['manifesto_link']}#page={most_relevant_manifesto_page_number + 1})."""
+ f"""{translate("learn-more-in", st.session_state.language)} [{translate("party-manifesto", st.session_state.language)} **{party_dict[party]["name"]}**{page_reference_string}]({party_dict[party]["manifesto_link"]}#page={most_relevant_manifesto_page_number + 1})."""
)
st.markdown("---")
# Display a section with all retrieved excerpts from the sources
st.subheader(translate("sources-subheading", st.session_state.language))
- st.write(
- translate(
- "sources-intro",
- st.session_state.language,
- )
- )
- st.write(
- translate(
- "sources-excerpts-intro",
- st.session_state.language,
- )
- )
+ st.write(translate("sources-intro", st.session_state.language))
+ st.write(translate("sources-excerpts-intro", st.session_state.language))
for party in st.session_state.parties:
- with st.expander(
- translate(
- "sources",
- st.session_state.language,
- )
- + f": {party_dict[party]['name']}"
- ):
+ with st.expander(translate("sources", st.session_state.language) + f": {party_dict[party]['name']}"):
for doc in st.session_state.response["docs"]["manifestos"][party]:
manifesto_excerpt = doc.page_content.replace("\n", " ")
page_number_of_excerpt = doc.metadata["page"] + 1
@@ -552,27 +447,19 @@ def update_party_selection(party):
st.subheader(translate("feedback-heading", st.session_state.language))
if not st.session_state.feedback_submitted:
-
st.write(translate("feedback-intro", st.session_state.language))
with st.form(key="feedback-form"):
- feedback_options = {
- "negative": "☹️",
- "neutral": "😐",
- "positive": "😊",
- }
+ feedback_options = {"negative": "☹️", "neutral": "😐", "positive": "😊"}
feedback_rating = st.segmented_control(
label=translate("feedback-question-rating", st.session_state.language),
options=feedback_options.keys(),
format_func=lambda option: feedback_options[option],
)
- feedback_text = st.text_area(
- label=translate("feedback-question-text", st.session_state.language)
- )
+ feedback_text = st.text_area(label=translate("feedback-question-text", st.session_state.language))
submitted = st.form_submit_button(
- label=translate("feedback-submit", st.session_state.language),
- type="primary",
+ label=translate("feedback-submit", st.session_state.language), type="primary"
)
if submitted:
st.session_state.feedback_submitted = True
diff --git a/Dockerfile b/Dockerfile
index 4a844c1..7f2d13a 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -19,8 +19,8 @@ RUN uv sync --frozen --no-dev
# Activate virtual environment
ENV PATH="/app/.venv/bin:$PATH"
-# Copy custom index
-COPY streamlit_app/index.html /usr/local/lib/python3.11/site-packages/streamlit/static/index.html
+# Copy custom index.html with SEO tags
+RUN cp streamlit_app/index.html $(python -c "import streamlit; import os; print(os.path.dirname(streamlit.__file__))")/static/index.html
# Expose port 8080 to world outside of the container
EXPOSE 8080
diff --git a/RAG/database/vector_database.py b/RAG/database/vector_database.py
index e896972..6bc382b 100644
--- a/RAG/database/vector_database.py
+++ b/RAG/database/vector_database.py
@@ -1,143 +1,143 @@
import glob
import os
-from langchain_community.document_loaders import PDFMinerLoader
+from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
class VectorDatabase:
- def __init__(
- self,
- embedding_model,
- source_type, # "manifestos" or "debates"
- data_path=".",
- database_directory="./chroma",
- chunk_size=1000,
- chunk_overlap=200,
- loader="pdf",
- reload=True,
- ):
- """
- Initializes the VectorDatabase.
-
- Parameters:
- - embedding_model: The model used to generate embeddings for the documents.
- - data_directory (str): The directory where the source documents are located. Defaults to the current directory.
- - database_directory (str): The directory to store the Chroma database. Defaults to './chroma'.
- - chunk_size (int): The size of text chunks to split the documents into. Defaults to 1000.
- - chunk_overlap (int): The number of characters to overlap between adjacent chunks. Defaults to 100.
- - loader(str): "pdf" or "csv", depending on data format
- """
-
- self.embedding_model = embedding_model
- self.source_type = source_type
- self.data_path = data_path
- self.database_directory = database_directory
- self.chunk_size = chunk_size
- self.chunk_overlap = chunk_overlap
- self.loader = loader
-
- if reload:
- self.database = self.load_database()
-
- def load_database(self):
- """
- Loads an existing Chroma database.
-
- Returns:
- - The loaded Chroma database.
- """
- if os.path.exists(self.database_directory):
- self.database = Chroma(persist_directory=self.database_directory, embedding_function=self.embedding_model)
- print("reloaded database")
- else:
- raise AssertionError(f"{self.database_directory} does not include database.")
-
- return self.database
-
- def build_database(self, overwrite=True):
- """
- Builds a new Chroma database from the documents in the data directory.
-
- Parameters:
- - loader: Optional, a document loader instance. If None, PyPDFDirectoryLoader will be used with the data_directory.
-
- Returns:
- - The newly built Chroma database.
- """
- # # If overwrite flag is true, remove old databases from directory if they exist
- # if overwrite:
- # if os.path.exists(self.database_directory):
- # shutil.rmtree(self.database_directory)
- # time.sleep(1)
-
- # PDF is the default loader defined above
-
- if os.path.exists(self.database_directory):
- raise AssertionError("Delete old database first and restart session!")
-
- # Define text_splitter
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
-
- if self.loader == "pdf":
- # loader = PyPDFDirectoryLoader(self.data_path)
- # get file_paths of all pdfs in data_folder
- pdf_paths = glob.glob(os.path.join(self.data_path, "*.pdf"))
-
- splits = []
- for pdf_path in pdf_paths:
- file_name = os.path.basename(pdf_path)
- party = file_name.split("_")[0]
-
- # Load pdf as single doc
- loader = PDFMinerLoader(pdf_path, concatenate_pages=True)
- doc = loader.load()
-
- # Also load pdf as individual pages, this is important to extract the page number later
- loader = PDFMinerLoader(pdf_path, concatenate_pages=False)
- doc_pages = loader.load()
-
- # Add party to metadata
- for i in range(len(doc)):
- doc[i].metadata.update({"party": party})
-
- # Create splits
- splits_temp = text_splitter.split_documents(doc)
-
- # For each split, we search for the page on which it has occurred
- for split in splits_temp:
- for page_number, doc_page in enumerate(doc_pages):
- # Create first and second half of split
- split_1 = split.page_content[: int(0.5 * len(split.page_content))]
- split_2 = split.page_content[int(0.5 * len(split.page_content)) :]
- # If the first half is on page page_number or the second half is on page page_number, set page=page_number
- if split_1 in doc_page.page_content or split_2 in doc_page.page_content:
- split.metadata.update({"page": page_number})
-
- if split.metadata.get("page") is None:
- split.metadata.update({"page": 1})
-
- splits.extend(splits_temp)
-
- elif self.loader == "csv":
- loader = CSVLoader(self.data_path, metadata_columns=["date", "fullName", "politicalGroup", "party"])
- # Load documents
- docs = loader.load()
-
- # Create splits
- splits = text_splitter.split_documents(docs)
-
- # Create database
- self.database = Chroma.from_documents(
- splits,
- self.embedding_model,
- persist_directory=self.database_directory,
- collection_metadata={"hnsw:space": "cosine"},
- )
-
- return self.database
+ def __init__(
+ self,
+ embedding_model,
+ source_type, # "manifestos" or "debates"
+ data_path=".",
+ database_directory="./chroma",
+ chunk_size=1000,
+ chunk_overlap=200,
+ loader="pdf",
+ reload=True,
+ ):
+ """
+ Initializes the VectorDatabase.
+
+ Parameters:
+ - embedding_model: The model used to generate embeddings for the documents.
+ - data_directory (str): The directory where the source documents are located. Defaults to the current directory.
+ - database_directory (str): The directory to store the Chroma database. Defaults to './chroma'.
+ - chunk_size (int): The size of text chunks to split the documents into. Defaults to 1000.
+ - chunk_overlap (int): The number of characters to overlap between adjacent chunks. Defaults to 100.
+ - loader(str): "pdf" or "csv", depending on data format
+ """
+
+ self.embedding_model = embedding_model
+ self.source_type = source_type
+ self.data_path = data_path
+ self.database_directory = database_directory
+ self.chunk_size = chunk_size
+ self.chunk_overlap = chunk_overlap
+ self.loader = loader
+
+ if reload:
+ self.database = self.load_database()
+
+ def load_database(self):
+ """
+ Loads an existing Chroma database.
+
+ Returns:
+ - The loaded Chroma database.
+ """
+ if os.path.exists(self.database_directory):
+ self.database = Chroma(persist_directory=self.database_directory, embedding_function=self.embedding_model)
+ print("reloaded database")
+ else:
+ raise AssertionError(f"{self.database_directory} does not include database.")
+
+ return self.database
+
+ def build_database(self, overwrite=True):
+ """
+ Builds a new Chroma database from the documents in the data directory.
+
+ Parameters:
+ - loader: Optional, a document loader instance. If None, PyPDFDirectoryLoader will be used with the data_directory.
+
+ Returns:
+ - The newly built Chroma database.
+ """
+ # # If overwrite flag is true, remove old databases from directory if they exist
+ # if overwrite:
+ # if os.path.exists(self.database_directory):
+ # shutil.rmtree(self.database_directory)
+ # time.sleep(1)
+
+ # PDF is the default loader defined above
+
+ if os.path.exists(self.database_directory):
+ raise AssertionError("Delete old database first and restart session!")
+
+ # Define text_splitter
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
+
+ if self.loader == "pdf":
+ # loader = PyPDFDirectoryLoader(self.data_path)
+ # get file_paths of all pdfs in data_folder
+ pdf_paths = glob.glob(os.path.join(self.data_path, "*.pdf"))
+
+ splits = []
+ for pdf_path in pdf_paths:
+ file_name = os.path.basename(pdf_path)
+ party = file_name.split("_")[0]
+
+ # Load pdf as single doc
+ loader = PyMuPDFLoader(pdf_path, mode="single")
+ doc = loader.load()
+
+ # Also load pdf as individual pages, this is important to extract the page number later
+ loader = PyMuPDFLoader(pdf_path, mode="page")
+ doc_pages = loader.load()
+
+ # Add party to metadata
+ for i in range(len(doc)):
+ doc[i].metadata.update({"party": party})
+
+ # Create splits
+ splits_temp = text_splitter.split_documents(doc)
+
+ # For each split, we search for the page on which it has occurred
+ for split in splits_temp:
+ for page_number, doc_page in enumerate(doc_pages):
+ # Create first and second half of split
+ split_1 = split.page_content[: int(0.5 * len(split.page_content))]
+ split_2 = split.page_content[int(0.5 * len(split.page_content)) :]
+ # If the first half is on page page_number or the second half is on page page_number, set page=page_number
+ if split_1 in doc_page.page_content or split_2 in doc_page.page_content:
+ split.metadata.update({"page": page_number})
+
+ if split.metadata.get("page") is None:
+ split.metadata.update({"page": 1})
+
+ splits.extend(splits_temp)
+
+ elif self.loader == "csv":
+ loader = CSVLoader(self.data_path, metadata_columns=["date", "fullName", "politicalGroup", "party"])
+ # Load documents
+ docs = loader.load()
+
+ # Create splits
+ splits = text_splitter.split_documents(docs)
+
+ # Create database
+ self.database = Chroma.from_documents(
+ splits,
+ self.embedding_model,
+ persist_directory=self.database_directory,
+ collection_metadata={"hnsw:space": "cosine"},
+ )
+
+ return self.database
if __name__ == "__main__":
diff --git a/RAG/evaluation/evaluation.py b/RAG/evaluation/evaluation.py
index 6986b1d..8e101cb 100644
--- a/RAG/evaluation/evaluation.py
+++ b/RAG/evaluation/evaluation.py
@@ -17,7 +17,7 @@ def context_relevancy(self, dataset):
context = ""
for i, doc in enumerate(context_docs):
- context += f"Dokument {i+1}: {doc}\n\n"
+ context += f"Dokument {i + 1}: {doc}\n\n"
prompt = f"""
{instruction}
@@ -29,11 +29,7 @@ def context_relevancy(self, dataset):
{question}"""
completion = self.client.chat.completions.create(
- model="gpt-3.5-turbo",
- temperature=0,
- messages=[
- {"role": "user", "content": prompt},
- ],
+ model="gpt-3.5-turbo", temperature=0, messages=[{"role": "user", "content": prompt}]
)
# Parse output into list
try:
diff --git a/RAG/models/embedding.py b/RAG/models/embedding.py
index 730ea5d..505966e 100644
--- a/RAG/models/embedding.py
+++ b/RAG/models/embedding.py
@@ -7,179 +7,175 @@
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
- input_mask_expanded = (
- attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
- )
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
- input_mask_expanded.sum(1), min=1e-9
- )
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class ManifestoBertaEmbeddings(Embeddings):
- """Embeddings using ManifestoBerta for use with LangChain."""
+ """Embeddings using ManifestoBerta for use with LangChain."""
- def __init__(self):
- # Load the tokenizer and model
- self.tokenizer = AutoTokenizer.from_pretrained(
- "manifesto-project/manifestoberta-xlm-roberta-56policy-topics-sentence-2023-1-1"
- )
- self.model = AutoModel.from_pretrained(
- "manifesto-project/manifestoberta-xlm-roberta-56policy-topics-sentence-2023-1-1"
- )
+ def __init__(self):
+ # Load the tokenizer and model
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ "manifesto-project/manifestoberta-xlm-roberta-56policy-topics-sentence-2023-1-1"
+ )
+ self.model = AutoModel.from_pretrained(
+ "manifesto-project/manifestoberta-xlm-roberta-56policy-topics-sentence-2023-1-1"
+ )
- def _embed(self, text: str, mean_over_tokens=True) -> list[float]:
- """Embed a text using ManifestoBerta.
+ def _embed(self, text: str, mean_over_tokens=True) -> list[float]:
+ """Embed a text using ManifestoBerta.
- Args:
- text: The text to embed.
+ Args:
+ text: The text to embed.
- Returns:
- Embeddings for the text.
- """
+ Returns:
+ Embeddings for the text.
+ """
- # Encode the text
- inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
+ # Encode the text
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
- # Get model output (make sure to set output_hidden_states to True)
- with torch.no_grad():
- outputs = self.model(**inputs, output_hidden_states=True)
+ # Get model output (make sure to set output_hidden_states to True)
+ with torch.no_grad():
+ outputs = self.model(**inputs, output_hidden_states=True)
- # Extract the last hidden states
- last_hidden_states = outputs.hidden_states[-1]
+ # Extract the last hidden states
+ last_hidden_states = outputs.hidden_states[-1]
- # Average the token embeddings for a representation of the whole text
- if mean_over_tokens:
- embedding = torch.mean(last_hidden_states, dim=1)
- else:
- embedding = last_hidden_states
+ # Average the token embeddings for a representation of the whole text
+ if mean_over_tokens:
+ embedding = torch.mean(last_hidden_states, dim=1)
+ else:
+ embedding = last_hidden_states
- # Convert to list
- embedding_list = embedding.cpu().tolist()
+ # Convert to list
+ embedding_list = embedding.cpu().tolist()
- return embedding_list[0]
+ return embedding_list[0]
- def embed_documents(self, texts: list[str]) -> list[list[float]]:
- return [self._embed(text) for text in texts]
+ def embed_documents(self, texts: list[str]) -> list[list[float]]:
+ return [self._embed(text) for text in texts]
- def embed_query(self, text: str) -> list[float]:
- # return self.embed_documents([text])[0] # previous version
- return self._embed(text)
+ def embed_query(self, text: str) -> list[float]:
+ # return self.embed_documents([text])[0] # previous version
+ return self._embed(text)
class E5BaseEmbedding(Embeddings):
- """Embeddings using ManifestoBerta for use with LangChain."""
+ """Embeddings using ManifestoBerta for use with LangChain."""
- def __init__(self):
- # Load the tokenizer and model
- self.tokenizer = AutoTokenizer.from_pretrained("danielheinz/e5-base-sts-en-de")
+ def __init__(self):
+ # Load the tokenizer and model
+ self.tokenizer = AutoTokenizer.from_pretrained("danielheinz/e5-base-sts-en-de")
- self.model = AutoModel.from_pretrained("danielheinz/e5-base-sts-en-de")
+ self.model = AutoModel.from_pretrained("danielheinz/e5-base-sts-en-de")
- def _embed(self, text: str, mean_over_tokens=True) -> list[float]:
- """Embed a text using ManifestoBerta.
+ def _embed(self, text: str, mean_over_tokens=True) -> list[float]:
+ """Embed a text using ManifestoBerta.
- Args:
- text: The text to embed.
+ Args:
+ text: The text to embed.
- Returns:
- Embeddings for the text.
- """
+ Returns:
+ Embeddings for the text.
+ """
- # Encode the text
- inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
+ # Encode the text
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
- # Get model output (make sure to set output_hidden_states to True)
- with torch.no_grad():
- outputs = self.model(**inputs, output_hidden_states=True)
+ # Get model output (make sure to set output_hidden_states to True)
+ with torch.no_grad():
+ outputs = self.model(**inputs, output_hidden_states=True)
- # Extract the last hidden states
- last_hidden_states = outputs.hidden_states[-1]
+ # Extract the last hidden states
+ last_hidden_states = outputs.hidden_states[-1]
- # Average the token embeddings for a representation of the whole text
- if mean_over_tokens:
- embedding = torch.mean(last_hidden_states, dim=1)
- else:
- embedding = last_hidden_states
+ # Average the token embeddings for a representation of the whole text
+ if mean_over_tokens:
+ embedding = torch.mean(last_hidden_states, dim=1)
+ else:
+ embedding = last_hidden_states
- # Convert to list
- embedding_list = embedding.cpu().tolist()
+ # Convert to list
+ embedding_list = embedding.cpu().tolist()
- return embedding_list[0]
+ return embedding_list[0]
- def embed_documents(self, texts: list[str]) -> list[list[float]]:
- return [self._embed(text) for text in texts]
+ def embed_documents(self, texts: list[str]) -> list[list[float]]:
+ return [self._embed(text) for text in texts]
- def embed_query(self, text: str) -> list[float]:
- # return self.embed_documents([text])[0] # previous version
- return self._embed(text)
+ def embed_query(self, text: str) -> list[float]:
+ # return self.embed_documents([text])[0] # previous version
+ return self._embed(text)
class JinaAIEmbedding(Embeddings):
- """Embeddings using ManifestoBerta for use with LangChain."""
+ """Embeddings using ManifestoBerta for use with LangChain."""
- def __init__(self):
- # Load the tokenizer and model
- self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-de")
- self.model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-de", trust_remote_code=True)
+ def __init__(self):
+ # Load the tokenizer and model
+ self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-de")
+ self.model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-de", trust_remote_code=True)
- def _embed(self, text: str, mean_over_tokens=True) -> list[float]:
- """Embed a text using ManifestoBerta.
+ def _embed(self, text: str, mean_over_tokens=True) -> list[float]:
+ """Embed a text using ManifestoBerta.
- Args:
- text: The text to embed.
+ Args:
+ text: The text to embed.
- Returns:
- Embeddings for the text.
- """
+ Returns:
+ Embeddings for the text.
+ """
- # Encode the text
- inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
+ # Encode the text
+ inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
- # Get model output (make sure to set output_hidden_states to True)
- with torch.no_grad():
- model_output = self.model(**inputs)
+ # Get model output (make sure to set output_hidden_states to True)
+ with torch.no_grad():
+ model_output = self.model(**inputs)
- embedding = mean_pooling(model_output, inputs["attention_mask"])
- embedding = F.normalize(embedding, p=2, dim=1)
+ embedding = mean_pooling(model_output, inputs["attention_mask"])
+ embedding = F.normalize(embedding, p=2, dim=1)
- # Convert to list
- embedding_list = embedding.cpu().tolist()
+ # Convert to list
+ embedding_list = embedding.cpu().tolist()
- return embedding_list[0]
+ return embedding_list[0]
- def embed_documents(self, texts: list[str]) -> list[list[float]]:
- return [self._embed(text) for text in texts]
+ def embed_documents(self, texts: list[str]) -> list[list[float]]:
+ return [self._embed(text) for text in texts]
- def embed_query(self, text: str) -> list[float]:
- # return self.embed_documents([text])[0] # previous version
- return self._embed(text)
+ def embed_query(self, text: str) -> list[float]:
+ # return self.embed_documents([text])[0] # previous version
+ return self._embed(text)
class SentenceTransformerEmbedding(Embeddings):
- """Embeddings using ManifestoBerta for use with LangChain."""
+ """Embeddings using ManifestoBerta for use with LangChain."""
- def __init__(self, model_name="multi-qa-mpnet-base-dot-v1"):
- # Load the tokenizer and model
- self.model = SentenceTransformer(model_name)
+ def __init__(self, model_name="multi-qa-mpnet-base-dot-v1"):
+ # Load the tokenizer and model
+ self.model = SentenceTransformer(model_name)
- def _embed(self, text: str) -> list[float]:
- """Embed a text using ManifestoBerta.
+ def _embed(self, text: str) -> list[float]:
+ """Embed a text using ManifestoBerta.
- Args:
- text: The text to embed.
+ Args:
+ text: The text to embed.
- Returns:
- Embeddings for the text.
- """
+ Returns:
+ Embeddings for the text.
+ """
- # Encode the text
- embedding = self.model.encode(text)
- embedding = [float(e) for e in embedding]
- return embedding
+ # Encode the text
+ embedding = self.model.encode(text)
+ embedding = [float(e) for e in embedding]
+ return embedding
- def embed_documents(self, texts: list[str]) -> list[list[float]]:
- return [self._embed(text) for text in texts]
+ def embed_documents(self, texts: list[str]) -> list[list[float]]:
+ return [self._embed(text) for text in texts]
- def embed_query(self, text: str) -> list[float]:
- # return self.embed_documents([text])[0] # previous version
- return self._embed(text)
+ def embed_query(self, text: str) -> list[float]:
+ # return self.embed_documents([text])[0] # previous version
+ return self._embed(text)
diff --git a/pipeline/create_database.py b/pipeline/create_database.py
index 25403e5..a87c43c 100644
--- a/pipeline/create_database.py
+++ b/pipeline/create_database.py
@@ -45,8 +45,8 @@ def clean_up_old_database():
def download_manifestos():
- """Downloads manifesto PDFs from manifesto links in the party_dict.json file.
- Saves them to the data/manifestos/01_pdf_originals directory.
+ """Downloads manifesto PDFs from manifesto links in the party_dict.json file.
+ Saves them to the data/manifestos/01_pdf_originals directory.
"""
# Download manifesto PDFs
logger.info("Downloading manifesto PDFs...")
@@ -61,9 +61,10 @@ def download_manifestos():
with open(data_path, "wb") as f:
f.write(response.content)
+
def build_database():
"""Builds a vector database from the manifesto PDFs.
- Saves the database to the data/manifestos/chroma/{embedding_name}/ directory.
+ Saves the database to the data/manifestos/chroma/{embedding_name}/ directory.
"""
# instantiate database
logger.info("Initializing vector database...")
diff --git a/pipeline/push_to_huggingface.py b/pipeline/push_to_huggingface.py
index c3f8d92..ed4d7da 100644
--- a/pipeline/push_to_huggingface.py
+++ b/pipeline/push_to_huggingface.py
@@ -6,25 +6,22 @@
embedding_name = "openai"
DATABASE_DIR = f"data/manifestos/chroma/{embedding_name}/"
+
def push_to_huggingface():
"""Creates a zip file of the database folder and pushes it to the Hugging Face dataset.
- NOTE: Make sure to set the HUGGINGFACE_TOKEN environment variable.
+ NOTE: Make sure to set the HUGGINGFACE_TOKEN environment variable.
"""
- login(token=os.getenv("HUGGINGFACE_TOKEN"))
+ login(token=os.getenv("HUGGINGFACE_TOKEN"))
# Create zip file of the database folder
- shutil.make_archive(
- base_name = DATABASE_DIR.rstrip("/"),
- format = 'zip',
- root_dir = DATABASE_DIR
- )
+ shutil.make_archive(base_name=DATABASE_DIR.rstrip("/"), format="zip", root_dir=DATABASE_DIR)
upload_file(
path_or_fileobj=f"{DATABASE_DIR.rstrip('/')}.zip",
repo_id="cliedl/electify",
repo_type="dataset",
- path_in_repo=f"{os.path.basename(DATABASE_DIR.rstrip('/'))}.zip"
+ path_in_repo=f"{os.path.basename(DATABASE_DIR.rstrip('/'))}.zip",
)
-if __name__ == "__main__":
- push_to_huggingface()
\ No newline at end of file
+if __name__ == "__main__":
+ push_to_huggingface()
diff --git a/pyproject.toml b/pyproject.toml
index 4e28f02..c2812fd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -38,4 +38,5 @@ dev = [
"torch>=2.6.0",
"tqdm>=4.67.1",
"transformers>=4.48.2",
+ "pymupdf>=1.25.3",
]
diff --git a/streamlit_app/index.html b/streamlit_app/index.html
index 7f901ce..9a28189 100644
--- a/streamlit_app/index.html
+++ b/streamlit_app/index.html
@@ -1,45 +1,109 @@
-
-
+
+
+
+
+
Informiere dich über die Positionen der Parteien zur Bundestagswahl 2025.
-