From c1876ae76ba5e3c20777bfbe4e8d9768d4604e37 Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Thu, 26 Feb 2026 02:59:26 +0000 Subject: [PATCH 1/4] Add LanceDB hybrid search (dense + FTS) for retriever recall Enable combined dense vector + full-text search via --hybrid flag in the batch pipeline. When enabled, ingestion creates both IVF_HNSW_SQ and FTS indices, and recall uses RRF reranking to merge results. On jp20 (20 PDFs, 115 queries): Dense: recall@1=0.61, recall@5=0.90, recall@10=0.96 Hybrid: recall@1=0.65, recall@5=0.94, recall@10=0.96 Co-Authored-By: Claude Opus 4.6 Signed-off-by: Jacob Ioffe --- .../src/retriever/examples/batch_pipeline.py | 12 +++++- retriever/src/retriever/ingest_modes/batch.py | 10 +++++ retriever/src/retriever/params/models.py | 2 + retriever/src/retriever/recall/core.py | 41 +++++++++++++++---- .../src/retriever/vector_store/__init__.py | 8 +++- .../retriever/vector_store/lancedb_store.py | 39 +++++++++++++----- 6 files changed, 90 insertions(+), 22 deletions(-) diff --git a/retriever/src/retriever/examples/batch_pipeline.py b/retriever/src/retriever/examples/batch_pipeline.py index 5a9ad955e..4cf7d6277 100644 --- a/retriever/src/retriever/examples/batch_pipeline.py +++ b/retriever/src/retriever/examples/batch_pipeline.py @@ -331,7 +331,7 @@ def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: return None, float(hit.get("_distance")) if "_distance" in hit else None key = f"{Path(str(source_id)).stem}_{page_number}" - dist = float(hit.get("_distance")) if "_distance" in hit else None + dist = float(hit["_distance"]) if "_distance" in hit else float(hit["_score"]) if "_score" in hit else None return key, dist @@ -531,6 +531,11 @@ def main( dir_okay=False, help="Optional JSON file path to write end-of-run detection counts summary.", ), + hybrid: bool = typer.Option( + False, + "--hybrid/--no-hybrid", + help="Enable LanceDB hybrid mode (dense + FTS text).", + ), ) -> None: log_handle, original_stdout, original_stderr = _configure_logging(log_file) try: @@ -578,6 +583,7 @@ def main( "table_name": LANCEDB_TABLE, "overwrite": True, "create_index": True, + "hybrid": hybrid, } ) ) @@ -599,6 +605,7 @@ def main( "table_name": LANCEDB_TABLE, "overwrite": True, "create_index": True, + "hybrid": hybrid, } ) ) @@ -656,6 +663,7 @@ def main( "table_name": LANCEDB_TABLE, "overwrite": True, "create_index": True, + "hybrid": hybrid, } ) ) @@ -712,6 +720,7 @@ def main( "table_name": LANCEDB_TABLE, "overwrite": True, "create_index": True, + "hybrid": hybrid, } ) ) @@ -784,6 +793,7 @@ def main( embedding_model=_recall_model, top_k=10, ks=(1, 5, 10), + hybrid=hybrid, ) _df_query, _gold, _raw_hits, _retrieved_keys, metrics = retrieve_and_score(query_csv=query_csv, cfg=cfg) diff --git a/retriever/src/retriever/ingest_modes/batch.py b/retriever/src/retriever/ingest_modes/batch.py index 769180cfb..0b4b509c1 100644 --- a/retriever/src/retriever/ingest_modes/batch.py +++ b/retriever/src/retriever/ingest_modes/batch.py @@ -984,6 +984,16 @@ def _create_lancedb_index(self) -> None: except Exception as e: print(f"Warning: failed to create LanceDB index (continuing without index): {e}") + if kw.get("hybrid", False): + text_column = str(kw.get("text_column", "text")) + fts_language = str(kw.get("fts_language", "English")) + try: + table.create_fts_index(text_column, language=fts_language) + except Exception as e: + print( + f"Warning: FTS index creation failed on column {text_column!r} (continuing with vector-only): {e}" + ) + for index_stub in table.list_indices(): table.wait_for_index([index_stub.name], timeout=timedelta(seconds=600)) diff --git a/retriever/src/retriever/params/models.py b/retriever/src/retriever/params/models.py index a64ef03a8..cb84601bf 100644 --- a/retriever/src/retriever/params/models.py +++ b/retriever/src/retriever/params/models.py @@ -85,6 +85,8 @@ class LanceDbParams(_ParamsModel): embedding_key: str = "embedding" include_text: bool = True text_column: str = "text" + hybrid: bool = False + fts_language: str = "English" class BatchTuningParams(_ParamsModel): diff --git a/retriever/src/retriever/recall/core.py b/retriever/src/retriever/recall/core.py index 75cc3b296..7edb24e15 100644 --- a/retriever/src/retriever/recall/core.py +++ b/retriever/src/retriever/recall/core.py @@ -41,6 +41,7 @@ class RecallConfig: # top candidates with full-precision vectors to eliminate SQ quantization error. nprobes: int = 0 refine_factor: int = 10 + hybrid: bool = False # Local HF knobs (only used when endpoints are missing). local_hf_device: Optional[str] = None local_hf_cache_dir: Optional[str] = None @@ -171,6 +172,8 @@ def _search_lancedb( vector_column_name: str = "vector", nprobes: int = 0, refine_factor: int = 10, + query_texts: Optional[List[str]] = None, + hybrid: bool = False, ) -> List[List[Dict[str, Any]]]: import lancedb # type: ignore @@ -194,16 +197,34 @@ def _search_lancedb( effective_nprobes = 16 # safe fallback matching default index config results: List[List[Dict[str, Any]]] = [] - for v in query_vectors: + for i, v in enumerate(query_vectors): q = np.asarray(v, dtype="float32") - hits = ( - table.search(q, vector_column_name=vector_column_name) - .nprobes(effective_nprobes) - .refine_factor(refine_factor) - .select(["text", "metadata", "source", "_distance"]) - .limit(top_k) - .to_list() - ) + + if hybrid and query_texts is not None: + from lancedb.rerankers import RRFReranker # type: ignore + + text = query_texts[i] + hits = ( + table.search(query_type="hybrid") + .vector(q) + .text(text) + .nprobes(effective_nprobes) + .refine_factor(refine_factor) + .select(["text", "metadata", "source"]) + .limit(top_k) + .rerank(RRFReranker()) + .to_list() + ) + else: + hits = ( + table.search(q, vector_column_name=vector_column_name) + .nprobes(effective_nprobes) + .refine_factor(refine_factor) + .select(["text", "metadata", "source", "_distance"]) + .limit(top_k) + .to_list() + ) + results.append(hits) return results @@ -299,6 +320,8 @@ def retrieve_and_score( vector_column_name=vector_column_name, nprobes=int(cfg.nprobes), refine_factor=int(cfg.refine_factor), + query_texts=queries, + hybrid=bool(cfg.hybrid), ) retrieved_keys = _hits_to_keys(raw_hits) metrics = {f"recall@{k}": _recall_at_k(gold, retrieved_keys, int(k)) for k in cfg.ks} diff --git a/retriever/src/retriever/vector_store/__init__.py b/retriever/src/retriever/vector_store/__init__.py index 1d050ff3f..1c05e4e8f 100644 --- a/retriever/src/retriever/vector_store/__init__.py +++ b/retriever/src/retriever/vector_store/__init__.py @@ -3,11 +3,17 @@ # SPDX-License-Identifier: Apache-2.0 from .__main__ import app -from .lancedb_store import LanceDBConfig, write_embeddings_to_lancedb, write_text_embeddings_dir_to_lancedb +from .lancedb_store import ( + LanceDBConfig, + create_lancedb_index, + write_embeddings_to_lancedb, + write_text_embeddings_dir_to_lancedb, +) __all__ = [ "app", "LanceDBConfig", + "create_lancedb_index", "write_embeddings_to_lancedb", "write_text_embeddings_dir_to_lancedb", ] diff --git a/retriever/src/retriever/vector_store/lancedb_store.py b/retriever/src/retriever/vector_store/lancedb_store.py index 4f85789a1..181ab0613 100644 --- a/retriever/src/retriever/vector_store/lancedb_store.py +++ b/retriever/src/retriever/vector_store/lancedb_store.py @@ -36,6 +36,9 @@ class LanceDBConfig: num_partitions: int = 16 num_sub_vectors: int = 256 + hybrid: bool = False + fts_language: str = "English" + def _read_text_embeddings_json_df(path: Path) -> pd.DataFrame: """ @@ -177,6 +180,30 @@ def _infer_vector_dim(rows: Sequence[Dict[str, Any]]) -> int: return 0 +def create_lancedb_index(table: Any, *, cfg: LanceDBConfig, text_column: str = "text") -> None: + """Create vector (IVF_HNSW_SQ) and optionally FTS indices on a LanceDB table.""" + try: + table.create_index( + index_type=cfg.index_type, + metric=cfg.metric, + num_partitions=int(cfg.num_partitions), + num_sub_vectors=int(cfg.num_sub_vectors), + vector_column_name="vector", + ) + except TypeError: + table.create_index(vector_column_name="vector") + + if cfg.hybrid: + try: + table.create_fts_index(text_column, language=cfg.fts_language) + except Exception: + logger.warning( + "FTS index creation failed on column %r; continuing with vector-only search.", + text_column, + exc_info=True, + ) + + def _write_rows_to_lancedb(rows: Sequence[Dict[str, Any]], *, cfg: LanceDBConfig) -> None: if not rows: logger.warning("No embeddings rows provided; nothing to write to LanceDB.") @@ -213,17 +240,7 @@ def _write_rows_to_lancedb(rows: Sequence[Dict[str, Any]], *, cfg: LanceDBConfig table = db.create_table(cfg.table_name, data=list(rows), schema=schema, mode=mode) if cfg.create_index: - try: - table.create_index( - index_type=cfg.index_type, - metric=cfg.metric, - num_partitions=int(cfg.num_partitions), - num_sub_vectors=int(cfg.num_sub_vectors), - vector_column_name="vector", - ) - except TypeError: - # Older/newer LanceDB versions may have different signatures; fall back to minimal call. - table.create_index(vector_column_name="vector") + create_lancedb_index(table, cfg=cfg) def write_embeddings_to_lancedb(df_with_embeddings: pd.DataFrame, *, cfg: LanceDBConfig) -> None: From c3526bce853ba9b6073c0bfeeb65d6a778f1f38c Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Thu, 26 Feb 2026 19:06:41 +0000 Subject: [PATCH 2/4] Fuse PDF split+extract, numpy passthrough, ActorPool for CPU stages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fuse PDFSplit + PDFExtraction into single PDFSplitAndExtractActor that opens each multi-page PDF once (eliminates redundant single-page PDF serialization) - Render pages to numpy arrays instead of PNG→base64 encoding; downstream page_elements and OCR stages accept numpy directly, only encoding to base64 for remote NIM endpoints - Switch PDF extraction from TaskPoolStrategy to ActorPoolStrategy (eliminates per-task process creation + library import overhead) - Increase pdf_split_batch_size default from 1 to 4 - Remove repartition barrier between extract and page_elements stages - Add BGR→RGB channel swap via OpenCV (aligned with api/ pdfium.py) Co-Authored-By: Claude Opus 4.6 Signed-off-by: Jacob Ioffe (cherry picked from commit d2dc2c7a5202fb52e7103d53a20a762f4414efa1) --- .../src/retriever/examples/batch_pipeline.py | 41 +-- retriever/src/retriever/ingest_modes/batch.py | 247 +++++++++--------- retriever/src/retriever/ocr/ocr.py | 38 +-- .../retriever/page_elements/page_elements.py | 40 ++- retriever/src/retriever/pdf/extract.py | 173 +++++++++++- 5 files changed, 378 insertions(+), 161 deletions(-) diff --git a/retriever/src/retriever/examples/batch_pipeline.py b/retriever/src/retriever/examples/batch_pipeline.py index 4cf7d6277..a70211e47 100644 --- a/retriever/src/retriever/examples/batch_pipeline.py +++ b/retriever/src/retriever/examples/batch_pipeline.py @@ -18,6 +18,7 @@ from typing import Optional, TextIO import lancedb +import pandas as pd import ray import typer from retriever import create_ingestor @@ -97,11 +98,11 @@ def _configure_logging(log_file: Optional[Path]) -> tuple[Optional[TextIO], Text return fh, original_stdout, original_stderr -def _estimate_processed_pages(uri: str, table_name: str) -> Optional[int]: +def _load_metadata_columns(uri: str, table_name: str) -> Optional[pd.DataFrame]: """ - Estimate pages processed by counting unique (source_id, page_number) pairs. + Load only the metadata columns from LanceDB, skipping the large vector column. - Falls back to table row count if page-level fields are unavailable. + Returns a DataFrame with [source_id, page_number, metadata] or None on failure. """ try: db = lancedb.connect(uri) @@ -110,15 +111,28 @@ def _estimate_processed_pages(uri: str, table_name: str) -> Optional[int]: return None try: - df = table.to_pandas()[["source_id", "page_number"]] - return int(df.dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0]) + return table.to_lance().to_table(columns=["source_id", "page_number", "metadata"]).to_pandas() except Exception: try: - return int(table.count_rows()) + return table.to_pandas()[["source_id", "page_number", "metadata"]] except Exception: return None +def _estimate_processed_pages(df: Optional[pd.DataFrame]) -> Optional[int]: + """ + Estimate pages processed by counting unique (source_id, page_number) pairs. + """ + if df is None: + return None + try: + return int( + df[["source_id", "page_number"]].dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0] + ) + except Exception: + return None + + def _to_int(value: object, default: int = 0) -> int: try: if value is None: @@ -128,18 +142,14 @@ def _to_int(value: object, default: int = 0) -> int: return default -def _collect_detection_summary(uri: str, table_name: str) -> Optional[dict]: +def _collect_detection_summary(df: Optional[pd.DataFrame]) -> Optional[dict]: """ Collect per-model detection totals deduped by (source_id, page_number). Counts are read from LanceDB row `metadata`, which is populated during batch ingestion by the Ray write stage. """ - try: - db = lancedb.connect(uri) - table = db.open_table(table_name) - df = table.to_pandas()[["source_id", "page_number", "metadata"]] - except Exception: + if df is None: return None # Deduplicate exploded rows by page key; keep max per-page counts. @@ -397,7 +407,7 @@ def main( help="Batch size for PDF extraction stage.", ), pdf_split_batch_size: int = typer.Option( - 1, + 4, "--pdf-split-batch-size", min=1, help="Batch size for PDF split stage.", @@ -735,8 +745,9 @@ def main( ) ) ingest_elapsed_s = time.perf_counter() - ingest_start - processed_pages = _estimate_processed_pages(lancedb_uri, LANCEDB_TABLE) - detection_summary = _collect_detection_summary(lancedb_uri, LANCEDB_TABLE) + metadata_df = _load_metadata_columns(lancedb_uri, LANCEDB_TABLE) + processed_pages = _estimate_processed_pages(metadata_df) + detection_summary = _collect_detection_summary(metadata_df) print("Extraction complete.") _print_detection_summary(detection_summary) if detection_summary_file is not None: diff --git a/retriever/src/retriever/ingest_modes/batch.py b/retriever/src/retriever/ingest_modes/batch.py index 0b4b509c1..1cff3f084 100644 --- a/retriever/src/retriever/ingest_modes/batch.py +++ b/retriever/src/retriever/ingest_modes/batch.py @@ -27,15 +27,13 @@ from retriever.utils.convert import DocToPdfConversionActor from retriever.page_elements import PageElementDetectionActor from retriever.ocr.ocr import OCRActor -from retriever.pdf.extract import PDFExtractionActor -from retriever.pdf.split import PDFSplitActor +from retriever.pdf.extract import PDFSplitAndExtractActor from ..ingest import Ingestor from ..params import EmbedParams from ..params import ExtractParams from ..params import HtmlChunkParams from ..params import IngestExecuteParams -from ..params import PdfSplitParams from ..params import TextChunkParams from ..params import VdbUploadParams @@ -80,11 +78,6 @@ class _LanceDBWriteActor: """ def __init__(self, params: VdbUploadParams | None = None) -> None: - import json - from pathlib import Path - - self._json = json - self._Path = Path lancedb_params = (params or VdbUploadParams()).lancedb self._lancedb_uri = lancedb_params.lancedb_uri @@ -126,108 +119,134 @@ def __init__(self, params: VdbUploadParams | None = None) -> None: mode=mode, ) - def _build_rows(self, df: Any) -> list: - """Build LanceDB rows from a pandas DataFrame batch. + def _build_arrow_table(self, df: Any) -> Any: + """Build a PyArrow table from a pandas DataFrame batch (vectorized).""" + import pandas as pd - Mirrors the row-building logic from - ``upload_embeddings_to_lancedb_inprocess`` in inprocess.py. - """ - rows: list = [] - for row in df.itertuples(index=False): - # Extract embedding - emb = None - meta = getattr(row, "metadata", None) - if isinstance(meta, dict): - emb = meta.get("embedding") - if not (isinstance(emb, list) and emb): - emb = None - if emb is None: - payload = getattr(row, self._embedding_column, None) - if isinstance(payload, dict): - emb = payload.get(self._embedding_key) - if not (isinstance(emb, list) and emb): - emb = None - if emb is None: - continue + pa = self._pa + n = len(df) + if n == 0: + return None - # Extract source path and page number - path = "" - page = -1 - v = getattr(row, "path", None) - if isinstance(v, str) and v.strip(): - path = v.strip() - v = getattr(row, "page_number", None) - try: - if v is not None: - page = int(v) - except Exception: - pass - if isinstance(meta, dict): - sp = meta.get("source_path") - if isinstance(sp, str) and sp.strip(): - path = sp.strip() - - p = self._Path(path) if path else None - filename = p.name if p is not None else "" - pdf_basename = p.stem if p is not None else "" - pdf_page = f"{pdf_basename}_{page}" if (pdf_basename and page >= 0) else "" - source_id = path or filename or pdf_basename - - metadata_obj = {"page_number": int(page) if page is not None else -1} + # --- Extract embeddings --- + meta_col = df.get("metadata", pd.Series([None] * n)) + + emb_from_meta = meta_col.map(lambda m: m.get("embedding") if isinstance(m, dict) else None) + valid_meta = emb_from_meta.map(lambda x: isinstance(x, list) and len(x) > 0) + emb_from_meta = emb_from_meta.where(valid_meta) + + if self._embedding_column in df.columns: + emb_from_col = df[self._embedding_column].map( + lambda p: p.get(self._embedding_key) if isinstance(p, dict) else None + ) + valid_col = emb_from_col.map(lambda x: isinstance(x, list) and len(x) > 0) + emb_from_col = emb_from_col.where(valid_col) + else: + emb_from_col = pd.Series([None] * n, index=df.index) + + embeddings = emb_from_meta.fillna(emb_from_col) + mask = embeddings.notna() + if not mask.any(): + return None + + df = df[mask].reset_index(drop=True) + embeddings = embeddings[mask].reset_index(drop=True) + nr = len(df) + + # --- Paths with metadata override --- + meta_col = df.get("metadata", pd.Series([None] * nr)) + raw_paths = df.get("path", pd.Series([""] * nr)).fillna("").astype(str).str.strip() + source_paths = meta_col.map(lambda m: m.get("source_path", "").strip() if isinstance(m, dict) else "") + paths = source_paths.where(source_paths.astype(bool), raw_paths) + + # --- Derived columns --- + page_numbers = ( + pd.to_numeric(df.get("page_number", pd.Series([None] * nr)), errors="coerce").fillna(-1).astype(int) + ) + + filenames = paths.map(lambda p: Path(p).name if p else "") + pdf_basenames = paths.map(lambda p: Path(p).stem if p else "") + + has_basename = pdf_basenames.astype(bool) + has_page = page_numbers >= 0 + pdf_pages = (pdf_basenames + "_" + page_numbers.astype(str)).where(has_basename & has_page, "") + + source_ids = paths.where(paths.astype(bool), filenames) + source_ids = source_ids.where(source_ids.astype(bool), pdf_basenames) + + # --- Text --- + if self._include_text and self._text_column in df.columns: + texts = df[self._text_column].map(lambda t: str(t) if isinstance(t, str) else "") + else: + texts = pd.Series([""] * nr) + + # --- Metadata JSON --- + pe_nums = df.get("page_elements_v3_num_detections", pd.Series([None] * nr)) + pe_counts_col = df.get("page_elements_v3_counts_by_label", pd.Series([None] * nr)) + ocr_tables = df.get("table", pd.Series([None] * nr)) + ocr_charts = df.get("chart", pd.Series([None] * nr)) + ocr_infos = df.get("infographic", pd.Series([None] * nr)) + + def _meta_json(page, pdf_page, pe_num, pe_count, tbl, chart, info): + obj = {"page_number": int(page)} if pdf_page: - metadata_obj["pdf_page"] = pdf_page - # Persist per-page detection counters for end-of-run summaries. - # These may be duplicated across exploded content rows; downstream - # summary logic should dedupe by (source_id, page_number). - pe_num = getattr(row, "page_elements_v3_num_detections", None) + obj["pdf_page"] = pdf_page if pe_num is not None: try: - metadata_obj["page_elements_v3_num_detections"] = int(pe_num) + obj["page_elements_v3_num_detections"] = int(pe_num) except Exception: pass - pe_counts = getattr(row, "page_elements_v3_counts_by_label", None) - if isinstance(pe_counts, dict): - metadata_obj["page_elements_v3_counts_by_label"] = { - str(k): int(v) for k, v in pe_counts.items() if isinstance(k, str) and v is not None + if isinstance(pe_count, dict): + obj["page_elements_v3_counts_by_label"] = { + str(k): int(v) for k, v in pe_count.items() if isinstance(k, str) and v is not None } - for ocr_col in ("table", "chart", "infographic"): - entries = getattr(row, ocr_col, None) + for name, entries in [("table", tbl), ("chart", chart), ("infographic", info)]: if isinstance(entries, list): - metadata_obj[f"ocr_{ocr_col}_detections"] = int(len(entries)) - source_obj = {"source_id": str(path)} - - row_out = { - "vector": emb, - "pdf_page": pdf_page, - "filename": filename, - "pdf_basename": pdf_basename, - "page_number": int(page) if page is not None else -1, - "source_id": str(source_id), - "path": str(path), - "metadata": self._json.dumps(metadata_obj, ensure_ascii=False), - "source": self._json.dumps(source_obj, ensure_ascii=False), - } - - if self._include_text: - t = getattr(row, self._text_column, None) - row_out["text"] = str(t) if isinstance(t, str) else "" - else: - row_out["text"] = "" - - rows.append(row_out) - return rows + obj[f"ocr_{name}_detections"] = len(entries) + return json.dumps(obj, ensure_ascii=False) + + metadata_jsons = [ + _meta_json(pn, pp, pen, pec, ot, oc, oi) + for pn, pp, pen, pec, ot, oc, oi in zip( + page_numbers, + pdf_pages, + pe_nums, + pe_counts_col, + ocr_tables, + ocr_charts, + ocr_infos, + ) + ] + + # --- Source JSON --- + source_jsons = paths.map(lambda p: json.dumps({"source_id": str(p)})).tolist() + + # --- Build PyArrow table --- + return pa.table( + { + "vector": pa.array(embeddings.tolist(), type=pa.list_(pa.float32(), 2048)), + "pdf_page": pa.array(pdf_pages.tolist(), type=pa.string()), + "filename": pa.array(filenames.tolist(), type=pa.string()), + "pdf_basename": pa.array(pdf_basenames.tolist(), type=pa.string()), + "page_number": pa.array(page_numbers.tolist(), type=pa.int32()), + "source_id": pa.array(source_ids.astype(str).tolist(), type=pa.string()), + "path": pa.array(paths.astype(str).tolist(), type=pa.string()), + "text": pa.array(texts.tolist(), type=pa.string()), + "metadata": pa.array(metadata_jsons, type=pa.string()), + "source": pa.array(source_jsons, type=pa.string()), + }, + schema=self._schema, + ) def __call__(self, batch_df: Any) -> Any: - rows = self._build_rows(batch_df) - if rows: - # Infer schema from first batch - if self._table is None: - self._table = self._db.open_table(self._table_name) - self._table.add(rows) + import pandas as pd - self._total_rows += len(rows) + arrow_table = self._build_arrow_table(batch_df) + if arrow_table is not None and arrow_table.num_rows > 0: + self._table.add(arrow_table) + self._total_rows += arrow_table.num_rows - return batch_df + return pd.DataFrame({"_written": range(len(batch_df))}) class _BatchEmbedActor: @@ -393,8 +412,8 @@ def _endpoint_count(raw: Any) -> int: return len([p for p in s.split(",") if p.strip()]) debug_run_id = str(kwargs.pop("debug_run_id", "unknown")) - pdf_split_batch_size = kwargs.pop("pdf_split_batch_size", 1) - pdf_extract_batch_size = kwargs.pop("pdf_extract_batch_size", 4) + pdf_split_batch_size = kwargs.pop("pdf_split_batch_size", 4) + kwargs.pop("pdf_extract_batch_size", None) # consumed but unused after fusing split+extract pdf_extract_num_cpus = float(kwargs.pop("pdf_extract_num_cpus", 2)) page_elements_batch_size = kwargs.pop("page_elements_batch_size", 24) detect_batch_size = kwargs.pop("detect_batch_size", 24) @@ -568,33 +587,21 @@ def _endpoint_count(raw: Any) -> int: batch_format="pandas", ) - # Splitting pdfs is broken into a separate stage to help amortize downstream - # processing if PDFs have vastly different numbers of pages. - pdf_split_actor = PDFSplitActor( - split_params=PdfSplitParams( - start_page=kwargs.get("start_page"), - end_page=kwargs.get("end_page"), - ) - ) + # Fused split + extraction: open each multi-page PDF once and + # extract per-page text/images in a single pass. Eliminates the + # redundant single-page PDF serialization from the old Split stage. + fused_kwargs = dict(kwargs) + fused_kwargs["start_page"] = kwargs.get("start_page") + fused_kwargs["end_page"] = kwargs.get("end_page") self._rd_dataset = self._rd_dataset.map_batches( - pdf_split_actor, + PDFSplitAndExtractActor, batch_size=pdf_split_batch_size, - num_cpus=1, - num_gpus=0, - batch_format="pandas", - ) - - # Pre-split pdfs are now ready for extraction — the main CPU bottleneck. - extraction_actor = PDFExtractionActor(**kwargs) - self._rd_dataset = self._rd_dataset.map_batches( - extraction_actor, - batch_size=pdf_extract_batch_size, batch_format="pandas", num_cpus=pdf_extract_num_cpus, num_gpus=0, - compute=rd.TaskPoolStrategy(size=pdf_extract_workers), + compute=rd.ActorPoolStrategy(size=pdf_extract_workers), + fn_constructor_kwargs=fused_kwargs, ) - self._rd_dataset = self._rd_dataset.repartition(target_num_rows_per_block=24) # Page-element detection with a GPU actor pool. # For ActorPoolStrategy, Ray Data expects a *callable class* (so it can # construct one instance per actor). Passing an already-constructed diff --git a/retriever/src/retriever/ocr/ocr.py b/retriever/src/retriever/ocr/ocr.py index 380b82f5e..0d2e6a0bd 100644 --- a/retriever/src/retriever/ocr/ocr.py +++ b/retriever/src/retriever/ocr/ocr.py @@ -110,13 +110,15 @@ def _clamp_int(v: float, lo: int, hi: int) -> int: def _crop_all_from_page( - page_image_b64: str, + page_image: Any, detections: List[Dict[str, Any]], wanted_labels: set, ) -> List[Tuple[str, List[float], np.ndarray]]: """ Decode the page image **once** and crop all matching detections. + *page_image* may be a numpy HWC uint8 array or a base64-encoded PNG string. + Returns a list of ``(label_name, bbox_xyxy_norm, crop_array)`` tuples for detections whose ``label_name`` is in *wanted_labels* and whose crop is valid. Skips detections that fail to crop (bad bbox, tiny region, etc.). @@ -127,15 +129,17 @@ def _crop_all_from_page( if Image is None: # pragma: no cover raise ImportError("Cropping requires pillow.") - if not isinstance(page_image_b64, str) or not page_image_b64: - return [] - - try: - raw = base64.b64decode(page_image_b64) - im0 = Image.open(io.BytesIO(raw)) - im = im0.convert("RGB") - im0.close() - except Exception: + if isinstance(page_image, np.ndarray): + im = Image.fromarray(page_image.astype(np.uint8), mode="RGB") + elif isinstance(page_image, str) and page_image: + try: + raw = base64.b64decode(page_image) + im0 = Image.open(io.BytesIO(raw)) + im = im0.convert("RGB") + im0.close() + except Exception: + return [] + else: return [] w, h = im.size @@ -448,11 +452,15 @@ def ocr_page_elements( dets = [] # --- get page image --- - page_image = getattr(row, "page_image", None) or {} - page_image_b64 = page_image.get("image_b64") if isinstance(page_image, dict) else None + page_image_data = getattr(row, "page_image", None) or {} + if isinstance(page_image_data, dict) and isinstance(page_image_data.get("image_array"), np.ndarray): + page_img = page_image_data["image_array"] + elif isinstance(page_image_data, dict): + page_img = page_image_data.get("image_b64") + else: + page_img = None - if not isinstance(page_image_b64, str) or not page_image_b64: - # No image available — nothing to crop/OCR. + if page_img is None: all_table.append(table_items) all_chart.append(chart_items) all_infographic.append(infographic_items) @@ -460,7 +468,7 @@ def ocr_page_elements( continue # --- decode page image once, crop all matching detections --- - crops = _crop_all_from_page(page_image_b64, dets, wanted_labels) + crops = _crop_all_from_page(page_img, dets, wanted_labels) if use_remote: crop_b64s: List[str] = [] diff --git a/retriever/src/retriever/page_elements/page_elements.py b/retriever/src/retriever/page_elements/page_elements.py index 101e3b7c0..8bfa40af6 100644 --- a/retriever/src/retriever/page_elements/page_elements.py +++ b/retriever/src/retriever/page_elements/page_elements.py @@ -99,6 +99,14 @@ def _error_payload(*, stage: str, exc: BaseException) -> Dict[str, Any]: } +def _np_to_b64_png(arr: "np.ndarray") -> str: + """Encode an HWC uint8 numpy array to a base64-encoded PNG string.""" + img = Image.fromarray(arr.astype(np.uint8), mode="RGB") + buf = io.BytesIO() + img.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("ascii") + + def _decode_b64_image_to_chw_tensor(image_b64: str) -> Tuple["torch.Tensor", Tuple[int, int]]: if torch is None or Image is None or np is None: # pragma: no cover raise ImportError("page element detection requires torch, pillow, and numpy.") @@ -499,17 +507,29 @@ def detect_page_elements_v3( for _, row in pages_df.iterrows(): try: - b64 = row.get("page_image")["image_b64"] - if not b64: - raise ValueError("No usable image_b64 found in row.") - row_b64.append(b64) - if use_remote: - row_tensors.append(None) - row_shapes.append(None) + page_img = row.get("page_image") + if isinstance(page_img, dict) and isinstance(page_img.get("image_array"), np.ndarray): + arr = page_img["image_array"] + h, w = arr.shape[:2] + row_shapes.append((h, w)) + if use_remote: + row_b64.append(_np_to_b64_png(arr)) + row_tensors.append(None) + else: + row_b64.append(None) + row_tensors.append(arr) else: - t, orig_shape = _decode_b64_image_to_np_array(b64) - row_tensors.append(t) - row_shapes.append(orig_shape) + b64 = page_img["image_b64"] + if not b64: + raise ValueError("No usable image_b64 found in row.") + row_b64.append(b64) + if use_remote: + row_tensors.append(None) + row_shapes.append(None) + else: + t, orig_shape = _decode_b64_image_to_np_array(b64) + row_tensors.append(t) + row_shapes.append(orig_shape) row_payloads.append({"detections": []}) except BaseException as e: row_tensors.append(None) diff --git a/retriever/src/retriever/pdf/extract.py b/retriever/src/retriever/pdf/extract.py index 3a5727aaa..2d37b1fea 100644 --- a/retriever/src/retriever/pdf/extract.py +++ b/retriever/src/retriever/pdf/extract.py @@ -27,6 +27,11 @@ except Exception: # pragma: no cover np = None # type: ignore[assignment] +try: + import cv2 +except Exception: # pragma: no cover + cv2 = None # type: ignore[assignment] + try: from PIL import Image except Exception: # pragma: no cover @@ -104,6 +109,37 @@ def _render_page_to_base64(page: Any, *, dpi: int = 200, image_format: str = "pn raise RuntimeError("Failed to render page to an image representation.") +def _bitmap_to_rgb_numpy(bitmap: Any) -> Any: + """Convert a PdfBitmap to an RGB HWC uint8 numpy array. + + Handles BGRA/BGRX/BGR modes that PDFium may return, using an in-place + SIMD-optimized channel swap via OpenCV (mirrors the approach in + ``api/.../pdfium.py:convert_bitmap_to_corrected_numpy``). + """ + arr = bitmap.to_numpy().copy() + mode = getattr(bitmap, "mode", None) + if cv2 is not None: + if mode in {"BGRA", "BGRX"}: + cv2.cvtColor(arr, cv2.COLOR_BGRA2RGBA, dst=arr) + elif mode == "BGR": + cv2.cvtColor(arr, cv2.COLOR_BGR2RGB, dst=arr) + return arr + + +def _render_page_to_numpy(page: Any, *, dpi: int = 200) -> Dict[str, Any]: + """Render a page to a numpy array (HWC uint8 RGB). No PNG/base64 overhead.""" + scale = max(float(dpi) / 72.0, 0.01) + bitmap = page.render(scale=scale) + arr = _bitmap_to_rgb_numpy(bitmap) + h, w = arr.shape[:2] + return { + "image_array": arr, + "image_b64": None, + "encoding": "numpy", + "orig_shape_hw": (h, w), + } + + def _error_record( *, source_path: Optional[str], @@ -268,7 +304,7 @@ def pdf_extraction( ) render_info: Optional[Dict[str, Any]] = None if want_any_raster: - render_info = _render_page_to_base64(page, dpi=dpi, image_format=image_format) + render_info = _render_page_to_numpy(page, dpi=dpi) page_record: Dict[str, Any] = { "path": pdf_path, @@ -322,6 +358,141 @@ def pdf_extraction( raise NotImplementedError("pdf_extraction currently only supports pandas.DataFrame input.") +def split_and_extract_pdf( + pdf_batch: Any, + extract_text: bool = False, + extract_images: bool = False, + extract_tables: bool = False, + extract_charts: bool = False, + extract_infographics: bool = False, + dpi: int = 200, + text_extraction_method: str = "pdfium_hybrid", + text_depth: str = "page", + start_page: int | None = None, + end_page: int | None = None, + **kwargs: Any, +) -> pd.DataFrame: + """Fused split + extraction: open each multi-page PDF once, extract per-page.""" + if not isinstance(pdf_batch, pd.DataFrame): + raise NotImplementedError("split_and_extract_pdf currently only supports pandas.DataFrame input.") + + if pdfium is None: # pragma: no cover + outputs: List[Dict[str, Any]] = [] + for _, row in pdf_batch.iterrows(): + pdf_path = row.get("path") + outputs.append( + _error_record( + source_path=str(pdf_path) if pdf_path is not None else None, + stage="import_pypdfium2", + exc=( + _PDFIUM_IMPORT_ERROR + if _PDFIUM_IMPORT_ERROR is not None + else RuntimeError("pypdfium2 unavailable") + ), + ) + ) + return pd.DataFrame(outputs) + + outputs: List[Dict[str, Any]] = [] + for _, row in pdf_batch.iterrows(): + pdf_bytes = row.get("bytes") + pdf_path = row.get("path") + try: + if not isinstance(pdf_bytes, (bytes, bytearray, memoryview)): + raise RuntimeError(f"Unsupported bytes payload type: {type(pdf_bytes)!r}") + + try: + doc = pdfium.PdfDocument(pdf_bytes) + except Exception: + doc = pdfium.PdfDocument(BytesIO(bytes(pdf_bytes))) + + n_pages = len(doc) + start_idx = 0 if start_page is None else max(int(start_page) - 1, 0) + end_idx = (n_pages - 1) if end_page is None else min(int(end_page) - 1, n_pages - 1) + + for page_idx in range(start_idx, end_idx + 1): + page = doc.get_page(page_idx) + try: + is_scanned = _is_scanned_page(page) + ocr_needed = extract_text and ( + (text_extraction_method == "pdfium_hybrid" and is_scanned) or text_extraction_method == "ocr" + ) + + text = "" + if extract_text and not ocr_needed: + text = _extract_page_text(page) + + want_raster = bool( + extract_images or extract_tables or extract_charts or extract_infographics or ocr_needed + ) + page_image = None + if want_raster: + page_image = _render_page_to_numpy(page, dpi=dpi) + + outputs.append( + { + "path": pdf_path, + "page_number": page_idx + 1, + "text": text if extract_text else "", + "page_image": page_image, + "images": [], + "tables": [], + "charts": [], + "infographics": [], + "metadata": { + "has_text": bool(text.strip()) if extract_text else False, + "needs_ocr_for_text": ocr_needed, + "dpi": dpi, + "source_path": pdf_path, + "error": None, + }, + } + ) + finally: + if hasattr(page, "close"): + page.close() + doc.close() + except BaseException as e: + outputs.append( + _error_record( + source_path=str(pdf_path) if pdf_path else None, + stage="split_and_extract", + exc=e, + ) + ) + return pd.DataFrame(outputs) + + +@dataclass(slots=True) +class PDFSplitAndExtractActor: + """Fused split + extraction actor for Ray Data ActorPoolStrategy.""" + + extract_kwargs: Dict[str, Any] + + def __init__(self, **extract_kwargs: Any) -> None: + self.extract_kwargs = dict(extract_kwargs) + + def __call__(self, pdf_batch: Any, **override_kwargs: Any) -> Any: + try: + return split_and_extract_pdf(pdf_batch, **self.extract_kwargs, **override_kwargs) + except BaseException as e: + source_path = None + try: + if isinstance(pdf_batch, pd.DataFrame) and "path" in pdf_batch.columns and len(pdf_batch.index) > 0: + source_path = str(pdf_batch.iloc[0]["path"]) + except Exception: + source_path = None + return pd.DataFrame( + [ + _error_record( + source_path=source_path, + stage="actor_call", + exc=e, + ) + ] + ) + + @dataclass(slots=True) class PDFExtractionActor: """ From 944e3844e6b470f5077a1c53dec67213f7bd1f65 Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Thu, 26 Feb 2026 21:38:11 +0000 Subject: [PATCH 3/4] Revert numpy page-image passthrough for extraction pipeline Keep page-image transport as base64 to avoid object-store bloat and spill pressure while preserving fused split+extract and ActorPool improvements. Signed-off-by: Jacob Ioffe Made-with: Cursor (cherry picked from commit e84da71b618fd34a4666b6676c2b8866e8721ded) --- retriever/src/retriever/ocr/ocr.py | 34 +++++--------- .../retriever/page_elements/page_elements.py | 39 +++++----------- retriever/src/retriever/pdf/extract.py | 45 +------------------ 3 files changed, 24 insertions(+), 94 deletions(-) diff --git a/retriever/src/retriever/ocr/ocr.py b/retriever/src/retriever/ocr/ocr.py index 0d2e6a0bd..55b78a403 100644 --- a/retriever/src/retriever/ocr/ocr.py +++ b/retriever/src/retriever/ocr/ocr.py @@ -110,15 +110,13 @@ def _clamp_int(v: float, lo: int, hi: int) -> int: def _crop_all_from_page( - page_image: Any, + page_image_b64: str, detections: List[Dict[str, Any]], wanted_labels: set, ) -> List[Tuple[str, List[float], np.ndarray]]: """ Decode the page image **once** and crop all matching detections. - *page_image* may be a numpy HWC uint8 array or a base64-encoded PNG string. - Returns a list of ``(label_name, bbox_xyxy_norm, crop_array)`` tuples for detections whose ``label_name`` is in *wanted_labels* and whose crop is valid. Skips detections that fail to crop (bad bbox, tiny region, etc.). @@ -129,17 +127,14 @@ def _crop_all_from_page( if Image is None: # pragma: no cover raise ImportError("Cropping requires pillow.") - if isinstance(page_image, np.ndarray): - im = Image.fromarray(page_image.astype(np.uint8), mode="RGB") - elif isinstance(page_image, str) and page_image: - try: - raw = base64.b64decode(page_image) - im0 = Image.open(io.BytesIO(raw)) - im = im0.convert("RGB") - im0.close() - except Exception: - return [] - else: + if not page_image_b64: + return [] + try: + raw = base64.b64decode(page_image_b64) + im0 = Image.open(io.BytesIO(raw)) + im = im0.convert("RGB") + im0.close() + except Exception: return [] w, h = im.size @@ -453,14 +448,9 @@ def ocr_page_elements( # --- get page image --- page_image_data = getattr(row, "page_image", None) or {} - if isinstance(page_image_data, dict) and isinstance(page_image_data.get("image_array"), np.ndarray): - page_img = page_image_data["image_array"] - elif isinstance(page_image_data, dict): - page_img = page_image_data.get("image_b64") - else: - page_img = None + page_img_b64 = page_image_data.get("image_b64") if isinstance(page_image_data, dict) else None - if page_img is None: + if not page_img_b64: all_table.append(table_items) all_chart.append(chart_items) all_infographic.append(infographic_items) @@ -468,7 +458,7 @@ def ocr_page_elements( continue # --- decode page image once, crop all matching detections --- - crops = _crop_all_from_page(page_img, dets, wanted_labels) + crops = _crop_all_from_page(page_img_b64, dets, wanted_labels) if use_remote: crop_b64s: List[str] = [] diff --git a/retriever/src/retriever/page_elements/page_elements.py b/retriever/src/retriever/page_elements/page_elements.py index 8bfa40af6..4d7957981 100644 --- a/retriever/src/retriever/page_elements/page_elements.py +++ b/retriever/src/retriever/page_elements/page_elements.py @@ -99,14 +99,6 @@ def _error_payload(*, stage: str, exc: BaseException) -> Dict[str, Any]: } -def _np_to_b64_png(arr: "np.ndarray") -> str: - """Encode an HWC uint8 numpy array to a base64-encoded PNG string.""" - img = Image.fromarray(arr.astype(np.uint8), mode="RGB") - buf = io.BytesIO() - img.save(buf, format="PNG") - return base64.b64encode(buf.getvalue()).decode("ascii") - - def _decode_b64_image_to_chw_tensor(image_b64: str) -> Tuple["torch.Tensor", Tuple[int, int]]: if torch is None or Image is None or np is None: # pragma: no cover raise ImportError("page element detection requires torch, pillow, and numpy.") @@ -508,28 +500,17 @@ def detect_page_elements_v3( for _, row in pages_df.iterrows(): try: page_img = row.get("page_image") - if isinstance(page_img, dict) and isinstance(page_img.get("image_array"), np.ndarray): - arr = page_img["image_array"] - h, w = arr.shape[:2] - row_shapes.append((h, w)) - if use_remote: - row_b64.append(_np_to_b64_png(arr)) - row_tensors.append(None) - else: - row_b64.append(None) - row_tensors.append(arr) + b64 = page_img["image_b64"] + if not b64: + raise ValueError("No usable image_b64 found in row.") + row_b64.append(b64) + if use_remote: + row_tensors.append(None) + row_shapes.append(None) else: - b64 = page_img["image_b64"] - if not b64: - raise ValueError("No usable image_b64 found in row.") - row_b64.append(b64) - if use_remote: - row_tensors.append(None) - row_shapes.append(None) - else: - t, orig_shape = _decode_b64_image_to_np_array(b64) - row_tensors.append(t) - row_shapes.append(orig_shape) + t, orig_shape = _decode_b64_image_to_np_array(b64) + row_tensors.append(t) + row_shapes.append(orig_shape) row_payloads.append({"detections": []}) except BaseException as e: row_tensors.append(None) diff --git a/retriever/src/retriever/pdf/extract.py b/retriever/src/retriever/pdf/extract.py index 2d37b1fea..cfdd3338b 100644 --- a/retriever/src/retriever/pdf/extract.py +++ b/retriever/src/retriever/pdf/extract.py @@ -22,16 +22,6 @@ else: # pragma: no cover _PDFIUM_IMPORT_ERROR = None -try: - import numpy as np -except Exception: # pragma: no cover - np = None # type: ignore[assignment] - -try: - import cv2 -except Exception: # pragma: no cover - cv2 = None # type: ignore[assignment] - try: from PIL import Image except Exception: # pragma: no cover @@ -109,37 +99,6 @@ def _render_page_to_base64(page: Any, *, dpi: int = 200, image_format: str = "pn raise RuntimeError("Failed to render page to an image representation.") -def _bitmap_to_rgb_numpy(bitmap: Any) -> Any: - """Convert a PdfBitmap to an RGB HWC uint8 numpy array. - - Handles BGRA/BGRX/BGR modes that PDFium may return, using an in-place - SIMD-optimized channel swap via OpenCV (mirrors the approach in - ``api/.../pdfium.py:convert_bitmap_to_corrected_numpy``). - """ - arr = bitmap.to_numpy().copy() - mode = getattr(bitmap, "mode", None) - if cv2 is not None: - if mode in {"BGRA", "BGRX"}: - cv2.cvtColor(arr, cv2.COLOR_BGRA2RGBA, dst=arr) - elif mode == "BGR": - cv2.cvtColor(arr, cv2.COLOR_BGR2RGB, dst=arr) - return arr - - -def _render_page_to_numpy(page: Any, *, dpi: int = 200) -> Dict[str, Any]: - """Render a page to a numpy array (HWC uint8 RGB). No PNG/base64 overhead.""" - scale = max(float(dpi) / 72.0, 0.01) - bitmap = page.render(scale=scale) - arr = _bitmap_to_rgb_numpy(bitmap) - h, w = arr.shape[:2] - return { - "image_array": arr, - "image_b64": None, - "encoding": "numpy", - "orig_shape_hw": (h, w), - } - - def _error_record( *, source_path: Optional[str], @@ -304,7 +263,7 @@ def pdf_extraction( ) render_info: Optional[Dict[str, Any]] = None if want_any_raster: - render_info = _render_page_to_numpy(page, dpi=dpi) + render_info = _render_page_to_base64(page, dpi=dpi) page_record: Dict[str, Any] = { "path": pdf_path, @@ -427,7 +386,7 @@ def split_and_extract_pdf( ) page_image = None if want_raster: - page_image = _render_page_to_numpy(page, dpi=dpi) + page_image = _render_page_to_base64(page, dpi=dpi) outputs.append( { From 02ec0748346433d33a648d143140dd0eba2b293a Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Fri, 27 Feb 2026 15:19:29 +0000 Subject: [PATCH 4/4] Use TaskPool for fused PDF split/extract stage - Run split_and_extract_pdf with TaskPoolStrategy(size=workers) - Keep GPU page-elements and OCR stages on ActorPoolStrategy - Remove unused PDFSplitAndExtractActor wrapper class Signed-off-by: Jacob Ioffe Made-with: Cursor --- retriever/src/retriever/ingest_modes/batch.py | 8 ++--- retriever/src/retriever/pdf/extract.py | 30 ------------------- 2 files changed, 4 insertions(+), 34 deletions(-) diff --git a/retriever/src/retriever/ingest_modes/batch.py b/retriever/src/retriever/ingest_modes/batch.py index 1cff3f084..d73fdb17c 100644 --- a/retriever/src/retriever/ingest_modes/batch.py +++ b/retriever/src/retriever/ingest_modes/batch.py @@ -27,7 +27,7 @@ from retriever.utils.convert import DocToPdfConversionActor from retriever.page_elements import PageElementDetectionActor from retriever.ocr.ocr import OCRActor -from retriever.pdf.extract import PDFSplitAndExtractActor +from retriever.pdf.extract import split_and_extract_pdf from ..ingest import Ingestor from ..params import EmbedParams @@ -594,13 +594,13 @@ def _endpoint_count(raw: Any) -> int: fused_kwargs["start_page"] = kwargs.get("start_page") fused_kwargs["end_page"] = kwargs.get("end_page") self._rd_dataset = self._rd_dataset.map_batches( - PDFSplitAndExtractActor, + split_and_extract_pdf, batch_size=pdf_split_batch_size, batch_format="pandas", num_cpus=pdf_extract_num_cpus, num_gpus=0, - compute=rd.ActorPoolStrategy(size=pdf_extract_workers), - fn_constructor_kwargs=fused_kwargs, + compute=rd.TaskPoolStrategy(size=pdf_extract_workers), + fn_kwargs=fused_kwargs, ) # Page-element detection with a GPU actor pool. # For ActorPoolStrategy, Ray Data expects a *callable class* (so it can diff --git a/retriever/src/retriever/pdf/extract.py b/retriever/src/retriever/pdf/extract.py index cfdd3338b..b1dfc382b 100644 --- a/retriever/src/retriever/pdf/extract.py +++ b/retriever/src/retriever/pdf/extract.py @@ -422,36 +422,6 @@ def split_and_extract_pdf( return pd.DataFrame(outputs) -@dataclass(slots=True) -class PDFSplitAndExtractActor: - """Fused split + extraction actor for Ray Data ActorPoolStrategy.""" - - extract_kwargs: Dict[str, Any] - - def __init__(self, **extract_kwargs: Any) -> None: - self.extract_kwargs = dict(extract_kwargs) - - def __call__(self, pdf_batch: Any, **override_kwargs: Any) -> Any: - try: - return split_and_extract_pdf(pdf_batch, **self.extract_kwargs, **override_kwargs) - except BaseException as e: - source_path = None - try: - if isinstance(pdf_batch, pd.DataFrame) and "path" in pdf_batch.columns and len(pdf_batch.index) > 0: - source_path = str(pdf_batch.iloc[0]["path"]) - except Exception: - source_path = None - return pd.DataFrame( - [ - _error_record( - source_path=source_path, - stage="actor_call", - exc=e, - ) - ] - ) - - @dataclass(slots=True) class PDFExtractionActor: """