From c298cf01ad09806075b6b4b5ef86de6db3f700c6 Mon Sep 17 00:00:00 2001 From: Quartz_Admirer <118068910+Quartz-Admirer@users.noreply.github.com> Date: Sat, 28 Jun 2025 21:26:23 +0300 Subject: [PATCH] Add files via upload --- .../dense_embedder.cpython-313.pyc | Bin 0 -> 2231 bytes .../__pycache__/settings.cpython-313.pyc | Bin 0 -> 2068 bytes src/indexing/bib_parser.py | 3 +- src/indexing/compute_pagerank.py | 145 +++++ src/indexing/elastic_search_indexer.py | 72 ++- src/indexing/entities.py | 3 + src/indexing/indexing_pipeline.py | 137 +++- src/indexing/parse.py | 150 +++-- src/indexing/settings.py | 16 +- src/search/query_processor.py | 596 ++++++++++++++++++ src/search/search_cli.py | 402 ++++++++---- 11 files changed, 1318 insertions(+), 206 deletions(-) create mode 100644 src/common/__pycache__/dense_embedder.cpython-313.pyc create mode 100644 src/common/__pycache__/settings.cpython-313.pyc create mode 100644 src/indexing/compute_pagerank.py create mode 100644 src/search/query_processor.py diff --git a/src/common/__pycache__/dense_embedder.cpython-313.pyc b/src/common/__pycache__/dense_embedder.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e552849fc4c5a8bd34099ffc1058e75782574f56 GIT binary patch literal 2231 zcmaJ?&u<$=6rQ!$KeB0XNt338R2{WN?Iy%Qs6}l+NL5=TlFFcLBMwP4Ua!aYwCi2> z&2HL6J@ka&fH#%LQJiE$(+X&_)Qf&I8FQZ`XLeE`p{k zTh`4Mr;DENO1W6Qrr_%u})lJ&ljr_ys{Uz+hI_aJc?3l&!dKzONs!e z4G7y74>kpaf#OX!T@Y@~M|^=+t7U|$Rl7`SC8SM1thqkgA|p5OxVkp=O8tGxU9Uww z*Y`PC>4r$4za$mtituhghhIne9QZt1s`iwvr|SJ}BN=srAdIwK5~oFnqZjZtl#dR- zoWlEdd1?HleJ_}d{r7ObL)Hlx7@ir&fUJ`N(!YioHOE>TFN_?lhabc$wRSO zut$@zLjm=43UzH?xgnaZ=q^{ijxZIQE`%|fQyu1qF(@MC7O#ab5peIn<{n|jv=0u^kc9=bxKTM$jHvXghF$w zs7a|PQ3{hdfU+++dD!+L9j9@Nd=t9|mcTFq1RT8g@HYep^QXpG8Iq^Z9|jmb2l~j1Jskzl&3D9$s4aS zK|^n$E~(7a^mj{Vx}`G@Z$A9!acS!{sBxWTHxNX7= zu5XX^E7Y2dU;~s_wqIA?UDI6y<1yOUz}=@1NwT;p-7|UswntRNsRY4@vzSVl*jMrq z#_q(f-$P1_H3Z1W7Xc4KgbPePDPwRPb!)c(RkutshBX?08|Fs%XVhY>#EgAUv_a_u ztI7bXQ)Nx6F!-GOYA!!F^EvDJQ6qQxW!}gwqzBZtifw_7qoUV{(G5rs)MF<3Xsd%i z3#dAk93B2{bwx;yz2Ya*WkOScIXwC=x1{(%>VOa79RyT-$g|k}nKcaKS>}jgzL+7# ZnHPCt%se6UPssU~`I52l7r{BiW&?kRU~soW!(Ah(p&ouJAXsNXdD4=P9l4treDiS&j|5>+Q0>Ltt%jIHLSv5d)KBy_)#Oqans?{ z!<*Z9#*aMu5R^06i99?LQ-{0FnUZCdg8*+iK|*5;#yDmPci~Q*vk=p0%jYZ#X(Ygg z`_5(nW>*-;4&K-(!RL()T=xTt1Lk;e%Mq3zwip*!_#rKnFkWxcZS1m0a28OBzjZW}i;JZI->c{_g@Y?a=$2bA7adaJJ=w1<>_z#azsk-mP^KBn}R zvOg~Ms=FtkyC<(@R78jm+GBjEajI}SHG7He9fNPqPZ~m%w~GcfuKSIItBOEhRsOM6 zRL5dab;l>0P$ZJ@z`}53#qhnmw=t^&DE_B03vieOTA@gWJA}6HEOBERD zK{lyxgq0B%40HMbpe4F^k~Z%;+;oK7B(dM_I$U)g`e72{D1j8Ya!nr!0JDTf3pD*# z;-2?XFF&W{llydA1ty~h9=i^-8?fm|D#nd+BASf*ZB?iZ+@ye9j<3WNfW`pWUJ{1e z-I2h4UAdBRZ#wRM3ZvsO8b_CfdfJD9aljo!=DW&Yp=7nkn?khaW@iKDIzf}gV($9Y zt5;{YK1fH`?o_JfB{E;OD+DSJaLfWM(`+nwI{H?R+NTBNUF>ID{HPd*P7tJHmDTym z;^I>EQ?k0YSh-EA<<&|$);B!Tp}j|dm!lWv%j*lD5qs&&N}BZ}arD-IZD?CCS;dG_ z80mE6FlHzxS*d)HjvR)0Sn4;X2yiKnc7TDtIYn}Y>?}8I9j3?N`9`jgPe*&WE-fhZ z?#W4L>vFNQjYQOo*VCNeh#03oqiVrSO`>o|NNN$8WfD-}hX{F)IDteYNTY<*eIARz zk7&d|VS+bM1vEY0;|PshCO1xMK#94No+2ccyTm1q5ZvEP1dR#dXCRyd_(SRE3v;!5 z0NuF80@{dci|p%2E*9W!+vHAnITb_4kKR?p+^xa73t3bHkg3gFYHg>vZq9rh{L8%K z6F|O0J0~X|n>%NxJLbgBvGFGce;?HRk|Z*mYGIh4mBM?HoRj1{kTg$V$GH%*++=)2 zs-O}_b|id~%Y%zx8VJ{PUsC zhx0Ec7amogF1*wi{?@a`#A^hk8-aYhS{UTAB(lx*L?j#<%d=8(ytVB)a@t6*lw6m3 zv!_@2r@8LF=Q;SM(n-y}{hP_A=FKjUKH()WmI#jJ|Iyz5V<8NLg_}QOBLt{_ozBN?_{{ozzUi<(6 literal 0 HcmV?d00001 diff --git a/src/indexing/bib_parser.py b/src/indexing/bib_parser.py index fe9b774..15a512e 100644 --- a/src/indexing/bib_parser.py +++ b/src/indexing/bib_parser.py @@ -60,6 +60,7 @@ def extract_bib_entries(path: Path) -> list[BibEntry]: "title": latex_to_unicode(entry.fields.get("title")), "year": entry.fields.get("year"), "author": authors if authors else None, + "doi": entry.fields.get("doi"), "language": entry.fields.get("language"), } ) @@ -68,4 +69,4 @@ def extract_bib_entries(path: Path) -> list[BibEntry]: df = df[(df["language"] == "eng") | (df["language"].isna())] df = df.dropna(subset=["url", "title", "year", "author"]) - return df[["url", "title", "year", "author"]].to_dict(orient="records") + return df[["url", "title", "year", "author", "doi"]].to_dict(orient="records") diff --git a/src/indexing/compute_pagerank.py b/src/indexing/compute_pagerank.py new file mode 100644 index 0000000..a82c01a --- /dev/null +++ b/src/indexing/compute_pagerank.py @@ -0,0 +1,145 @@ +"""compute_pagerank.py +Batch job that enriches an existing Elasticsearch index with PageRank scores. + +Steps +----- +1. Fetch all documents (with a DOI) from the configured index. +2. Retrieve outbound references for every DOI via the Crossref API (async). +3. Build the citation graph (optionally adding virtual nodes for external papers). +4. Run PageRank. +5. Bulk-update each indexed document with its PageRank score (field ``pagerank``). + +Run as a one-off script *after* the main indexing pipeline has finished: + + python -m src.indexing.compute_pagerank +""" + +from __future__ import annotations + +import asyncio +import logging +from collections import defaultdict +from typing import Dict, List, Tuple + +import aiohttp + +import networkx as nx + +from src.indexing.settings import settings +from src.indexing.elastic_search_indexer import ElasticSearchIndexer + +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +async def _fetch_single_reference_list( + session: aiohttp.ClientSession, doi: str +) -> Tuple[str, List[str]]: + """Return (doi, references[]) or (doi, []) on error.""" + url = f"https://api.crossref.org/works/{doi}" + headers = {"User-Agent": "Academic Research Tool (mailto:your-email@example.com)"} + + try: + async with session.get(url, headers=headers, timeout=30) as resp: + resp.raise_for_status() + data = await resp.json() + refs = data.get("message", {}).get("reference", []) + ref_dois = [r["DOI"] for r in refs if "DOI" in r] + return doi, ref_dois + except Exception as exc: # noqa: BLE001,E722 + logger.warning("Crossref lookup failed for %s: %s", doi, exc) + return doi, [] + + +async def fetch_references( + dois: List[str], concurrency: int = 20 +) -> Dict[str, List[str]]: + """Fetch reference lists for *dois* concurrently using Crossref with progress bar.""" + results: Dict[str, List[str]] = {} + + with tqdm(total=len(dois), desc="Crossref", unit="doi") as pbar: + for i in range(0, len(dois), concurrency): + batch = dois[i : i + concurrency] + async with aiohttp.ClientSession() as session: + tasks = [_fetch_single_reference_list(session, d) for d in batch] + batch_res = await asyncio.gather(*tasks, return_exceptions=False) + results.update(dict(batch_res)) + pbar.update(len(batch)) + return results + + +def pagerank_with_virtual_nodes( + references_dict: Dict[str, List[str]], + alpha: float = 0.85, +) -> Dict[str, float]: + """Run PageRank. Only original papers are returned in the result.""" + + # Collect external papers referenced > *min_external_citations* times + external_counter: Dict[str, int] = defaultdict(int) + for cited_list in references_dict.values(): + for cited in cited_list: + if cited not in references_dict: + external_counter[cited] += 1 + + virtual_nodes = {doi for doi, cnt in external_counter.items()} + + original_nodes = set(references_dict.keys()) + all_nodes = original_nodes | virtual_nodes + + G = nx.DiGraph() + G.add_nodes_from(all_nodes) + + for src, tgt_list in references_dict.items(): + for tgt in tgt_list: + if tgt in all_nodes: + G.add_edge(src, tgt) + + pagerank_full = nx.pagerank(G, alpha=alpha, max_iter=1000, tol=1e-9) + return {doi: score for doi, score in pagerank_full.items() if doi in original_nodes} + + +def main() -> None: + index_name = settings.index_name + indexer = ElasticSearchIndexer(settings.es_host) + + logger.info("Scanning index '%s' for documents …", index_name) + internal_to_esid = indexer.build_internal_id_map(index_name) + if not internal_to_esid: + logger.error("No documents found in index '%s'.", index_name) + return + + # Separate DOIs (needed for Crossref) and synthetic IDs + dois = [iid for iid in internal_to_esid if not iid.startswith("SYNTH_")] + + logger.info("Fetching Crossref references for %d DOIs …", len(dois)) + references = asyncio.run( + fetch_references(dois, concurrency=settings.crossref_concurrency) + ) + + # Ensure synthetic IDs are present in graph even if they have no outgoing refs or references + for internal_id in internal_to_esid: + if internal_id not in references: + references[internal_id] = [] + + logger.info("Running PageRank on citation graph …") + pagerank_scores = pagerank_with_virtual_nodes( + references, + alpha=settings.pagerank_alpha, + ) + + id_score_map = { + internal_to_esid[iid]: score + for iid, score in pagerank_scores.items() + if iid in internal_to_esid + } + + logger.info("Updating Elasticsearch documents with PageRank scores …") + indexer.bulk_update_field(index_name, id_score_map, field="pagerank") + logger.info( + "PageRank enrichment completed: %d documents updated.", len(id_score_map) + ) + + +if __name__ == "__main__": + main() diff --git a/src/indexing/elastic_search_indexer.py b/src/indexing/elastic_search_indexer.py index 561845d..64075a7 100644 --- a/src/indexing/elastic_search_indexer.py +++ b/src/indexing/elastic_search_indexer.py @@ -10,7 +10,7 @@ import logging import uuid -from typing import Any, Iterable, Optional +from typing import Any, Iterable, Optional, Dict from elasticsearch import Elasticsearch, helpers from elasticsearch.exceptions import NotFoundError @@ -19,6 +19,8 @@ from src.indexing.entities import IndexedDocument +import hashlib + logger = logging.getLogger(__name__) @@ -74,6 +76,8 @@ def create_index( "author": {"type": "text"}, "year": {"type": "keyword"}, "url": {"type": "keyword"}, + "doi": {"type": "keyword"}, + "pagerank": {"type": "float"}, } if embedding_dim: properties["text_embedding"] = { @@ -119,3 +123,69 @@ def index_documents( ] helpers.bulk(self._client, actions) logger.info("Indexed %d documents into '%s'.", len(chunk), index_name) + + @staticmethod + def _create_synth_id(title: str | None, year: str | int | None, author: list[str] | str | None) -> str: + """Generate a stable synthetic identifier for a paper without DOI.""" + title_part = (title or "").strip()[:50] + year_part = str(year or "") + first_author = "" + if isinstance(author, list) and author: + first_author = author[0] + elif isinstance(author, str): + first_author = author.split(",")[0] + first_author = first_author[:30] + base = f"{title_part}_{year_part}_{first_author}".lower().replace(" ", "_") + return f"SYNTH_{hashlib.md5(base.encode()).hexdigest()[:12]}" + + def build_internal_id_map( + self, + index_name: str, + *, + include_fields: list[str] | None = None, + ) -> Dict[str, str]: + """Return a mapping of *internal_id -> Elasticsearch _id*. + + *internal_id* is the DOI if present, otherwise a synthetic ID derived + from title/year/author so that every document participates in graph + enrichment jobs. + """ + + src_fields = ["doi", "title", "year", "author"] + if include_fields: + src_fields.extend(include_fields) + + id_map: Dict[str, str] = {} + for hit in helpers.scan(self._client, index=index_name, _source=src_fields): + src = hit["_source"] + doc_doi = src.get("doi") + if doc_doi: + internal_id = doc_doi + else: + internal_id = self._create_synth_id(src.get("title"), src.get("year"), src.get("author")) + + id_map[internal_id] = hit["_id"] + + return id_map + + def bulk_update_field( + self, + index_name: str, + id_value_map: Dict[str, float], + field: str = "pagerank", + ) -> None: + """Update *field* for each ES doc where `_id` is a key in id_value_map.""" + + actions = [ + { + "_op_type": "update", + "_index": index_name, + "_id": es_id, + "doc": {field: value}, + } + for es_id, value in id_value_map.items() + ] + + if actions: + helpers.bulk(self._client, actions, refresh=True) + logger.info("Updated %d documents (%s field).", len(actions), field) \ No newline at end of file diff --git a/src/indexing/entities.py b/src/indexing/entities.py index 8fa976a..5617b33 100644 --- a/src/indexing/entities.py +++ b/src/indexing/entities.py @@ -16,6 +16,8 @@ class IndexedDocument(TypedDict, total=False): author: list[str] url: str year: Union[int, str] + doi: str + pagerank: float class BibEntry(TypedDict): @@ -25,3 +27,4 @@ class BibEntry(TypedDict): title: str author: list[str] year: Union[int, str] + doi: str | None diff --git a/src/indexing/indexing_pipeline.py b/src/indexing/indexing_pipeline.py index 16fcc45..0751116 100644 --- a/src/indexing/indexing_pipeline.py +++ b/src/indexing/indexing_pipeline.py @@ -3,8 +3,8 @@ Workflow --------- 1. Read the BibTeX file specified by ``BIB_FILE``. -2. Download every referenced PDF (multiprocess via Docling). -3. Convert each PDF to Markdown and embed its full text with +2. Download every referenced PDF concurrently with asynchronous HTTP requests. +3. Extract plain text from each PDF and embed it with :class:`src.common.dense_embedder.DenseEmbedder`. 4. Create the target ES index – mapping includes a `dense_vector` field sized to the embedding dimension. @@ -16,30 +16,32 @@ from __future__ import annotations import logging -import os +import asyncio from itertools import batched from tqdm import tqdm +import httpx +from httpx import Limits + from src.indexing.elastic_search_indexer import ElasticSearchIndexer from src.indexing.entities import IndexedDocument, BibEntry from src.indexing.bib_parser import extract_bib_entries -from src.indexing.docling_parallel import convert_in_parallel, close_pool from src.indexing.settings import settings from src.common.dense_embedder import DenseEmbedder -from .parse import process_text +from .parse import fetch_and_parse logger = logging.getLogger(__name__) -def _process_batch(entries: list[BibEntry], workers: int, embedder: DenseEmbedder) -> list[IndexedDocument]: +def _process_batch( + entries: list[BibEntry], embedder: DenseEmbedder +) -> list[IndexedDocument]: """Convert one batch of BibTeX entries to indexable documents. Args: entries: A list of BibEntry mappings to be converted. Each entry must contain at least the keys url, title, year and author. - workers: Number of worker processes used by Docling during PDF → Markdown - conversion. embedder: DenseEmbedder instance used for embedding text. Returns: @@ -54,9 +56,8 @@ def _process_batch(entries: list[BibEntry], workers: int, embedder: DenseEmbedde url = url.rstrip("/") + ".pdf" urls.append(url) - texts = convert_in_parallel(urls, workers) - # Clean raw Markdown text extracted from PDFs before embedding/indexing - texts = [process_text(t) if t else None for t in texts] + # Download and parse PDFs concurrently + texts = asyncio.run(fetch_and_parse(urls)) docs: list[IndexedDocument] = [] raw_texts: list[str] = [] @@ -70,6 +71,7 @@ def _process_batch(entries: list[BibEntry], workers: int, embedder: DenseEmbedde "year": entry.get("year"), "url": entry.get("url"), "author": entry.get("author"), + "doi": entry.get("doi"), "text": text, } ) @@ -92,7 +94,6 @@ def ingest_bib() -> None: bib_file = settings.bib_file index_name = settings.index_name batch_size = settings.batch_size - concurrency = settings.concurrency or os.cpu_count() or 4 force_delete_index = settings.force_delete_index es_host = settings.es_host max_entries = settings.max_entries @@ -114,7 +115,7 @@ def ingest_bib() -> None: with tqdm(total=len(entries), desc="Ingesting PDF batches", unit="doc") as progress: embedder = DenseEmbedder() for batch_entries in batched(entries, batch_size): - docs = _process_batch(list(batch_entries), concurrency, embedder) + docs = _process_batch(list(batch_entries), embedder) if not docs: progress.update(len(batch_entries)) continue @@ -129,9 +130,113 @@ def ingest_bib() -> None: indexer.index_documents(index_name, docs, batch_size=len(docs)) progress.update(len(batch_entries)) - # tear down worker pool - close_pool() + +async def _process_batch_async( + entries: list[BibEntry], + embedder: DenseEmbedder, + client: httpx.AsyncClient, +) -> list[IndexedDocument]: + """Convert one batch of BibTeX entries to indexable documents. + + Args: + entries: A list of BibEntry mappings to be converted. Each entry must + contain at least the keys url, title, year and author. + embedder: DenseEmbedder instance used for embedding text. + client: httpx.AsyncClient instance used for downloading PDFs. + + Returns: + A list of :class:IndexedDocument dictionaries that passed conversion + successfully (failed PDFs are silently skipped). + """ + + urls = [] + for e in entries: + url = e["url"] + if not url.endswith(".pdf"): + url = url.rstrip("/") + ".pdf" + urls.append(url) + + # Download and parse PDFs concurrently (shared HTTP client) + texts = await fetch_and_parse(urls, client=client) + + docs: list[IndexedDocument] = [] + raw_texts: list[str] = [] + for entry, text in zip(entries, texts): + if not text: + continue + raw_texts.append(text) + docs.append( + { + "title": entry.get("title"), + "year": entry.get("year"), + "url": entry.get("url"), + "author": entry.get("author"), + "doi": entry.get("doi"), + "text": text, + } + ) + + if docs: + embeddings = embedder.embed_documents(raw_texts) + for doc, emb in zip(docs, embeddings): + doc["text_embedding"] = emb + + return docs + + +async def _ingest_bib_async() -> None: + """Async implementation of the ingestion pipeline (single event loop).""" + + bib_file = settings.bib_file + index_name = settings.index_name + batch_size = settings.batch_size + force_delete_index = settings.force_delete_index + es_host = settings.es_host + max_entries = settings.max_entries + + entries = extract_bib_entries(bib_file) + if max_entries is not None and max_entries > 0: + entries = entries[:max_entries] + if not entries: + logger.warning("No entries with a URL found in %s", bib_file) + return + + indexer = ElasticSearchIndexer(es_host) + index_created = False + + for entry in entries: + if not entry["url"].endswith(".pdf"): + entry["url"] = entry["url"].rstrip("/") + ".pdf" + embedder = DenseEmbedder() + + # Shared HTTP client across all batches – HTTP/2 + keep-alive + async with httpx.AsyncClient( + http2=True, + limits=Limits( + max_connections=settings.acl_concurrency, + max_keepalive_connections=settings.acl_concurrency * 2, + ), + ) as client: + with tqdm( + total=len(entries), desc="Ingesting PDF batches", unit="doc" + ) as progress: + for batch_entries in batched(entries, batch_size): + batch_list = list(batch_entries) + docs = await _process_batch_async(batch_list, embedder, client) + if not docs: + progress.update(len(batch_list)) + continue + + if not index_created: + dim = len(docs[0]["text_embedding"]) + indexer.create_index( + index_name, force_delete=force_delete_index, embedding_dim=dim + ) + index_created = True + + indexer.index_documents(index_name, docs, batch_size=len(docs)) + progress.update(len(batch_list)) if __name__ == "__main__": - ingest_bib() + asyncio.run(_ingest_bib_async()) diff --git a/src/indexing/parse.py b/src/indexing/parse.py index d786382..83028c7 100644 --- a/src/indexing/parse.py +++ b/src/indexing/parse.py @@ -1,25 +1,35 @@ import fitz -import time -import os import re -from pathlib import Path import logging +import asyncio +import httpx + +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) + +from src.indexing.settings import settings logger = logging.getLogger(__name__) + def process_text(text): - - match = re.search(r'\\babstract\\b', text, re.IGNORECASE) + match = re.search(r"\\babstract\\b", text, re.IGNORECASE) if match: # Get the text *after* the word "abstract" - text = text[match.end():] + text = text[match.end() :] # Look for common footer section headers like "References", "Bibliography", "Acknowledgements". - footer_match = re.search(r"\n\s*(references|bibliography|acknowledgements)\b", text, re.IGNORECASE) + footer_match = re.search( + r"\n\s*(references|bibliography|acknowledgements)\b", text, re.IGNORECASE + ) if footer_match: - text = text[:footer_match.start()] + text = text[: footer_match.start()] - text = re.sub(r'(\\w+)-\\s*\\n\\s*(\\w+)', r'\\1\\2', text) + text = re.sub(r"(\\w+)-\\s*\\n\\s*(\\w+)", r"\\1\\2", text) # Remove emails and URLs (any trailing punctuation will be consumed as well) text = re.sub(r"[\w.+-]+@[\w-]+(?:\.[\w-]+)+", " ", text) @@ -44,52 +54,84 @@ def process_text(text): text = re.sub(r"\s+", " ", text).strip() return text -def parse_pdf(path: str | Path, *, save_md: bool = False) -> tuple[str, float]: + +def parse_pdf(data: bytes | bytearray) -> str: + """Extract and clean text from a PDF. + + The PDF can be supplied either as raw *bytes* (recommended) or as a local + filesystem *path* (``str`` | ``Path``). The latter is kept for backwards + compatibility. + """ try: - t0 = time.perf_counter() - with fitz.open(path) as doc: - raw = "".join(page.get_text() for page in doc) - elapsed = time.perf_counter() - t0 + if not data.lstrip().startswith(b"%PDF"): + raise ValueError("Invalid PDF header – does not start with %PDF") + + with fitz.open(stream=data, filetype="pdf") as doc: + texts: list[str] = [] + for page_number in range(doc.page_count): + try: + page = doc.load_page(page_number) + texts.append(page.get_text("text")) + except Exception as page_exc: + logger.warning( + "Skipping page %d due to parsing error: %s", + page_number, + page_exc, + ) + raw = "".join(texts) except Exception as e: - logger.error("Cannot parse %s: %s", path, e) + logger.error("Cannot parse PDF: %s", e) raise - cleaned = process_text(raw) - # Optional debug output: сохранить файл .md рядом с PDF - if save_md: - base_name = Path(path).stem - md_path = Path(path).with_suffix("").with_name(f"{base_name}_pymupdf.md") - try: - with open(md_path, "w", encoding="utf-8") as f: - f.write(f"# {base_name} (PyMuPDF)\n\n") - f.write(f"**Время парсинга:** {elapsed:.3f}s\n\n") - f.write(cleaned) - except Exception as err: - logger.warning("Не удалось записать MD файл %s: %s", md_path, err) - return cleaned, elapsed - -if __name__ == "__main__": - import argparse, textwrap, sys - - parser = argparse.ArgumentParser( - description="Extract and clean text from PDF files.", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - parser.add_argument("pdf", nargs="+", help="Путь(и) к PDF файлам") - parser.add_argument( - "--save-md", - action="store_true", - ) - - args = parser.parse_args() - - for pdf_path in args.pdf: - try: - cleaned, elapsed = parse_pdf(pdf_path, save_md=args.save_md) - print("-" * 80) - print(f"{pdf_path} — {len(cleaned.split())} слов, обработано за {elapsed:.2f} c") - snippet = cleaned[:500] + ("…" if len(cleaned) > 500 else "") - print(snippet) - except Exception as exc: - print(f"[ERROR] {pdf_path}: {exc}", file=sys.stderr) - continue \ No newline at end of file + return process_text(raw) + + +@retry( + stop=stop_after_attempt(getattr(settings, "download_retries", 3)), + wait=wait_exponential(multiplier=0.5, min=0.5, max=8), + retry=retry_if_exception_type(httpx.HTTPError), + reraise=True, +) +async def _get_with_retry(client: httpx.AsyncClient, url: str) -> bytes: + """Download *url* once, raising on any HTTP/client error.""" + resp = await client.get(url, timeout=30.0) + resp.raise_for_status() + return resp.content + + +async def fetch_and_parse( + urls: list[str], + *, + client: httpx.AsyncClient | None = None, +) -> list[str | None]: + """Download multiple PDF URLs concurrently and return cleaned texts. + + Args: + urls: HTTP(S) links pointing to PDF files. + + Returns: + List of cleaned document texts (``None`` for failed downloads/parses), + preserving the input order. + """ + + results: list[str | None] = [None] * len(urls) + semaphore = asyncio.Semaphore(settings.acl_concurrency) + + async def _worker(i: int, url: str) -> None: + async with semaphore: + try: + pdf_bytes = await _get_with_retry(client, url) + except Exception as exc: + logger.error("Download failed for %s: %s", url, exc) + return + + try: + text = parse_pdf(pdf_bytes) + except Exception as exc: + logger.error("Parsing failed for %s: %s", url, exc) + return + + results[i] = text + + await asyncio.gather(*(_worker(i, u) for i, u in enumerate(urls))) + return results diff --git a/src/indexing/settings.py b/src/indexing/settings.py index 0c4cbd2..b8399aa 100644 --- a/src/indexing/settings.py +++ b/src/indexing/settings.py @@ -21,8 +21,12 @@ class Settings(BaseSettings): Name of the Elasticsearch index to create / populate. batch_size Number of BibTeX entries processed per batch (also used for ES bulk size). - concurrency - Worker processes for Docling PDF→Markdown conversion. + acl_concurrency + Number of asyncio workers for PDF downloading. + crossref_concurrency + Number of asyncio workers for Crossref API requests. + download_retries + Number of retries for downloading PDFs. max_entries Optional hard cap on number of BibTeX records to ingest (``None`` = all). force_delete_index @@ -31,12 +35,16 @@ class Settings(BaseSettings): Elasticsearch HTTP endpoint (single-node setup assumed). openai_base_url, embedding_model_name, embedding_batch_size, openai_api_key See :pyfile:`src.common.settings` – parameters forwarded to the embedding client. + pagerank_alpha + Damping factor for PageRank (probability of following citation links). """ bib_file: Path = Field(..., env="BIB_FILE") index_name: str = Field("papers", env="INDEX_NAME") batch_size: int = Field(100, env="BATCH_SIZE") - concurrency: int = Field(4, env="CONCURRENCY") + acl_concurrency: int = Field(10, env="ACL_CONCURRENCY") + crossref_concurrency: int = Field(10, env="CROSSREF_CONCURRENCY") + download_retries: int = Field(3, env="DOWNLOAD_RETRIES") max_entries: Optional[int] = Field(None, env="MAX_ENTRIES") force_delete_index: bool = Field(False, env="FORCE_DELETE_INDEX") es_host: str = Field("http://localhost:9200", env="ES_HOST") @@ -49,6 +57,8 @@ class Settings(BaseSettings): openai_api_key: str = Field("not-needed", env="OPENAI_API_KEY") + pagerank_alpha: float = Field(0.85, env="PAGERANK_ALPHA") + class Config: env_file = ".env" env_file_encoding = "utf-8" diff --git a/src/search/query_processor.py b/src/search/query_processor.py new file mode 100644 index 0000000..af87b7c --- /dev/null +++ b/src/search/query_processor.py @@ -0,0 +1,596 @@ + +import re +import logging +from typing import List, Tuple, Optional, Set, Dict +from difflib import SequenceMatcher +from collections import defaultdict +import json +from pathlib import Path + +try: + from spellchecker import SpellChecker + SPELLCHECKER_AVAILABLE = True +except ImportError: + SPELLCHECKER_AVAILABLE = False + print("WARNING: pyspellchecker not installed. Install with: pip install pyspellchecker") + +try: + import nltk + from nltk.corpus import wordnet + NLTK_AVAILABLE = True +except ImportError: + NLTK_AVAILABLE = False + print("WARNING: nltk not installed. Install with: pip install nltk") + + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +class QueryProcessor: + """ + Query processor with typo correction for academic search. + + This system uses a multi-level approach to typo correction: + + 1. **Common typos dictionary** - manual dictionary of typical errors in NLP/ML domain + 2. **NLP-specific vocabulary** - terms that should NOT be corrected + 3. **Common English dictionary** - preloaded dictionary of common words (~30k words) + 4. **WordNet** (optional) - for checking word existence + + Атрибуты: + spell (SpellChecker): Main spell checker with preloaded dictionary + min_confidence (float): Minimum confidence for applying correction + max_corrections (int): Maximum number of corrections per query + _nlp_vocabulary (set): NLP/ML specialized terms + _common_typos (dict): Dictionary of known typos → correct spelling + _correction_cache (dict): Cache for speeding up repeated corrections + """ + + def __init__( + self, + custom_vocabulary: Optional[List[str]] = None, + min_confidence: float = 0.7, + max_corrections: int = 3, + enable_cache: bool = True, + dictionary_path: Optional[str] = None + ): + """ + Initialize query processor. + + Args: + custom_vocabulary: Additional terms for the dictionary + min_confidence: Confidence threshold (0.0-1.0). Corrections with lower + confidence are ignored + max_corrections: Limit on the number of corrections in one query + enable_cache: Enable correction cache + dictionary_path: Path to the user dictionary (text file) + """ + self.min_confidence = min_confidence + self.max_corrections = max_corrections + self.enable_cache = enable_cache + self._correction_cache: Dict[str, Tuple[str, float]] = {} + + if SPELLCHECKER_AVAILABLE: + self.spell = SpellChecker() + try: + if hasattr(self.spell.word_frequency, "unique_words"): + num_words = self.spell.word_frequency.unique_words # type: ignore[attr-defined] + else: + num_words = len(list(self.spell.word_frequency.words())) # type: ignore[arg-type] + except Exception: + num_words = "unknown" + logger.info(f"SpellChecker initialized (words in dictionary: {num_words})") + else: + self.spell = None + logger.warning("SpellChecker not available, using fallback mode") + + if NLTK_AVAILABLE: + try: + nltk.data.find('corpora/wordnet') + self.use_wordnet = True + logger.info("WordNet available for additional validation") + except LookupError: + logger.info("Downloading WordNet...") + nltk.download('wordnet') + self.use_wordnet = True + else: + self.use_wordnet = False + + self._nlp_vocabulary = { + 'transformer', 'transformers', 'bert', 'roberta', 'xlnet', 'albert', + 'electra', 'gpt', 'gpt2', 'gpt3', 'gpt4', 't5', 'bart', 'mbart', + 'xlm', 'xlmr', 'distilbert', 'deberta', 'longformer', 'reformer', + 'funnel', 'convbert', 'layoutlm', 'tapas', 'flaubert', 'camembert', + + 'lstm', 'gru', 'rnn', 'cnn', 'gcn', 'gnn', 'gan', 'vae', 'ae', + 'mlp', 'ffn', 'resnet', 'densenet', 'efficientnet', 'vgg', + + 'embedding', 'embeddings', 'encoder', 'decoder', 'attention', + 'multihead', 'crossattention', 'selfattention', 'positional', + 'tokenizer', 'tokenization', 'subword', 'wordpiece', 'sentencepiece', + 'bpe', 'dropout', 'layernorm', 'batchnorm', 'normalization', + + 'pretrained', 'finetuning', 'finetune', 'finetuned', 'pretraining', + 'downstream', 'upstream', 'zeroshot', 'fewshot', 'multitask', + 'transfer', 'learning', 'supervised', 'unsupervised', 'semisupervised', + 'reinforcement', 'metalearning', 'continual', 'federated', + + 'classification', 'generation', 'summarization', 'translation', + 'parsing', 'tagging', 'ner', 'pos', 'srl', 'coref', 'coreference', + 'qa', 'qg', 'nli', 'sts', 'paraphrase', 'entailment', 'sentiment', + 'emotion', 'stance', 'factchecking', 'textual', 'similarity', + + 'semantic', 'syntactic', 'morphological', 'phonological', 'pragmatic', + 'lexical', 'discourse', 'anaphora', 'cataphora', 'deixis', + 'polysemy', 'synonymy', 'antonymy', 'hyponymy', 'hypernymy', + + 'dataset', 'corpus', 'corpora', 'benchmark', 'leaderboard', + 'train', 'validation', 'dev', 'test', 'split', 'fold', + 'augmentation', 'preprocessing', 'postprocessing', + + 'accuracy', 'precision', 'recall', 'f1', 'bleu', 'rouge', 'meteor', + 'bertscore', 'bleurt', 'comet', 'chrff', 'ter', 'wer', 'cer', + 'perplexity', 'loss', 'crossentropy', 'mse', 'mae', 'rmse', + + 'acl', 'naacl', 'eacl', 'aacl', 'emnlp', 'coling', 'lrec', + 'conll', 'tacl', 'cl', 'aaai', 'ijcai', 'icml', 'neurips', + 'iclr', 'cvpr', 'eccv', 'iccv', 'acmmm', 'sigir', 'sigkdd', + + 'nlp', 'nlu', 'nlg', 'mt', 'ir', 'kg', 'kge', 'ie', 're', + 'api', 'sdk', 'gpu', 'cpu', 'tpu', 'cuda', 'onnx', 'tensorrt', + 'pytorch', 'tensorflow', 'jax', 'keras', 'huggingface', 'spacy', + 'nltk', 'stanford', 'allennlp', 'fairseq', 'transformers', + 'tokenizers', 'datasets', 'wandb', 'mlflow', 'tensorboard', + 'arxiv', 'openreview', 'github', 'sota', 'baseline', 'ablation' + } + + self._common_typos = {} + + if custom_vocabulary: + self._nlp_vocabulary.update(word.lower() for word in custom_vocabulary) + logger.info(f"Added {len(custom_vocabulary)} custom vocabulary terms") + + if dictionary_path and Path(dictionary_path).exists(): + self._load_custom_dictionary(dictionary_path) + + if self.spell: + self.spell.word_frequency.load_words(self._nlp_vocabulary) + logger.info(f"Updated spellchecker with {len(self._nlp_vocabulary)} NLP terms") + + self._stats = { + 'total_queries': 0, + 'corrected_queries': 0, + 'total_corrections': 0, + 'cache_hits': 0, + 'typo_frequencies': defaultdict(int) + } + + def _load_custom_dictionary(self, path: str) -> None: + """Load additional dictionary from file.""" + try: + with open(path, 'r', encoding='utf-8') as f: + words = {line.strip().lower() for line in f if line.strip()} + self._nlp_vocabulary.update(words) + logger.info(f"Loaded {len(words)} words from {path}") + except Exception as e: + logger.error(f"Failed to load dictionary from {path}: {e}") + + def process_query(self, query: str) -> Tuple[str, List[str], float]: + """ + Process and correct user query. + + Args: + query: Original user query + + Returns: + Tuple containing: + - corrected_query: Corrected query + - corrections: List of applied corrections in format "was → became" + - confidence: Confidence in corrections (0.0-1.0) + """ + self._stats['total_queries'] += 1 + + cleaned_query = self._clean_query(query) + + corrected_query, corrections = self._correct_typos(cleaned_query) + + confidence = self._calculate_confidence(query, corrected_query, corrections) + + if corrections: + self._stats['corrected_queries'] += 1 + self._stats['total_corrections'] += len(corrections) + for correction in corrections: + typo = correction.split(' → ')[0] + self._stats['typo_frequencies'][typo] += 1 + + logger.info(f"Query processing: '{query}' -> '{corrected_query}' " + f"(confidence: {confidence:.2f}, corrections: {len(corrections)})") + + return corrected_query, corrections, confidence + + def _clean_query(self, query: str) -> str: + """ + Clean and normalize query. + """ + query = ' '.join(query.split()) + + query = re.sub(r'\s+([?.!,])', r'\1', query) + query = re.sub(r'([?.!,])\s*', r'\1 ', query) + + if query.count('"') % 2 != 0: + query = query.replace('"', '') + + return query.strip() + + def _correct_typos(self, query: str) -> Tuple[str, List[str]]: + """ + Main typo correction logic. + """ + words = query.lower().split() + corrected_words = [] + corrections = [] + correction_count = 0 + + for i, word in enumerate(words): + if correction_count >= self.max_corrections: + corrected_words.append(word) + continue + + if word in self._nlp_vocabulary: + corrected_words.append(word) + continue + + if len(word) <= 2: + corrected_words.append(word) + continue + + if self.enable_cache and word in self._correction_cache: + cached_correction, cached_confidence = self._correction_cache[word] + if cached_confidence >= self.min_confidence: + corrected_words.append(cached_correction) + if cached_correction != word: + corrections.append(f"{word} → {cached_correction}") + correction_count += 1 + self._stats['cache_hits'] += 1 + else: + corrected_words.append(word) + continue + + correction = self._get_spell_correction(word) + if correction and correction != word: + similarity = SequenceMatcher(None, word, correction).ratio() + + if similarity >= self.min_confidence: + corrected_words.append(correction) + corrections.append(f"{word} → {correction}") + correction_count += 1 + + if self.enable_cache: + self._correction_cache[word] = (correction, similarity) + else: + corrected_words.append(word) + if self.enable_cache: + self._correction_cache[word] = (word, similarity) + else: + corrected_words.append(word) + if self.enable_cache: + self._correction_cache[word] = (word, 1.0) + + corrected_query = self._restore_case(corrected_words, query.split()) + + return corrected_query, corrections + + def _get_spell_correction(self, word: str) -> Optional[str]: + """ + Get correction from spell checker with additional validation. + """ + if not self.spell: + return None + + correction = self.spell.correction(word) + + if self.use_wordnet and correction and correction != word: + if wordnet.synsets(word): + return word + else: + candidates = self.spell.candidates(word) + if candidates: + for candidate in candidates: + if wordnet.synsets(candidate): + return candidate + + return correction + + def _restore_case(self, corrected_words: List[str], original_words: List[str]) -> str: + """ + Restore case and formatting from the original query. + """ + result = [] + + acronyms = { + 'nlp': 'NLP', 'bert': 'BERT', 'gpt': 'GPT', 'lstm': 'LSTM', + 'rnn': 'RNN', 'cnn': 'CNN', 'gan': 'GAN', 'vae': 'VAE', + 'ner': 'NER', 'pos': 'POS', 'srl': 'SRL', 'qa': 'QA', + 'nli': 'NLI', 'mt': 'MT', 'ir': 'IR', 'kg': 'KG', 'ie': 'IE', + 'acl': 'ACL', 'emnlp': 'EMNLP', 'naacl': 'NAACL', + 'coling': 'COLING', 'lrec': 'LREC', 'conll': 'CoNLL', + 'aaai': 'AAAI', 'ijcai': 'IJCAI', 'icml': 'ICML', + 'neurips': 'NeurIPS', 'iclr': 'ICLR', 'cvpr': 'CVPR', + 'sota': 'SOTA', 'bleu': 'BLEU', 'rouge': 'ROUGE', 'f1': 'F1', + 'api': 'API', 'sdk': 'SDK', 'gpu': 'GPU', 'cpu': 'CPU', + 'tpu': 'TPU', 'cuda': 'CUDA', 'onnx': 'ONNX' + } + + for i, word in enumerate(corrected_words): + if word in acronyms: + result.append(acronyms[word]) + elif i < len(original_words): + if original_words[i][0].isupper(): + result.append(word.capitalize()) + elif original_words[i].isupper(): + result.append(word.upper()) + else: + result.append(word) + else: + result.append(word) + + return ' '.join(result) + + def _calculate_confidence(self, original: str, corrected: str, corrections: List[str]) -> float: + """ + Calculate confidence in corrections. + + Factors: + - Similarity between original and corrected string + - Number of corrections + """ + if not corrections: + return 1.0 + + similarity = SequenceMatcher(None, original.lower(), corrected.lower()).ratio() + + correction_penalty = len(corrections) * 0.05 + + confidence = min(1.0, max(0.0, similarity - correction_penalty)) + + return confidence + + def suggest_alternatives(self, query: str, top_k: int = 3) -> List[str]: + """ + Suggest alternative query variants. + + """ + alternatives = [] + words = query.lower().split() + + for i, word in enumerate(words): + if word in self._nlp_vocabulary or len(word) <= 2: + continue + + if self.spell: + candidates = self.spell.candidates(word) + if candidates and word not in candidates: + for candidate in list(candidates)[:2]: + alt_words = words.copy() + alt_words[i] = candidate + alt_query = ' '.join(alt_words) + if alt_query != query.lower(): + alternatives.append(self._restore_case( + alt_words, query.split() + )) + + seen = set() + unique_alternatives = [] + for alt in alternatives: + if alt not in seen: + seen.add(alt) + unique_alternatives.append(alt) + + return unique_alternatives[:top_k] + + def get_statistics(self) -> Dict: + """Get usage statistics.""" + return { + 'total_queries': self._stats['total_queries'], + 'corrected_queries': self._stats['corrected_queries'], + 'correction_rate': (self._stats['corrected_queries'] / + max(1, self._stats['total_queries'])), + 'total_corrections': self._stats['total_corrections'], + 'cache_hits': self._stats['cache_hits'], + 'cache_size': len(self._correction_cache), + 'most_common_typos': dict(sorted( + self._stats['typo_frequencies'].items(), + key=lambda x: x[1], + reverse=True + )[:10]) + } + + def save_cache(self, path: str) -> None: + """Save correction cache for future use.""" + cache_data = { + 'cache': self._correction_cache, + 'stats': self._stats, + 'typo_frequencies': dict(self._stats['typo_frequencies']) + } + with open(path, 'w', encoding='utf-8') as f: + json.dump(cache_data, f, indent=2, ensure_ascii=False) + logger.info(f"Cache saved to {path}") + + def load_cache(self, path: str) -> None: + """Load correction cache.""" + if Path(path).exists(): + with open(path, 'r', encoding='utf-8') as f: + cache_data = json.load(f) + self._correction_cache = cache_data.get('cache', {}) + self._correction_cache = { + k: tuple(v) for k, v in self._correction_cache.items() + } + logger.info(f"Cache loaded from {path} ({len(self._correction_cache)} entries)") + + +def preprocess_query( + query: str, + processor: Optional[QueryProcessor] = None +) -> Tuple[str, float]: + """ + Simple function for integration into existing code. + + Args: + query: User query + processor: QueryProcessor instance (created automatically if None) + + Returns: + Tuple (corrected_query, confidence) + """ + if processor is None: + processor = QueryProcessor() + + corrected_query, corrections, confidence = processor.process_query(query) + + if corrections: + logger.info(f"Query corrections applied: {', '.join(corrections)}") + + return corrected_query, confidence + + +def demo(): + """Demo of all typo correction capabilities.""" + + print("=" * 80) + print("TYPO CORRECTION SYSTEM FOR ACADEMIC SEARCH") + print("=" * 80) + + processor = QueryProcessor( + min_confidence=0.7, + max_corrections=3, + enable_cache=True + ) + + test_queries = [ + "tansformer attenton mechansim", + "embedings for nlp", + "nueral netwrok architechture", + + "bert modle for clasification", + "gpt3 langauge generation", + + "sumarization with transformet", + "sentement analysis dataset", + + "evaluaiton metrics blue and rough", + "f-1 score and acuracy", + + "emnpl 2024 paper", + "nacaal workshop on mt", + + "preprocesing for tokenizaton", + "benchamrk performace comparison", + + "transformer attention mechanism", + "BERT embeddings visualization", + "neural machine translation", + + "transfomer-based aproach for nlp", + "multi-head attenton in bert", + ] + + print("\nTEST CASES:") + print("-" * 80) + + for i, query in enumerate(test_queries, 1): + corrected, corrections, confidence = processor.process_query(query) + + print(f"\n{i}. Original query: {query}") + print(f" Corrected: {corrected}") + + if corrections: + print(f" Corrections: {' | '.join(corrections)}") + else: + print(" Corrections: [not required]") + + print(f" Confidence: {confidence:.1%}") + + if confidence < 0.8 and corrections: + alternatives = processor.suggest_alternatives(query) + if alternatives: + print(f" Alternatives: {' | '.join(alternatives[:2])}") + + print("\n\nSTATISTICS:") + print("-" * 80) + stats = processor.get_statistics() + print(f"Total queries: {stats['total_queries']}") + print(f"Corrected queries: {stats['corrected_queries']}") + print(f"Correction rate: {stats['correction_rate']:.1%}") + print(f"Total corrections: {stats['total_corrections']}") + print(f"Cache hits: {stats['cache_hits']}") + print(f"Cache size: {stats['cache_size']}") + + if stats['most_common_typos']: + print("\nTop-5 frequent typos:") + for typo, count in list(stats['most_common_typos'].items())[:5]: + print(f" - '{typo}': {count} times") + + print("\n\nINTERACTIVE MODE") + print("-" * 80) + print("Enter queries (type 'exit' to quit)") + print("Commands: 'stats' - show stats, 'cache' - save cache") + print("-" * 80) + + while True: + try: + user_input = input("\n> ").strip() + + if user_input.lower() in ['exit', 'quit', 'q']: + break + elif user_input.lower() == 'stats': + stats = processor.get_statistics() + print(f"\nProcessed queries: {stats['total_queries']}") + print(f"Correction rate: {stats['correction_rate']:.1%}") + continue + elif user_input.lower() == 'cache': + processor.save_cache('query_corrections_cache.json') + print("Cache saved to query_corrections_cache.json") + continue + + if not user_input: + continue + + corrected, corrections, confidence = processor.process_query(user_input) + + print(f"\nResult: {corrected}") + + if corrections: + print(f" Corrections: {' | '.join(corrections)}") + print(f" Confidence: {confidence:.1%}") + + if confidence < 0.8: + alternatives = processor.suggest_alternatives(user_input) + if alternatives: + print(f" Alternatives: {' | '.join(alternatives)}") + else: + print(" [No corrections required]") + + except KeyboardInterrupt: + print("\n\nExiting...") + break + + processor.save_cache('query_corrections_cache.json') + print("\nCache saved for future use") + + +if __name__ == "__main__": + print("Dependency check...") + if not SPELLCHECKER_AVAILABLE: + print("WARNING: pyspellchecker not installed. Install with: pip install pyspellchecker") + if not NLTK_AVAILABLE: + print("WARNING: nltk not installed. Install with: pip install nltk") + + if SPELLCHECKER_AVAILABLE: + print("pyspellchecker available") + if NLTK_AVAILABLE: + print("nltk available") + + print() + + demo() \ No newline at end of file diff --git a/src/search/search_cli.py b/src/search/search_cli.py index 6dd5e44..6ff403e 100644 --- a/src/search/search_cli.py +++ b/src/search/search_cli.py @@ -1,131 +1,271 @@ -"""Interactive search CLI – BM25 + vector similarity with client-side RRF. - -For every user query the tool: -1. Embeds the query with :class:`DenseEmbedder`. -2. Executes two ES searches (lexical + KNN). -3. Fuses the result sets locally using Reciprocal Rank Fusion. -4. Prints the top-k documents with basic metadata. - -All runtime parameters are supplied via :pyfile:`src.search.settings`. -""" - -from __future__ import annotations - -from collections import defaultdict -from elasticsearch import Elasticsearch -from src.common.dense_embedder import DenseEmbedder -from src.search.settings import settings - - -def _print_hit(rank: int, hit: dict) -> None: - source = hit["_source"] - title = source.get("title", "") - year = source.get("year", "?") - authors = ", ".join(source.get("author", []) or []) - print(f"{rank}. {title} ({year})") - if authors: - print(f" {authors}") - if source.get("url"): - print(f" {source['url']}") - - -def _connect() -> Elasticsearch: - """Create and validate an Elasticsearch connection.""" - - client = Elasticsearch(settings.es_host) - if not client.ping(): - raise SystemExit(f"Cannot connect to Elasticsearch at {settings.es_host}") - return client - - -def _bm25_body(query: str, size: int) -> dict: - return { - "query": {"match": {"text": {"query": query, "fuzziness": "AUTO"}}}, - "size": size, - "_source": ["title", "author", "year", "url"], - } - - -def _knn_body(query_vector: list[float]) -> dict: - return { - "knn": { - "field": "text_embedding", - "query_vector": query_vector, - "k": settings.knn_k, - "num_candidates": settings.knn_candidates, - }, - "size": settings.knn_k, - "_source": ["title", "author", "year", "url"], - } - - -def _rrf_fuse( - lex_hits: list[dict], knn_hits: list[dict], window_size: int, top_k: int -) -> list[dict]: - """Combine two ranking lists using Reciprocal Rank Fusion.""" - - def rrf(rank: int, k: int) -> float: - return 1.0 / (k + rank) - - scores: dict[str, float] = defaultdict(float) - doc_store: dict[str, dict] = {} - - for r, hit in enumerate(lex_hits, 1): - doc_id = hit["_id"] - scores[doc_id] += rrf(r, window_size) - doc_store[doc_id] = hit - - for r, hit in enumerate(knn_hits, 1): - doc_id = hit["_id"] - scores[doc_id] += rrf(r, window_size) - doc_store[doc_id] = hit - - sorted_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] - return [doc_store[doc_id] for doc_id, _ in sorted_docs] - - -def _interactive_loop(client: Elasticsearch, embedder: DenseEmbedder) -> None: - """Prompt the user for queries and display fused results.""" - - while True: - try: - query = input("query> ").strip() - except (KeyboardInterrupt, EOFError): - print() - break - - if not query or query.lower() in {"exit", "quit"}: - break - - query_vector = embedder.embed_query(query) - window_size = max(settings.knn_k, settings.top_k) - - lex_hits = client.search( - index=settings.index_name, - body=_bm25_body(query, window_size), - )["hits"]["hits"] - - knn_hits = client.search( - index=settings.index_name, - body=_knn_body(query_vector), - )["hits"]["hits"] - - if not lex_hits and not knn_hits: - print("No matches\n") - continue - - hits = _rrf_fuse(lex_hits, knn_hits, window_size, settings.top_k) - - for i, hit in enumerate(hits, 1): - _print_hit(i, hit) - print() - - -def main() -> None: - client = _connect() - embedder = DenseEmbedder() - _interactive_loop(client, embedder) - - -if __name__ == "__main__": - main() +"""Enhanced interactive search CLI with typo correction. + +This is an improved version of search_cli.py that includes: +- Automatic typo correction before search +- Query suggestions for ambiguous corrections +- Confidence scoring for corrections +- Option to use original query if correction confidence is low +""" + +from __future__ import annotations + +from collections import defaultdict +from elasticsearch import Elasticsearch +from src.common.dense_embedder import DenseEmbedder +from src.search.settings import settings +from src.search.query_processor import QueryProcessor, preprocess_query +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def _print_hit(rank: int, hit: dict) -> None: + source = hit["_source"] + title = source.get("title", "") + year = source.get("year", "?") + authors = ", ".join(source.get("author", []) or []) + pagerank = source.get("pagerank", 0.0) + + print(f"{rank}. {title} ({year})") + if authors: + print(f" Authors: {authors}") + if source.get("url"): + print(f" URL: {source['url']}") + if pagerank > 0: + print(f" PageRank: {pagerank:.6f}") + + +def _connect() -> Elasticsearch: + """Create and validate an Elasticsearch connection.""" + client = Elasticsearch(settings.es_host) + if not client.ping(): + raise SystemExit(f"Cannot connect to Elasticsearch at {settings.es_host}") + return client + + +def _bm25_body(query: str, size: int) -> dict: + return { + "query": {"match": {"text": {"query": query, "fuzziness": "AUTO"}}}, + "size": size, + "_source": ["title", "author", "year", "url", "pagerank"], + } + + +def _knn_body(query_vector: list[float]) -> dict: + return { + "knn": { + "field": "text_embedding", + "query_vector": query_vector, + "k": settings.knn_k, + "num_candidates": settings.knn_candidates, + }, + "size": settings.knn_k, + "_source": ["title", "author", "year", "url", "pagerank"], + } + + +def _rrf_fuse( + lex_hits: list[dict], knn_hits: list[dict], window_size: int, top_k: int +) -> list[dict]: + """ + Combine BM25, KNN, and PageRank using Reciprocal Rank Fusion (RRF). + """ + def rrf(rank: int, k: int) -> float: + return 1.0 / (k + rank) + + scores: dict[str, float] = defaultdict(float) + doc_store: dict[str, dict] = {} + + # Store hits and prepare rank dictionaries + bm25_rank: dict[str, int] = {} + knn_rank: dict[str, int] = {} + pagerank_values: dict[str, float] = {} + + for r, hit in enumerate(lex_hits, 1): + doc_id = hit["_id"] + doc_store[doc_id] = hit + bm25_rank[doc_id] = r + pr = hit["_source"].get("pagerank") + if pr is not None: + pagerank_values[doc_id] = pr + + for r, hit in enumerate(knn_hits, 1): + doc_id = hit["_id"] + doc_store.setdefault(doc_id, hit) + knn_rank[doc_id] = r + pr = hit["_source"].get("pagerank") + if pr is not None: + pagerank_values[doc_id] = pr + + # Determine PageRank ranking among all collected docs + pr_sorted = sorted(pagerank_values.items(), key=lambda x: x[1], reverse=True) + pr_rank = {doc_id: r + 1 for r, (doc_id, _) in enumerate(pr_sorted)} + + # Fuse using RRF formula + for doc_id in doc_store: + if doc_id in bm25_rank: + scores[doc_id] += rrf(bm25_rank[doc_id], window_size) + if doc_id in knn_rank: + scores[doc_id] += rrf(knn_rank[doc_id], window_size) + if doc_id in pr_rank: + scores[doc_id] += rrf(pr_rank[doc_id], window_size) + + sorted_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] + return [doc_store[d_id] for d_id, _ in sorted_docs] + + +def _execute_search( + client: Elasticsearch, + embedder: DenseEmbedder, + query: str, + show_corrections: bool = True +) -> list[dict]: + """Execute the search with the given query.""" + # Embed the query for vector search + query_vector = embedder.embed_query(query) + window_size = max(settings.knn_k, settings.top_k) + + # Execute BM25 search + lex_hits = client.search( + index=settings.index_name, + body=_bm25_body(query, window_size), + )["hits"]["hits"] + + # Execute KNN search + knn_hits = client.search( + index=settings.index_name, + body=_knn_body(query_vector), + )["hits"]["hits"] + + if not lex_hits and not knn_hits: + return [] + + # Fuse results + return _rrf_fuse(lex_hits, knn_hits, window_size, settings.top_k) + + +def _interactive_loop( + client: Elasticsearch, + embedder: DenseEmbedder, + query_processor: QueryProcessor +) -> None: + """Enhanced prompt loop with typo correction.""" + + print("\n=== ACL Anthology Search Engine ===") + print("Type 'help' for commands, 'exit' to quit\n") + + while True: + try: + query = input("query> ").strip() + except (KeyboardInterrupt, EOFError): + print("\nGoodbye!") + break + + if not query: + continue + + # Handle commands + if query.lower() == "exit" or query.lower() == "quit": + print("Goodbye!") + break + elif query.lower() == "help": + print("\nCommands:") + print(" help - Show this help message") + print(" exit/quit - Exit the search engine") + print(" ! - Skip typo correction for this query") + print("\nJust type your search query to find papers.\n") + continue + + # Check if user wants to skip correction (prefix with !) + skip_correction = query.startswith("!") + if skip_correction: + query = query[1:].strip() + + # Process query with typo correction unless skipped + if not skip_correction: + corrected_query, corrections, confidence = query_processor.process_query(query) + + # Show corrections if any were made + if corrections and corrected_query != query: + print(f"\n📝 Query corrected: '{query}' → '{corrected_query}'") + print(f" Corrections: {', '.join(corrections)}") + print(f" Confidence: {confidence:.2%}") + + # If confidence is low, ask user + if confidence < 0.8: + print("\n Low confidence in corrections. Choose an option:") + print(f" 1) Use corrected: '{corrected_query}'") + print(f" 2) Use original: '{query}'") + + # Show alternatives if available + alternatives = query_processor.suggest_alternatives(query) + if alternatives: + print(" 3) Other suggestions:") + for i, alt in enumerate(alternatives, 1): + print(f" 3.{i}) '{alt}'") + + choice = input("\n Enter choice (1/2/3.x): ").strip() + + if choice == "2": + corrected_query = query + elif choice.startswith("3."): + try: + alt_idx = int(choice[2:]) - 1 + if 0 <= alt_idx < len(alternatives): + corrected_query = alternatives[alt_idx] + except ValueError: + pass + + query_to_search = corrected_query + else: + query_to_search = query + else: + query_to_search = query + + # Execute search + print(f"\n🔍 Searching for: '{query_to_search}'...") + hits = _execute_search(client, embedder, query_to_search) + + if not hits: + print("No matches found.\n") + + # Suggest alternatives if no results + if not skip_correction: + alternatives = query_processor.suggest_alternatives(query_to_search) + if alternatives: + print("Did you mean:") + for alt in alternatives: + print(f" - {alt}") + print() + continue + + # Display results + print(f"\nFound {len(hits)} results:\n") + for i, hit in enumerate(hits, 1): + _print_hit(i, hit) + print() # Empty line between results + + +def main() -> None: + """Enhanced main function with query processing.""" + # Initialize components + client = _connect() + embedder = DenseEmbedder() + + # Initialize query processor with custom NLP vocabulary + custom_vocab = [ + # Add conference-specific terms + 'emnlp', 'naacl', 'eacl', 'coling', 'lrec', 'conll', 'tacl', + # Add more technical terms as needed + 'roberta', 'xlnet', 'albert', 'electra', 't5', 'bart', + 'squad', 'glue', 'superglue', 'bleu', 'rouge', 'meteor', + ] + query_processor = QueryProcessor(custom_vocabulary=custom_vocab) + + # Run interactive loop + _interactive_loop(client, embedder, query_processor) + + +if __name__ == "__main__": + main() \ No newline at end of file