From 0aa1447f06109b85f80ab8697b9a9699a4f2bc32 Mon Sep 17 00:00:00 2001 From: josh-nowak Date: Sat, 8 Feb 2025 11:48:05 +0100 Subject: [PATCH 1/7] fix custom index.html copy command in Dockerfile for SEO --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 4a844c1..7f2d13a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,8 +19,8 @@ RUN uv sync --frozen --no-dev # Activate virtual environment ENV PATH="/app/.venv/bin:$PATH" -# Copy custom index -COPY streamlit_app/index.html /usr/local/lib/python3.11/site-packages/streamlit/static/index.html +# Copy custom index.html with SEO tags +RUN cp streamlit_app/index.html $(python -c "import streamlit; import os; print(os.path.dirname(streamlit.__file__))")/static/index.html # Expose port 8080 to world outside of the container EXPOSE 8080 From 46039176db267935ac70de35ab58b736efdb2c5f Mon Sep 17 00:00:00 2001 From: josh-nowak Date: Sat, 8 Feb 2025 12:38:22 +0100 Subject: [PATCH 2/7] switch to PyMuPDFLoader to fix parsing issue with double characters --- RAG/database/vector_database.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/RAG/database/vector_database.py b/RAG/database/vector_database.py index e896972..343acf9 100644 --- a/RAG/database/vector_database.py +++ b/RAG/database/vector_database.py @@ -1,7 +1,7 @@ import glob import os -from langchain_community.document_loaders import PDFMinerLoader +from langchain_community.document_loaders import PyMuPDFLoader from langchain_community.document_loaders.csv_loader import CSVLoader from langchain_community.vectorstores import Chroma from langchain_text_splitters import RecursiveCharacterTextSplitter @@ -92,11 +92,11 @@ def build_database(self, overwrite=True): party = file_name.split("_")[0] # Load pdf as single doc - loader = PDFMinerLoader(pdf_path, concatenate_pages=True) + loader = PyMuPDFLoader(pdf_path, mode="single") doc = loader.load() # Also load pdf as individual pages, this is important to extract the page number later - loader = PDFMinerLoader(pdf_path, concatenate_pages=False) + loader = PyMuPDFLoader(pdf_path, mode="page") doc_pages = loader.load() # Add party to metadata From c7773d2d39b1be8da437aec81d9bc68f508ee0f8 Mon Sep 17 00:00:00 2001 From: josh-nowak Date: Sat, 8 Feb 2025 22:06:15 +0100 Subject: [PATCH 3/7] add pymupdf dependency --- pyproject.toml | 1 + uv.lock | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 4e28f02..c2812fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,4 +38,5 @@ dev = [ "torch>=2.6.0", "tqdm>=4.67.1", "transformers>=4.48.2", + "pymupdf>=1.25.3", ] diff --git a/uv.lock b/uv.lock index adcf2ea..ac8e6b6 100644 --- a/uv.lock +++ b/uv.lock @@ -675,6 +675,7 @@ dev = [ { name = "pdfminer-six" }, { name = "playwright" }, { name = "plotly" }, + { name = "pymupdf" }, { name = "pytest" }, { name = "pytest-timeout" }, { name = "ragas" }, @@ -712,6 +713,7 @@ dev = [ { name = "pdfminer-six", specifier = ">=20240706" }, { name = "playwright", specifier = ">=1.49.1" }, { name = "plotly", specifier = ">=6.0.0" }, + { name = "pymupdf", specifier = ">=1.25.3" }, { name = "pytest", specifier = ">=8.3.4" }, { name = "pytest-timeout", specifier = ">=2.3.1" }, { name = "ragas", specifier = ">=0.2.12" }, @@ -2913,6 +2915,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/0a/c20ef65268514be131860e7d70d4f3892114b6a6e2a3a4a56a9f6aff901b/pymongo-4.11-cp313-cp313t-win_amd64.whl", hash = "sha256:488e3440f5bedcbf494fd02c0a433cb5be7e55ba44dc72202813e1007a865e6a", size = 987846 }, ] +[[package]] +name = "pymupdf" +version = "1.25.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/06/47/b61c1c44b87cbdaeecdec3f43ce524ed6b3c72172bc6184eb82c94fbc43d/pymupdf-1.25.3.tar.gz", hash = "sha256:b640187c64c5ac5d97505a92e836da299da79c2f689f3f94a67a37a493492193", size = 67259841 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/9b/98ef4b98309e9db3baa9fe572f0e61b6130bb9852d13189970f35b703499/pymupdf-1.25.3-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:96878e1b748f9c2011aecb2028c5f96b5a347a9a91169130ad0133053d97915e", size = 19343576 }, + { url = "https://files.pythonhosted.org/packages/14/62/4e12126db174c8cfbf692281cda971cc4046c5f5226032c2cfaa6f83e08d/pymupdf-1.25.3-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:6ef753005b72ebfd23470f72f7e30f61e21b0b5e748045ec5b8f89e6e3068d62", size = 18580114 }, + { url = "https://files.pythonhosted.org/packages/52/de/bd1418e31f73d37b8381cd5deacfd681e6be702b8890e123e83724569ee1/pymupdf-1.25.3-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:46d90c4f9e62d1856e8db4b9f04a202ff4a7f086a816af73abdc86adb7f5e25a", size = 19999825 }, + { url = "https://files.pythonhosted.org/packages/42/ee/3c449b0de061440ba1ac984aa845315e9e2dca0ff2003c5adfc6febff203/pymupdf-1.25.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a5de51efdbe4d486b6c1111c84e8a231cbfb426f3d6ff31ab530ad70e6f39756", size = 21123157 }, + { url = "https://files.pythonhosted.org/packages/83/53/71faaaf91c56f2883b13f3dd849bf2697f012eb35eb7b952d62734cff41f/pymupdf-1.25.3-cp39-abi3-win32.whl", hash = "sha256:bca72e6089f985d800596e22973f79cc08af6cbff1d93e5bda9248326a03857c", size = 15094211 }, + { url = "https://files.pythonhosted.org/packages/09/e0/d72e88a1d5e23aa381fd463057dc3d0fb29090e1e7308a870c334716579c/pymupdf-1.25.3-cp39-abi3-win_amd64.whl", hash = "sha256:4fb357438c9129fbf939b5af85323434df64e36759c399c376b62ad6da95498c", size = 16542949 }, +] + [[package]] name = "pyparsing" version = "3.2.1" From 64a5fda4460dc5f76ad753c8e94583fffe420dfa Mon Sep 17 00:00:00 2001 From: josh-nowak Date: Sat, 8 Feb 2025 22:19:20 +0100 Subject: [PATCH 4/7] apply ruff formatter --- App.py | 199 +++++------------------ RAG/database/vector_database.py | 260 +++++++++++++++---------------- RAG/evaluation/evaluation.py | 8 +- RAG/models/embedding.py | 238 ++++++++++++++-------------- pipeline/create_database.py | 7 +- pipeline/push_to_huggingface.py | 17 +- streamlit_app/utils/log.py | 8 +- streamlit_app/utils/translate.py | 5 +- tests/test_db.py | 9 +- 9 files changed, 312 insertions(+), 439 deletions(-) diff --git a/App.py b/App.py index c044d27..1e17a83 100644 --- a/App.py +++ b/App.py @@ -47,12 +47,12 @@ # model_name="gpt-4o-mini", max_tokens=MAX_TOKENS, temperature=TEMPERATURE # ) LARGE_LANGUAGE_MODEL = AzureChatOpenAI( - azure_deployment="gpt-4o-mini", + azure_deployment="gpt-4o-mini", azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_SWEDEN"], api_key=os.environ["AZURE_OPENAI_API_KEY_SWEDEN"], - api_version=os.environ["AZURE_OPENAI_API_VERSION_SWEDEN"], - max_tokens=MAX_TOKENS, - temperature=TEMPERATURE + api_version=os.environ["AZURE_OPENAI_API_VERSION_SWEDEN"], + max_tokens=MAX_TOKENS, + temperature=TEMPERATURE, ) @@ -69,9 +69,7 @@ def load_embedding_model(): @st.cache_resource def load_db_manifestos(): return VectorDatabase( - embedding_model=embedding_model, - source_type="manifestos", - database_directory=DATABASE_DIR_MANIFESTOS, + embedding_model=embedding_model, source_type="manifestos", database_directory=DATABASE_DIR_MANIFESTOS ) @@ -131,9 +129,7 @@ def load_db_manifestos(): # The keys represent the (random) order of appearance of the parties in the app # and not fixed parties as opposed to the above party_dict. - st.session_state.show_individual_parties = { - f"party_{i+1}": False for i in range(len(st.session_state.parties)) - } + st.session_state.show_individual_parties = {f"party_{i + 1}": False for i in range(len(st.session_state.parties))} # The "example_prompts" dictionary will contain randomly selected example prompts for the user to choose from: if "example_prompts" not in st.session_state: @@ -146,9 +142,7 @@ def load_db_manifestos(): all_example_prompts[key] = [] all_example_prompts[key].append(value) - st.session_state.example_prompts = { - key: random.sample(value, 3) for key, value in all_example_prompts.items() - } + st.session_state.example_prompts = {key: random.sample(value, 3) for key, value in all_example_prompts.items()} if "number_of_requests" not in st.session_state: st.session_state.number_of_requests = 0 @@ -183,7 +177,7 @@ def img_to_bytes(img_path): def img_to_html(img_path): - img_html = f"" + img_html = f"" return img_html @@ -194,9 +188,7 @@ def submit_query(): st.session_state.feedback_text = "" st.session_state.feedback_submitted = False st.session_state.stage = 1 - st.session_state.show_individual_parties = { - f"party_{i+1}": False for i in range(len(st.session_state.parties)) - } + st.session_state.show_individual_parties = {f"party_{i + 1}": False for i in range(len(st.session_state.parties))} random.shuffle(st.session_state.parties) @@ -218,9 +210,9 @@ def generate_response(): st.session_state.response = rag.query(query) # Assert that the response contains all parties - assert set(st.session_state.response["answer"].keys()) == set( - st.session_state.parties - ), "LLM response does not contain all parties" + assert set(st.session_state.response["answer"].keys()) == set(st.session_state.parties), ( + "LLM response does not contain all parties" + ) break except Exception as e: @@ -230,33 +222,19 @@ def generate_response(): if retry_count > max_retries: print(f"Max number of tries ({max_retries}) reached, aborting") st.session_state.response = None - st.error( - translate( - "error-api-unavailable", - st.session_state.language, - ) - ) + st.error(translate("error-api-unavailable", st.session_state.language)) # Display error message in app: raise e else: print(f"Retrying, retry number {retry_count}") pass st.session_state.log_id = add_log_dict( - { - "query": st.session_state.query, - "answer": st.session_state.response["answer"], - } + {"query": st.session_state.query, "answer": st.session_state.response["answer"]} ) def submit_feedback(feedback_rating, feedback_text): - update_log_dict( - st.session_state.log_id, - { - "feedback-rating": feedback_rating, - "feedback-text": feedback_text, - }, - ) + update_log_dict(st.session_state.log_id, {"feedback-rating": feedback_rating, "feedback-text": feedback_text}) st.rerun() @@ -277,24 +255,13 @@ def convert_date_format(date_string): ################################## with st.sidebar: - selected_language = st.radio( - label="Language", - options=["🇩🇪 Deutsch", "🇬🇧 English"], - horizontal=True, - ) + selected_language = st.radio(label="Language", options=["🇩🇪 Deutsch", "🇬🇧 English"], horizontal=True) languages = {"🇩🇪 Deutsch": "de", "🇬🇧 English": "en"} st.session_state.language = languages[selected_language] rag.language = st.session_state.language st.header("🗳️ electify.eu", divider="blue") -st.write( - "##### :grey[" - + translate( - "subheadline", - st.session_state.language, - ) - + "]" -) +st.write("##### :grey[" + translate("subheadline", st.session_state.language) + "]") support_button( text=f"💙  {translate('support-button', st.session_state.language)}", @@ -304,38 +271,25 @@ def convert_date_format(date_string): if st.session_state.number_of_requests >= 3: # Show support banner after 3 requests in a single session. st.info( - f"{translate('support-banner', st.session_state.language)}(https://buymeacoffee.com/electify.eu)", - icon="💙", + f"{translate('support-banner', st.session_state.language)}(https://buymeacoffee.com/electify.eu)", icon="💙" ) query = st.text_input( - label=translate( - "query-instruction", - st.session_state.language, - ), - placeholder="", - value=st.session_state.query, + label=translate("query-instruction", st.session_state.language), placeholder="", value=st.session_state.query ) col_submit, col_checkbox = st.columns([1, 3]) # Submit button with col_submit: - st.button( - translate("submit-query", st.session_state.language), - on_click=submit_query, - type="primary", - ) + st.button(translate("submit-query", st.session_state.language), on_click=submit_query, type="primary") # Checkbox to show/hide party names globally with col_checkbox: st.session_state.show_all_parties = st.checkbox( label=translate("show-party-names", st.session_state.language), value=True, - help=translate( - "show-party-names-help", - st.session_state.language, - ), + help=translate("show-party-names-help", st.session_state.language), ) # Allow the user to select up to 6 parties @@ -350,12 +304,7 @@ def update_party_selection(party): party_selection[party] = not party_selection[party] st.session_state.parties = [k for k, v in party_selection.items() if v] - st.write( - translate( - "select-parties-instruction", - st.session_state.language, - ) - ) + st.write(translate("select-parties-instruction", st.session_state.language)) for party in available_parties: st.checkbox( label=party_dict[party]["name"], @@ -365,15 +314,11 @@ def update_party_selection(party): ) if len(st.session_state.parties) == 0: - st.markdown( - f"⚠️ **:red[{translate('error-min-1-party', st.session_state.language)}]**" - ) + st.markdown(f"⚠️ **:red[{translate('error-min-1-party', st.session_state.language)}]**") # Reset to default parties st.session_state.parties = rag.parties elif len(st.session_state.parties) > 6: - st.markdown( - f"⚠️ **:red[{translate('error-max-6-parties', st.session_state.language)}]**" - ) + st.markdown(f"⚠️ **:red[{translate('error-max-6-parties', st.session_state.language)}]**") # Limit to the six first selected parties st.session_state.parties = st.session_state.parties[:6] @@ -395,43 +340,23 @@ def update_party_selection(party): # STAGE > 0: Show disclaimer once the user has submitted a query (and keep showing it) if st.session_state.stage > 0: if len(st.session_state.parties) == 0: - st.error( - translate( - "error-min-1-party", - st.session_state.language, - ) - ) + st.error(translate("error-min-1-party", st.session_state.language)) st.session_state.stage = 0 else: st.info( "☝️ " - + translate( - "disclaimer-llm", - st.session_state.language, - ) + + translate("disclaimer-llm", st.session_state.language) + " \n" - + translate( - "disclaimer-research", - st.session_state.language, - ) + + translate("disclaimer-research", st.session_state.language) + " \n" - + translate( - "disclaimer-random-order", - st.session_state.language, - ), + + translate("disclaimer-random-order", st.session_state.language) ) # STAGE 1: User submitted a query and we are waiting for the response if st.session_state.stage == 1: st.session_state.number_of_requests += 1 - with st.spinner( - translate( - "loading-response", - st.session_state.language, - ) - + "🕵️" - ): + with st.spinner(translate("loading-response", st.session_state.language) + "🕵️"): generate_response() st.session_state.stage = 2 @@ -439,7 +364,6 @@ def update_party_selection(party): # STAGE > 1: The response has been generated and is displayed if st.session_state.stage > 1: - # Initialize an empty list to hold all columns col_list = [] # Create a pair of columns for each party @@ -451,14 +375,11 @@ def update_party_selection(party): p = i + 1 col1, col2 = col_list[i] - most_relevant_manifesto_page_number = st.session_state.response["docs"][ - "manifestos" - ][party][0].metadata["page"] + most_relevant_manifesto_page_number = st.session_state.response["docs"]["manifestos"][party][0].metadata[ + "page" + ] - show_party = ( - st.session_state.show_all_parties - or st.session_state.show_individual_parties[f"party_{p}"] - ) + show_party = st.session_state.show_all_parties or st.session_state.show_individual_parties[f"party_{p}"] # In this column, we show the party image with col1: @@ -470,12 +391,7 @@ def update_party_selection(party): else: file_loc = "streamlit_app/assets/placeholder_logo.png" st.markdown(img_to_html(file_loc), unsafe_allow_html=True) - st.button( - translate("show-party", st.session_state.language), - on_click=reveal_party, - args=(p,), - key=p, - ) + st.button(translate("show-party", st.session_state.language), on_click=reveal_party, args=(p,), key=p) # In this column, we show the RAG response with col2: if show_party: @@ -484,16 +400,11 @@ def update_party_selection(party): st.header(f"{translate('party', st.session_state.language)} {p}") if party == "afd": - st.caption( - f"⚠️ **{translate("warning-afd", st.session_state.language)}**" - ) + st.caption(f"⚠️ **{translate('warning-afd', st.session_state.language)}**") st.write(st.session_state.response["answer"][party]) if show_party: - is_answer_empty = ( - "keine passende antwort" - in st.session_state.response["answer"][party].lower() - ) + is_answer_empty = "keine passende antwort" in st.session_state.response["answer"][party].lower() page_reference_string = ( "" @@ -502,33 +413,17 @@ def update_party_selection(party): ) st.write( - f"""{translate('learn-more-in', st.session_state.language)} [{translate('party-manifesto', st.session_state.language)} **{party_dict[party]['name']}**{page_reference_string}]({party_dict[party]['manifesto_link']}#page={most_relevant_manifesto_page_number + 1}).""" + f"""{translate("learn-more-in", st.session_state.language)} [{translate("party-manifesto", st.session_state.language)} **{party_dict[party]["name"]}**{page_reference_string}]({party_dict[party]["manifesto_link"]}#page={most_relevant_manifesto_page_number + 1}).""" ) st.markdown("---") # Display a section with all retrieved excerpts from the sources st.subheader(translate("sources-subheading", st.session_state.language)) - st.write( - translate( - "sources-intro", - st.session_state.language, - ) - ) - st.write( - translate( - "sources-excerpts-intro", - st.session_state.language, - ) - ) + st.write(translate("sources-intro", st.session_state.language)) + st.write(translate("sources-excerpts-intro", st.session_state.language)) for party in st.session_state.parties: - with st.expander( - translate( - "sources", - st.session_state.language, - ) - + f": {party_dict[party]['name']}" - ): + with st.expander(translate("sources", st.session_state.language) + f": {party_dict[party]['name']}"): for doc in st.session_state.response["docs"]["manifestos"][party]: manifesto_excerpt = doc.page_content.replace("\n", " ") page_number_of_excerpt = doc.metadata["page"] + 1 @@ -552,27 +447,19 @@ def update_party_selection(party): st.subheader(translate("feedback-heading", st.session_state.language)) if not st.session_state.feedback_submitted: - st.write(translate("feedback-intro", st.session_state.language)) with st.form(key="feedback-form"): - feedback_options = { - "negative": "☹️", - "neutral": "😐", - "positive": "😊", - } + feedback_options = {"negative": "☹️", "neutral": "😐", "positive": "😊"} feedback_rating = st.segmented_control( label=translate("feedback-question-rating", st.session_state.language), options=feedback_options.keys(), format_func=lambda option: feedback_options[option], ) - feedback_text = st.text_area( - label=translate("feedback-question-text", st.session_state.language) - ) + feedback_text = st.text_area(label=translate("feedback-question-text", st.session_state.language)) submitted = st.form_submit_button( - label=translate("feedback-submit", st.session_state.language), - type="primary", + label=translate("feedback-submit", st.session_state.language), type="primary" ) if submitted: st.session_state.feedback_submitted = True diff --git a/RAG/database/vector_database.py b/RAG/database/vector_database.py index 343acf9..6bc382b 100644 --- a/RAG/database/vector_database.py +++ b/RAG/database/vector_database.py @@ -8,136 +8,136 @@ class VectorDatabase: - def __init__( - self, - embedding_model, - source_type, # "manifestos" or "debates" - data_path=".", - database_directory="./chroma", - chunk_size=1000, - chunk_overlap=200, - loader="pdf", - reload=True, - ): - """ - Initializes the VectorDatabase. - - Parameters: - - embedding_model: The model used to generate embeddings for the documents. - - data_directory (str): The directory where the source documents are located. Defaults to the current directory. - - database_directory (str): The directory to store the Chroma database. Defaults to './chroma'. - - chunk_size (int): The size of text chunks to split the documents into. Defaults to 1000. - - chunk_overlap (int): The number of characters to overlap between adjacent chunks. Defaults to 100. - - loader(str): "pdf" or "csv", depending on data format - """ - - self.embedding_model = embedding_model - self.source_type = source_type - self.data_path = data_path - self.database_directory = database_directory - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - self.loader = loader - - if reload: - self.database = self.load_database() - - def load_database(self): - """ - Loads an existing Chroma database. - - Returns: - - The loaded Chroma database. - """ - if os.path.exists(self.database_directory): - self.database = Chroma(persist_directory=self.database_directory, embedding_function=self.embedding_model) - print("reloaded database") - else: - raise AssertionError(f"{self.database_directory} does not include database.") - - return self.database - - def build_database(self, overwrite=True): - """ - Builds a new Chroma database from the documents in the data directory. - - Parameters: - - loader: Optional, a document loader instance. If None, PyPDFDirectoryLoader will be used with the data_directory. - - Returns: - - The newly built Chroma database. - """ - # # If overwrite flag is true, remove old databases from directory if they exist - # if overwrite: - # if os.path.exists(self.database_directory): - # shutil.rmtree(self.database_directory) - # time.sleep(1) - - # PDF is the default loader defined above - - if os.path.exists(self.database_directory): - raise AssertionError("Delete old database first and restart session!") - - # Define text_splitter - text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) - - if self.loader == "pdf": - # loader = PyPDFDirectoryLoader(self.data_path) - # get file_paths of all pdfs in data_folder - pdf_paths = glob.glob(os.path.join(self.data_path, "*.pdf")) - - splits = [] - for pdf_path in pdf_paths: - file_name = os.path.basename(pdf_path) - party = file_name.split("_")[0] - - # Load pdf as single doc - loader = PyMuPDFLoader(pdf_path, mode="single") - doc = loader.load() - - # Also load pdf as individual pages, this is important to extract the page number later - loader = PyMuPDFLoader(pdf_path, mode="page") - doc_pages = loader.load() - - # Add party to metadata - for i in range(len(doc)): - doc[i].metadata.update({"party": party}) - - # Create splits - splits_temp = text_splitter.split_documents(doc) - - # For each split, we search for the page on which it has occurred - for split in splits_temp: - for page_number, doc_page in enumerate(doc_pages): - # Create first and second half of split - split_1 = split.page_content[: int(0.5 * len(split.page_content))] - split_2 = split.page_content[int(0.5 * len(split.page_content)) :] - # If the first half is on page page_number or the second half is on page page_number, set page=page_number - if split_1 in doc_page.page_content or split_2 in doc_page.page_content: - split.metadata.update({"page": page_number}) - - if split.metadata.get("page") is None: - split.metadata.update({"page": 1}) - - splits.extend(splits_temp) - - elif self.loader == "csv": - loader = CSVLoader(self.data_path, metadata_columns=["date", "fullName", "politicalGroup", "party"]) - # Load documents - docs = loader.load() - - # Create splits - splits = text_splitter.split_documents(docs) - - # Create database - self.database = Chroma.from_documents( - splits, - self.embedding_model, - persist_directory=self.database_directory, - collection_metadata={"hnsw:space": "cosine"}, - ) - - return self.database + def __init__( + self, + embedding_model, + source_type, # "manifestos" or "debates" + data_path=".", + database_directory="./chroma", + chunk_size=1000, + chunk_overlap=200, + loader="pdf", + reload=True, + ): + """ + Initializes the VectorDatabase. + + Parameters: + - embedding_model: The model used to generate embeddings for the documents. + - data_directory (str): The directory where the source documents are located. Defaults to the current directory. + - database_directory (str): The directory to store the Chroma database. Defaults to './chroma'. + - chunk_size (int): The size of text chunks to split the documents into. Defaults to 1000. + - chunk_overlap (int): The number of characters to overlap between adjacent chunks. Defaults to 100. + - loader(str): "pdf" or "csv", depending on data format + """ + + self.embedding_model = embedding_model + self.source_type = source_type + self.data_path = data_path + self.database_directory = database_directory + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.loader = loader + + if reload: + self.database = self.load_database() + + def load_database(self): + """ + Loads an existing Chroma database. + + Returns: + - The loaded Chroma database. + """ + if os.path.exists(self.database_directory): + self.database = Chroma(persist_directory=self.database_directory, embedding_function=self.embedding_model) + print("reloaded database") + else: + raise AssertionError(f"{self.database_directory} does not include database.") + + return self.database + + def build_database(self, overwrite=True): + """ + Builds a new Chroma database from the documents in the data directory. + + Parameters: + - loader: Optional, a document loader instance. If None, PyPDFDirectoryLoader will be used with the data_directory. + + Returns: + - The newly built Chroma database. + """ + # # If overwrite flag is true, remove old databases from directory if they exist + # if overwrite: + # if os.path.exists(self.database_directory): + # shutil.rmtree(self.database_directory) + # time.sleep(1) + + # PDF is the default loader defined above + + if os.path.exists(self.database_directory): + raise AssertionError("Delete old database first and restart session!") + + # Define text_splitter + text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) + + if self.loader == "pdf": + # loader = PyPDFDirectoryLoader(self.data_path) + # get file_paths of all pdfs in data_folder + pdf_paths = glob.glob(os.path.join(self.data_path, "*.pdf")) + + splits = [] + for pdf_path in pdf_paths: + file_name = os.path.basename(pdf_path) + party = file_name.split("_")[0] + + # Load pdf as single doc + loader = PyMuPDFLoader(pdf_path, mode="single") + doc = loader.load() + + # Also load pdf as individual pages, this is important to extract the page number later + loader = PyMuPDFLoader(pdf_path, mode="page") + doc_pages = loader.load() + + # Add party to metadata + for i in range(len(doc)): + doc[i].metadata.update({"party": party}) + + # Create splits + splits_temp = text_splitter.split_documents(doc) + + # For each split, we search for the page on which it has occurred + for split in splits_temp: + for page_number, doc_page in enumerate(doc_pages): + # Create first and second half of split + split_1 = split.page_content[: int(0.5 * len(split.page_content))] + split_2 = split.page_content[int(0.5 * len(split.page_content)) :] + # If the first half is on page page_number or the second half is on page page_number, set page=page_number + if split_1 in doc_page.page_content or split_2 in doc_page.page_content: + split.metadata.update({"page": page_number}) + + if split.metadata.get("page") is None: + split.metadata.update({"page": 1}) + + splits.extend(splits_temp) + + elif self.loader == "csv": + loader = CSVLoader(self.data_path, metadata_columns=["date", "fullName", "politicalGroup", "party"]) + # Load documents + docs = loader.load() + + # Create splits + splits = text_splitter.split_documents(docs) + + # Create database + self.database = Chroma.from_documents( + splits, + self.embedding_model, + persist_directory=self.database_directory, + collection_metadata={"hnsw:space": "cosine"}, + ) + + return self.database if __name__ == "__main__": diff --git a/RAG/evaluation/evaluation.py b/RAG/evaluation/evaluation.py index 6986b1d..8e101cb 100644 --- a/RAG/evaluation/evaluation.py +++ b/RAG/evaluation/evaluation.py @@ -17,7 +17,7 @@ def context_relevancy(self, dataset): context = "" for i, doc in enumerate(context_docs): - context += f"Dokument {i+1}: {doc}\n\n" + context += f"Dokument {i + 1}: {doc}\n\n" prompt = f""" {instruction} @@ -29,11 +29,7 @@ def context_relevancy(self, dataset): {question}""" completion = self.client.chat.completions.create( - model="gpt-3.5-turbo", - temperature=0, - messages=[ - {"role": "user", "content": prompt}, - ], + model="gpt-3.5-turbo", temperature=0, messages=[{"role": "user", "content": prompt}] ) # Parse output into list try: diff --git a/RAG/models/embedding.py b/RAG/models/embedding.py index 730ea5d..505966e 100644 --- a/RAG/models/embedding.py +++ b/RAG/models/embedding.py @@ -7,179 +7,175 @@ def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] - input_mask_expanded = ( - attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() - ) - return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( - input_mask_expanded.sum(1), min=1e-9 - ) + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) class ManifestoBertaEmbeddings(Embeddings): - """Embeddings using ManifestoBerta for use with LangChain.""" + """Embeddings using ManifestoBerta for use with LangChain.""" - def __init__(self): - # Load the tokenizer and model - self.tokenizer = AutoTokenizer.from_pretrained( - "manifesto-project/manifestoberta-xlm-roberta-56policy-topics-sentence-2023-1-1" - ) - self.model = AutoModel.from_pretrained( - "manifesto-project/manifestoberta-xlm-roberta-56policy-topics-sentence-2023-1-1" - ) + def __init__(self): + # Load the tokenizer and model + self.tokenizer = AutoTokenizer.from_pretrained( + "manifesto-project/manifestoberta-xlm-roberta-56policy-topics-sentence-2023-1-1" + ) + self.model = AutoModel.from_pretrained( + "manifesto-project/manifestoberta-xlm-roberta-56policy-topics-sentence-2023-1-1" + ) - def _embed(self, text: str, mean_over_tokens=True) -> list[float]: - """Embed a text using ManifestoBerta. + def _embed(self, text: str, mean_over_tokens=True) -> list[float]: + """Embed a text using ManifestoBerta. - Args: - text: The text to embed. + Args: + text: The text to embed. - Returns: - Embeddings for the text. - """ + Returns: + Embeddings for the text. + """ - # Encode the text - inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) + # Encode the text + inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) - # Get model output (make sure to set output_hidden_states to True) - with torch.no_grad(): - outputs = self.model(**inputs, output_hidden_states=True) + # Get model output (make sure to set output_hidden_states to True) + with torch.no_grad(): + outputs = self.model(**inputs, output_hidden_states=True) - # Extract the last hidden states - last_hidden_states = outputs.hidden_states[-1] + # Extract the last hidden states + last_hidden_states = outputs.hidden_states[-1] - # Average the token embeddings for a representation of the whole text - if mean_over_tokens: - embedding = torch.mean(last_hidden_states, dim=1) - else: - embedding = last_hidden_states + # Average the token embeddings for a representation of the whole text + if mean_over_tokens: + embedding = torch.mean(last_hidden_states, dim=1) + else: + embedding = last_hidden_states - # Convert to list - embedding_list = embedding.cpu().tolist() + # Convert to list + embedding_list = embedding.cpu().tolist() - return embedding_list[0] + return embedding_list[0] - def embed_documents(self, texts: list[str]) -> list[list[float]]: - return [self._embed(text) for text in texts] + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [self._embed(text) for text in texts] - def embed_query(self, text: str) -> list[float]: - # return self.embed_documents([text])[0] # previous version - return self._embed(text) + def embed_query(self, text: str) -> list[float]: + # return self.embed_documents([text])[0] # previous version + return self._embed(text) class E5BaseEmbedding(Embeddings): - """Embeddings using ManifestoBerta for use with LangChain.""" + """Embeddings using ManifestoBerta for use with LangChain.""" - def __init__(self): - # Load the tokenizer and model - self.tokenizer = AutoTokenizer.from_pretrained("danielheinz/e5-base-sts-en-de") + def __init__(self): + # Load the tokenizer and model + self.tokenizer = AutoTokenizer.from_pretrained("danielheinz/e5-base-sts-en-de") - self.model = AutoModel.from_pretrained("danielheinz/e5-base-sts-en-de") + self.model = AutoModel.from_pretrained("danielheinz/e5-base-sts-en-de") - def _embed(self, text: str, mean_over_tokens=True) -> list[float]: - """Embed a text using ManifestoBerta. + def _embed(self, text: str, mean_over_tokens=True) -> list[float]: + """Embed a text using ManifestoBerta. - Args: - text: The text to embed. + Args: + text: The text to embed. - Returns: - Embeddings for the text. - """ + Returns: + Embeddings for the text. + """ - # Encode the text - inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) + # Encode the text + inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) - # Get model output (make sure to set output_hidden_states to True) - with torch.no_grad(): - outputs = self.model(**inputs, output_hidden_states=True) + # Get model output (make sure to set output_hidden_states to True) + with torch.no_grad(): + outputs = self.model(**inputs, output_hidden_states=True) - # Extract the last hidden states - last_hidden_states = outputs.hidden_states[-1] + # Extract the last hidden states + last_hidden_states = outputs.hidden_states[-1] - # Average the token embeddings for a representation of the whole text - if mean_over_tokens: - embedding = torch.mean(last_hidden_states, dim=1) - else: - embedding = last_hidden_states + # Average the token embeddings for a representation of the whole text + if mean_over_tokens: + embedding = torch.mean(last_hidden_states, dim=1) + else: + embedding = last_hidden_states - # Convert to list - embedding_list = embedding.cpu().tolist() + # Convert to list + embedding_list = embedding.cpu().tolist() - return embedding_list[0] + return embedding_list[0] - def embed_documents(self, texts: list[str]) -> list[list[float]]: - return [self._embed(text) for text in texts] + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [self._embed(text) for text in texts] - def embed_query(self, text: str) -> list[float]: - # return self.embed_documents([text])[0] # previous version - return self._embed(text) + def embed_query(self, text: str) -> list[float]: + # return self.embed_documents([text])[0] # previous version + return self._embed(text) class JinaAIEmbedding(Embeddings): - """Embeddings using ManifestoBerta for use with LangChain.""" + """Embeddings using ManifestoBerta for use with LangChain.""" - def __init__(self): - # Load the tokenizer and model - self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-de") - self.model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-de", trust_remote_code=True) + def __init__(self): + # Load the tokenizer and model + self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-de") + self.model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-de", trust_remote_code=True) - def _embed(self, text: str, mean_over_tokens=True) -> list[float]: - """Embed a text using ManifestoBerta. + def _embed(self, text: str, mean_over_tokens=True) -> list[float]: + """Embed a text using ManifestoBerta. - Args: - text: The text to embed. + Args: + text: The text to embed. - Returns: - Embeddings for the text. - """ + Returns: + Embeddings for the text. + """ - # Encode the text - inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") + # Encode the text + inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") - # Get model output (make sure to set output_hidden_states to True) - with torch.no_grad(): - model_output = self.model(**inputs) + # Get model output (make sure to set output_hidden_states to True) + with torch.no_grad(): + model_output = self.model(**inputs) - embedding = mean_pooling(model_output, inputs["attention_mask"]) - embedding = F.normalize(embedding, p=2, dim=1) + embedding = mean_pooling(model_output, inputs["attention_mask"]) + embedding = F.normalize(embedding, p=2, dim=1) - # Convert to list - embedding_list = embedding.cpu().tolist() + # Convert to list + embedding_list = embedding.cpu().tolist() - return embedding_list[0] + return embedding_list[0] - def embed_documents(self, texts: list[str]) -> list[list[float]]: - return [self._embed(text) for text in texts] + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [self._embed(text) for text in texts] - def embed_query(self, text: str) -> list[float]: - # return self.embed_documents([text])[0] # previous version - return self._embed(text) + def embed_query(self, text: str) -> list[float]: + # return self.embed_documents([text])[0] # previous version + return self._embed(text) class SentenceTransformerEmbedding(Embeddings): - """Embeddings using ManifestoBerta for use with LangChain.""" + """Embeddings using ManifestoBerta for use with LangChain.""" - def __init__(self, model_name="multi-qa-mpnet-base-dot-v1"): - # Load the tokenizer and model - self.model = SentenceTransformer(model_name) + def __init__(self, model_name="multi-qa-mpnet-base-dot-v1"): + # Load the tokenizer and model + self.model = SentenceTransformer(model_name) - def _embed(self, text: str) -> list[float]: - """Embed a text using ManifestoBerta. + def _embed(self, text: str) -> list[float]: + """Embed a text using ManifestoBerta. - Args: - text: The text to embed. + Args: + text: The text to embed. - Returns: - Embeddings for the text. - """ + Returns: + Embeddings for the text. + """ - # Encode the text - embedding = self.model.encode(text) - embedding = [float(e) for e in embedding] - return embedding + # Encode the text + embedding = self.model.encode(text) + embedding = [float(e) for e in embedding] + return embedding - def embed_documents(self, texts: list[str]) -> list[list[float]]: - return [self._embed(text) for text in texts] + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [self._embed(text) for text in texts] - def embed_query(self, text: str) -> list[float]: - # return self.embed_documents([text])[0] # previous version - return self._embed(text) + def embed_query(self, text: str) -> list[float]: + # return self.embed_documents([text])[0] # previous version + return self._embed(text) diff --git a/pipeline/create_database.py b/pipeline/create_database.py index 25403e5..a87c43c 100644 --- a/pipeline/create_database.py +++ b/pipeline/create_database.py @@ -45,8 +45,8 @@ def clean_up_old_database(): def download_manifestos(): - """Downloads manifesto PDFs from manifesto links in the party_dict.json file. - Saves them to the data/manifestos/01_pdf_originals directory. + """Downloads manifesto PDFs from manifesto links in the party_dict.json file. + Saves them to the data/manifestos/01_pdf_originals directory. """ # Download manifesto PDFs logger.info("Downloading manifesto PDFs...") @@ -61,9 +61,10 @@ def download_manifestos(): with open(data_path, "wb") as f: f.write(response.content) + def build_database(): """Builds a vector database from the manifesto PDFs. - Saves the database to the data/manifestos/chroma/{embedding_name}/ directory. + Saves the database to the data/manifestos/chroma/{embedding_name}/ directory. """ # instantiate database logger.info("Initializing vector database...") diff --git a/pipeline/push_to_huggingface.py b/pipeline/push_to_huggingface.py index c3f8d92..ed4d7da 100644 --- a/pipeline/push_to_huggingface.py +++ b/pipeline/push_to_huggingface.py @@ -6,25 +6,22 @@ embedding_name = "openai" DATABASE_DIR = f"data/manifestos/chroma/{embedding_name}/" + def push_to_huggingface(): """Creates a zip file of the database folder and pushes it to the Hugging Face dataset. - NOTE: Make sure to set the HUGGINGFACE_TOKEN environment variable. + NOTE: Make sure to set the HUGGINGFACE_TOKEN environment variable. """ - login(token=os.getenv("HUGGINGFACE_TOKEN")) + login(token=os.getenv("HUGGINGFACE_TOKEN")) # Create zip file of the database folder - shutil.make_archive( - base_name = DATABASE_DIR.rstrip("/"), - format = 'zip', - root_dir = DATABASE_DIR - ) + shutil.make_archive(base_name=DATABASE_DIR.rstrip("/"), format="zip", root_dir=DATABASE_DIR) upload_file( path_or_fileobj=f"{DATABASE_DIR.rstrip('/')}.zip", repo_id="cliedl/electify", repo_type="dataset", - path_in_repo=f"{os.path.basename(DATABASE_DIR.rstrip('/'))}.zip" + path_in_repo=f"{os.path.basename(DATABASE_DIR.rstrip('/'))}.zip", ) -if __name__ == "__main__": - push_to_huggingface() \ No newline at end of file +if __name__ == "__main__": + push_to_huggingface() diff --git a/streamlit_app/utils/log.py b/streamlit_app/utils/log.py index 19ebc93..5034fde 100644 --- a/streamlit_app/utils/log.py +++ b/streamlit_app/utils/log.py @@ -18,9 +18,7 @@ def add_log_dict(dictionary: dict) -> str: The inserted id. """ # Add timestamp to dictionary - dictionary["created_at"] = datetime.now(timezone.utc) + timedelta( - hours=1 - ) # Berlin timezone (UTC+1) + dictionary["created_at"] = datetime.now(timezone.utc) + timedelta(hours=1) # Berlin timezone (UTC+1) client = MongoClient(os.getenv("MONGODB_URI"), server_api=ServerApi("1")) db = client[os.getenv("DATABASE_NAME")] @@ -44,9 +42,7 @@ def update_log_dict(_id: str, dictionary: dict): client = MongoClient(os.getenv("MONGODB_URI"), server_api=ServerApi("1")) db = client[os.getenv("DATABASE_NAME")] collection = db[os.getenv("COLLECTION_NAME")] - result = collection.update_one( - filter={"_id": ObjectId(_id)}, update={"$set": dictionary} - ) + result = collection.update_one(filter={"_id": ObjectId(_id)}, update={"$set": dictionary}) return result diff --git a/streamlit_app/utils/translate.py b/streamlit_app/utils/translate.py index bca409a..1fffc81 100644 --- a/streamlit_app/utils/translate.py +++ b/streamlit_app/utils/translate.py @@ -3,11 +3,10 @@ import pandas as pd current_file_directory = os.path.dirname(os.path.abspath(__file__)) -dictionary = pd.read_csv( - os.path.join(current_file_directory, "language_dictionary.csv"), delimiter=";" -) +dictionary = pd.read_csv(os.path.join(current_file_directory, "language_dictionary.csv"), delimiter=";") dictionary = dictionary.set_index("key") + def translate(key, language): if language in dictionary.columns: return dictionary.loc[key][language] diff --git a/tests/test_db.py b/tests/test_db.py index 532ea2e..b412363 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -19,29 +19,30 @@ party_dict = json.load(f) db = VectorDatabase( - embedding_model=embedding_model, - source_type="manifestos", - database_directory=DATABASE_DIR_MANIFESTOS, + embedding_model=embedding_model, source_type="manifestos", database_directory=DATABASE_DIR_MANIFESTOS ) + def test_db_connection(): assert db.database is not None results = db.database.similarity_search("Wie sieht die Steuerpolitik der Parteien aus?") assert len(results) > 0 + def test_party_counts(): collection = db.database._collection party_counts = {} for doc in collection.get()["metadatas"]: if "party" in doc: party = doc["party"] - party_counts[party] = party_counts.get(party, 0) + 1 + party_counts[party] = party_counts.get(party, 0) + 1 for party in sorted(party_counts.keys()): print(f"{party}: {party_counts[party]}") assert set(party_counts.keys()) == set(party_dict.keys()) + def test_has_page_metadata(): collection = db.database._collection for doc in collection.get()["metadatas"]: From 276f0b47c538b6b470b71e56ca598cf7e1c918cb Mon Sep 17 00:00:00 2001 From: josh-nowak Date: Mon, 10 Feb 2025 15:56:04 +0100 Subject: [PATCH 5/7] fix: only show afd disclaimer if party name is shown --- App.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/App.py b/App.py index 1e17a83..0bf20b9 100644 --- a/App.py +++ b/App.py @@ -399,7 +399,7 @@ def update_party_selection(party): else: st.header(f"{translate('party', st.session_state.language)} {p}") - if party == "afd": + if party == "afd" and show_party: st.caption(f"⚠️ **{translate('warning-afd', st.session_state.language)}**") st.write(st.session_state.response["answer"][party]) From 27daee95c906025a7752f873306868e03c9c2a10 Mon Sep 17 00:00:00 2001 From: josh-nowak Date: Mon, 10 Feb 2025 15:58:18 +0100 Subject: [PATCH 6/7] remove custom index.html until we find a more stable solution --- Dockerfile | 3 --- 1 file changed, 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 7f2d13a..cf28bb8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,9 +19,6 @@ RUN uv sync --frozen --no-dev # Activate virtual environment ENV PATH="/app/.venv/bin:$PATH" -# Copy custom index.html with SEO tags -RUN cp streamlit_app/index.html $(python -c "import streamlit; import os; print(os.path.dirname(streamlit.__file__))")/static/index.html - # Expose port 8080 to world outside of the container EXPOSE 8080 From 6b73d39c9dd362932907aa633c4e8e806287826f Mon Sep 17 00:00:00 2001 From: josh-nowak Date: Mon, 10 Feb 2025 22:04:35 +0100 Subject: [PATCH 7/7] add new index.html for SEO --- Dockerfile | 3 + streamlit_app/index.html | 138 ++++++++++++++++++++++++++++----------- 2 files changed, 104 insertions(+), 37 deletions(-) diff --git a/Dockerfile b/Dockerfile index cf28bb8..7f2d13a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,6 +19,9 @@ RUN uv sync --frozen --no-dev # Activate virtual environment ENV PATH="/app/.venv/bin:$PATH" +# Copy custom index.html with SEO tags +RUN cp streamlit_app/index.html $(python -c "import streamlit; import os; print(os.path.dirname(streamlit.__file__))")/static/index.html + # Expose port 8080 to world outside of the container EXPOSE 8080 diff --git a/streamlit_app/index.html b/streamlit_app/index.html index 7f901ce..9a28189 100644 --- a/streamlit_app/index.html +++ b/streamlit_app/index.html @@ -1,45 +1,109 @@ - - + + + + + + + + + + + - - - - Electify: Informiere dich zur Bundestagswahl 2025 mit künstlicher Intelligenz - - - - - - + + Electify: Informiere dich zur Bundestagswahl 2025 mit künstlicher + Intelligenz + + + + + + + - - - - - - - - - - - - - - + + + + + - + + + + + + -
-

electify.eu

-

Informiere dich über die Positionen der Parteien zur Bundestagswahl 2025.

-
-
- - +