diff --git a/retriever/src/retriever/examples/batch_pipeline.py b/retriever/src/retriever/examples/batch_pipeline.py index cb3ab8c81..58f3620e2 100644 --- a/retriever/src/retriever/examples/batch_pipeline.py +++ b/retriever/src/retriever/examples/batch_pipeline.py @@ -18,6 +18,7 @@ from pathlib import Path from typing import Optional, TextIO +import pandas as pd import ray import typer from retriever import create_ingestor @@ -102,11 +103,11 @@ def _configure_logging(log_file: Optional[Path]) -> tuple[Optional[TextIO], Text return fh, original_stdout, original_stderr -def _estimate_processed_pages(uri: str, table_name: str) -> Optional[int]: +def _load_metadata_columns(uri: str, table_name: str) -> Optional[pd.DataFrame]: """ - Estimate pages processed by counting unique (source_id, page_number) pairs. + Load only the metadata columns from LanceDB, skipping the large vector column. - Falls back to table row count if page-level fields are unavailable. + Returns a DataFrame with [source_id, page_number, metadata] or None on failure. """ try: db = _lancedb().connect(uri) @@ -115,15 +116,28 @@ def _estimate_processed_pages(uri: str, table_name: str) -> Optional[int]: return None try: - df = table.to_pandas()[["source_id", "page_number"]] - return int(df.dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0]) + return table.to_lance().to_table(columns=["source_id", "page_number", "metadata"]).to_pandas() except Exception: try: - return int(table.count_rows()) + return table.to_pandas()[["source_id", "page_number", "metadata"]] except Exception: return None +def _estimate_processed_pages(df: Optional[pd.DataFrame]) -> Optional[int]: + """ + Estimate pages processed by counting unique (source_id, page_number) pairs. + """ + if df is None: + return None + try: + return int( + df[["source_id", "page_number"]].dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0] + ) + except Exception: + return None + + def _to_int(value: object, default: int = 0) -> int: try: if value is None: @@ -133,18 +147,14 @@ def _to_int(value: object, default: int = 0) -> int: return default -def _collect_detection_summary(uri: str, table_name: str) -> Optional[dict]: +def _collect_detection_summary(df: Optional[pd.DataFrame]) -> Optional[dict]: """ Collect per-model detection totals deduped by (source_id, page_number). Counts are read from LanceDB row `metadata`, which is populated during batch ingestion by the Ray write stage. """ - try: - db = _lancedb().connect(uri) - table = db.open_table(table_name) - df = table.to_pandas()[["source_id", "page_number", "metadata"]] - except Exception: + if df is None: return None # Deduplicate exploded rows by page key; keep max per-page counts. @@ -402,7 +412,7 @@ def main( help="Batch size for PDF extraction stage.", ), pdf_split_batch_size: int = typer.Option( - 1, + 4, "--pdf-split-batch-size", min=1, help="Batch size for PDF split stage.", @@ -778,8 +788,9 @@ def main( ) ) ingest_elapsed_s = time.perf_counter() - ingest_start - processed_pages = _estimate_processed_pages(lancedb_uri, LANCEDB_TABLE) - detection_summary = _collect_detection_summary(lancedb_uri, LANCEDB_TABLE) + metadata_df = _load_metadata_columns(lancedb_uri, LANCEDB_TABLE) + processed_pages = _estimate_processed_pages(metadata_df) + detection_summary = _collect_detection_summary(metadata_df) print("Extraction complete.") _print_detection_summary(detection_summary) if detection_summary_file is not None: diff --git a/retriever/src/retriever/ingest_modes/batch.py b/retriever/src/retriever/ingest_modes/batch.py index 86ddf01dd..ac74612ac 100644 --- a/retriever/src/retriever/ingest_modes/batch.py +++ b/retriever/src/retriever/ingest_modes/batch.py @@ -27,15 +27,13 @@ from retriever.utils.convert import DocToPdfConversionActor from retriever.page_elements import PageElementDetectionActor from retriever.ocr.ocr import OCRActor -from retriever.pdf.extract import PDFExtractionActor -from retriever.pdf.split import PDFSplitActor +from retriever.pdf.extract import split_and_extract_pdf from ..ingest import Ingestor from ..params import EmbedParams from ..params import ExtractParams from ..params import HtmlChunkParams from ..params import IngestExecuteParams -from ..params import PdfSplitParams from ..params import TextChunkParams from ..params import VdbUploadParams @@ -80,11 +78,6 @@ class _LanceDBWriteActor: """ def __init__(self, params: VdbUploadParams | None = None) -> None: - import json - from pathlib import Path - - self._json = json - self._Path = Path lancedb_params = (params or VdbUploadParams()).lancedb self._lancedb_uri = lancedb_params.lancedb_uri @@ -126,108 +119,134 @@ def __init__(self, params: VdbUploadParams | None = None) -> None: mode=mode, ) - def _build_rows(self, df: Any) -> list: - """Build LanceDB rows from a pandas DataFrame batch. + def _build_arrow_table(self, df: Any) -> Any: + """Build a PyArrow table from a pandas DataFrame batch (vectorized).""" + import pandas as pd - Mirrors the row-building logic from - ``upload_embeddings_to_lancedb_inprocess`` in inprocess.py. - """ - rows: list = [] - for row in df.itertuples(index=False): - # Extract embedding - emb = None - meta = getattr(row, "metadata", None) - if isinstance(meta, dict): - emb = meta.get("embedding") - if not (isinstance(emb, list) and emb): - emb = None - if emb is None: - payload = getattr(row, self._embedding_column, None) - if isinstance(payload, dict): - emb = payload.get(self._embedding_key) - if not (isinstance(emb, list) and emb): - emb = None - if emb is None: - continue + pa = self._pa + n = len(df) + if n == 0: + return None - # Extract source path and page number - path = "" - page = -1 - v = getattr(row, "path", None) - if isinstance(v, str) and v.strip(): - path = v.strip() - v = getattr(row, "page_number", None) - try: - if v is not None: - page = int(v) - except Exception: - pass - if isinstance(meta, dict): - sp = meta.get("source_path") - if isinstance(sp, str) and sp.strip(): - path = sp.strip() - - p = self._Path(path) if path else None - filename = p.name if p is not None else "" - pdf_basename = p.stem if p is not None else "" - pdf_page = f"{pdf_basename}_{page}" if (pdf_basename and page >= 0) else "" - source_id = path or filename or pdf_basename - - metadata_obj = {"page_number": int(page) if page is not None else -1} + # --- Extract embeddings --- + meta_col = df.get("metadata", pd.Series([None] * n)) + + emb_from_meta = meta_col.map(lambda m: m.get("embedding") if isinstance(m, dict) else None) + valid_meta = emb_from_meta.map(lambda x: isinstance(x, list) and len(x) > 0) + emb_from_meta = emb_from_meta.where(valid_meta) + + if self._embedding_column in df.columns: + emb_from_col = df[self._embedding_column].map( + lambda p: p.get(self._embedding_key) if isinstance(p, dict) else None + ) + valid_col = emb_from_col.map(lambda x: isinstance(x, list) and len(x) > 0) + emb_from_col = emb_from_col.where(valid_col) + else: + emb_from_col = pd.Series([None] * n, index=df.index) + + embeddings = emb_from_meta.fillna(emb_from_col) + mask = embeddings.notna() + if not mask.any(): + return None + + df = df[mask].reset_index(drop=True) + embeddings = embeddings[mask].reset_index(drop=True) + nr = len(df) + + # --- Paths with metadata override --- + meta_col = df.get("metadata", pd.Series([None] * nr)) + raw_paths = df.get("path", pd.Series([""] * nr)).fillna("").astype(str).str.strip() + source_paths = meta_col.map(lambda m: m.get("source_path", "").strip() if isinstance(m, dict) else "") + paths = source_paths.where(source_paths.astype(bool), raw_paths) + + # --- Derived columns --- + page_numbers = ( + pd.to_numeric(df.get("page_number", pd.Series([None] * nr)), errors="coerce").fillna(-1).astype(int) + ) + + filenames = paths.map(lambda p: Path(p).name if p else "") + pdf_basenames = paths.map(lambda p: Path(p).stem if p else "") + + has_basename = pdf_basenames.astype(bool) + has_page = page_numbers >= 0 + pdf_pages = (pdf_basenames + "_" + page_numbers.astype(str)).where(has_basename & has_page, "") + + source_ids = paths.where(paths.astype(bool), filenames) + source_ids = source_ids.where(source_ids.astype(bool), pdf_basenames) + + # --- Text --- + if self._include_text and self._text_column in df.columns: + texts = df[self._text_column].map(lambda t: str(t) if isinstance(t, str) else "") + else: + texts = pd.Series([""] * nr) + + # --- Metadata JSON --- + pe_nums = df.get("page_elements_v3_num_detections", pd.Series([None] * nr)) + pe_counts_col = df.get("page_elements_v3_counts_by_label", pd.Series([None] * nr)) + ocr_tables = df.get("table", pd.Series([None] * nr)) + ocr_charts = df.get("chart", pd.Series([None] * nr)) + ocr_infos = df.get("infographic", pd.Series([None] * nr)) + + def _meta_json(page, pdf_page, pe_num, pe_count, tbl, chart, info): + obj = {"page_number": int(page)} if pdf_page: - metadata_obj["pdf_page"] = pdf_page - # Persist per-page detection counters for end-of-run summaries. - # These may be duplicated across exploded content rows; downstream - # summary logic should dedupe by (source_id, page_number). - pe_num = getattr(row, "page_elements_v3_num_detections", None) + obj["pdf_page"] = pdf_page if pe_num is not None: try: - metadata_obj["page_elements_v3_num_detections"] = int(pe_num) + obj["page_elements_v3_num_detections"] = int(pe_num) except Exception: pass - pe_counts = getattr(row, "page_elements_v3_counts_by_label", None) - if isinstance(pe_counts, dict): - metadata_obj["page_elements_v3_counts_by_label"] = { - str(k): int(v) for k, v in pe_counts.items() if isinstance(k, str) and v is not None + if isinstance(pe_count, dict): + obj["page_elements_v3_counts_by_label"] = { + str(k): int(v) for k, v in pe_count.items() if isinstance(k, str) and v is not None } - for ocr_col in ("table", "chart", "infographic"): - entries = getattr(row, ocr_col, None) + for name, entries in [("table", tbl), ("chart", chart), ("infographic", info)]: if isinstance(entries, list): - metadata_obj[f"ocr_{ocr_col}_detections"] = int(len(entries)) - source_obj = {"source_id": str(path)} - - row_out = { - "vector": emb, - "pdf_page": pdf_page, - "filename": filename, - "pdf_basename": pdf_basename, - "page_number": int(page) if page is not None else -1, - "source_id": str(source_id), - "path": str(path), - "metadata": self._json.dumps(metadata_obj, ensure_ascii=False), - "source": self._json.dumps(source_obj, ensure_ascii=False), - } - - if self._include_text: - t = getattr(row, self._text_column, None) - row_out["text"] = str(t) if isinstance(t, str) else "" - else: - row_out["text"] = "" - - rows.append(row_out) - return rows + obj[f"ocr_{name}_detections"] = len(entries) + return json.dumps(obj, ensure_ascii=False) + + metadata_jsons = [ + _meta_json(pn, pp, pen, pec, ot, oc, oi) + for pn, pp, pen, pec, ot, oc, oi in zip( + page_numbers, + pdf_pages, + pe_nums, + pe_counts_col, + ocr_tables, + ocr_charts, + ocr_infos, + ) + ] + + # --- Source JSON --- + source_jsons = paths.map(lambda p: json.dumps({"source_id": str(p)})).tolist() + + # --- Build PyArrow table --- + return pa.table( + { + "vector": pa.array(embeddings.tolist(), type=pa.list_(pa.float32(), 2048)), + "pdf_page": pa.array(pdf_pages.tolist(), type=pa.string()), + "filename": pa.array(filenames.tolist(), type=pa.string()), + "pdf_basename": pa.array(pdf_basenames.tolist(), type=pa.string()), + "page_number": pa.array(page_numbers.tolist(), type=pa.int32()), + "source_id": pa.array(source_ids.astype(str).tolist(), type=pa.string()), + "path": pa.array(paths.astype(str).tolist(), type=pa.string()), + "text": pa.array(texts.tolist(), type=pa.string()), + "metadata": pa.array(metadata_jsons, type=pa.string()), + "source": pa.array(source_jsons, type=pa.string()), + }, + schema=self._schema, + ) def __call__(self, batch_df: Any) -> Any: - rows = self._build_rows(batch_df) - if rows: - # Infer schema from first batch - if self._table is None: - self._table = self._db.open_table(self._table_name) - self._table.add(rows) + import pandas as pd - self._total_rows += len(rows) + arrow_table = self._build_arrow_table(batch_df) + if arrow_table is not None and arrow_table.num_rows > 0: + self._table.add(arrow_table) + self._total_rows += arrow_table.num_rows - return batch_df + return pd.DataFrame({"_written": range(len(batch_df))}) class _BatchEmbedActor: @@ -393,8 +412,8 @@ def _endpoint_count(raw: Any) -> int: return len([p for p in s.split(",") if p.strip()]) debug_run_id = str(kwargs.pop("debug_run_id", "unknown")) - pdf_split_batch_size = kwargs.pop("pdf_split_batch_size", 1) - pdf_extract_batch_size = kwargs.pop("pdf_extract_batch_size", 4) + pdf_split_batch_size = kwargs.pop("pdf_split_batch_size", 4) + kwargs.pop("pdf_extract_batch_size", None) # consumed but unused after fusing split+extract pdf_extract_num_cpus = float(kwargs.pop("pdf_extract_num_cpus", 2)) page_elements_batch_size = kwargs.pop("page_elements_batch_size", 24) detect_batch_size = kwargs.pop("detect_batch_size", 24) @@ -568,33 +587,21 @@ def _endpoint_count(raw: Any) -> int: batch_format="pandas", ) - # Splitting pdfs is broken into a separate stage to help amortize downstream - # processing if PDFs have vastly different numbers of pages. - pdf_split_actor = PDFSplitActor( - split_params=PdfSplitParams( - start_page=kwargs.get("start_page"), - end_page=kwargs.get("end_page"), - ) - ) + # Fused split + extraction: open each multi-page PDF once and + # extract per-page text/images in a single pass. Eliminates the + # redundant single-page PDF serialization from the old Split stage. + fused_kwargs = dict(kwargs) + fused_kwargs["start_page"] = kwargs.get("start_page") + fused_kwargs["end_page"] = kwargs.get("end_page") self._rd_dataset = self._rd_dataset.map_batches( - pdf_split_actor, + split_and_extract_pdf, batch_size=pdf_split_batch_size, - num_cpus=1, - num_gpus=0, - batch_format="pandas", - ) - - # Pre-split pdfs are now ready for extraction — the main CPU bottleneck. - extraction_actor = PDFExtractionActor(**kwargs) - self._rd_dataset = self._rd_dataset.map_batches( - extraction_actor, - batch_size=pdf_extract_batch_size, batch_format="pandas", num_cpus=pdf_extract_num_cpus, num_gpus=0, compute=rd.TaskPoolStrategy(size=pdf_extract_workers), + fn_kwargs=fused_kwargs, ) - self._rd_dataset = self._rd_dataset.repartition(target_num_rows_per_block=24) # Page-element detection with a GPU actor pool. # For ActorPoolStrategy, Ray Data expects a *callable class* (so it can # construct one instance per actor). Passing an already-constructed diff --git a/retriever/src/retriever/ocr/ocr.py b/retriever/src/retriever/ocr/ocr.py index 380b82f5e..55b78a403 100644 --- a/retriever/src/retriever/ocr/ocr.py +++ b/retriever/src/retriever/ocr/ocr.py @@ -127,9 +127,8 @@ def _crop_all_from_page( if Image is None: # pragma: no cover raise ImportError("Cropping requires pillow.") - if not isinstance(page_image_b64, str) or not page_image_b64: + if not page_image_b64: return [] - try: raw = base64.b64decode(page_image_b64) im0 = Image.open(io.BytesIO(raw)) @@ -448,11 +447,10 @@ def ocr_page_elements( dets = [] # --- get page image --- - page_image = getattr(row, "page_image", None) or {} - page_image_b64 = page_image.get("image_b64") if isinstance(page_image, dict) else None + page_image_data = getattr(row, "page_image", None) or {} + page_img_b64 = page_image_data.get("image_b64") if isinstance(page_image_data, dict) else None - if not isinstance(page_image_b64, str) or not page_image_b64: - # No image available — nothing to crop/OCR. + if not page_img_b64: all_table.append(table_items) all_chart.append(chart_items) all_infographic.append(infographic_items) @@ -460,7 +458,7 @@ def ocr_page_elements( continue # --- decode page image once, crop all matching detections --- - crops = _crop_all_from_page(page_image_b64, dets, wanted_labels) + crops = _crop_all_from_page(page_img_b64, dets, wanted_labels) if use_remote: crop_b64s: List[str] = [] diff --git a/retriever/src/retriever/page_elements/page_elements.py b/retriever/src/retriever/page_elements/page_elements.py index 101e3b7c0..4d7957981 100644 --- a/retriever/src/retriever/page_elements/page_elements.py +++ b/retriever/src/retriever/page_elements/page_elements.py @@ -499,7 +499,8 @@ def detect_page_elements_v3( for _, row in pages_df.iterrows(): try: - b64 = row.get("page_image")["image_b64"] + page_img = row.get("page_image") + b64 = page_img["image_b64"] if not b64: raise ValueError("No usable image_b64 found in row.") row_b64.append(b64) diff --git a/retriever/src/retriever/pdf/extract.py b/retriever/src/retriever/pdf/extract.py index 3a5727aaa..b1dfc382b 100644 --- a/retriever/src/retriever/pdf/extract.py +++ b/retriever/src/retriever/pdf/extract.py @@ -22,11 +22,6 @@ else: # pragma: no cover _PDFIUM_IMPORT_ERROR = None -try: - import numpy as np -except Exception: # pragma: no cover - np = None # type: ignore[assignment] - try: from PIL import Image except Exception: # pragma: no cover @@ -268,7 +263,7 @@ def pdf_extraction( ) render_info: Optional[Dict[str, Any]] = None if want_any_raster: - render_info = _render_page_to_base64(page, dpi=dpi, image_format=image_format) + render_info = _render_page_to_base64(page, dpi=dpi) page_record: Dict[str, Any] = { "path": pdf_path, @@ -322,6 +317,111 @@ def pdf_extraction( raise NotImplementedError("pdf_extraction currently only supports pandas.DataFrame input.") +def split_and_extract_pdf( + pdf_batch: Any, + extract_text: bool = False, + extract_images: bool = False, + extract_tables: bool = False, + extract_charts: bool = False, + extract_infographics: bool = False, + dpi: int = 200, + text_extraction_method: str = "pdfium_hybrid", + text_depth: str = "page", + start_page: int | None = None, + end_page: int | None = None, + **kwargs: Any, +) -> pd.DataFrame: + """Fused split + extraction: open each multi-page PDF once, extract per-page.""" + if not isinstance(pdf_batch, pd.DataFrame): + raise NotImplementedError("split_and_extract_pdf currently only supports pandas.DataFrame input.") + + if pdfium is None: # pragma: no cover + outputs: List[Dict[str, Any]] = [] + for _, row in pdf_batch.iterrows(): + pdf_path = row.get("path") + outputs.append( + _error_record( + source_path=str(pdf_path) if pdf_path is not None else None, + stage="import_pypdfium2", + exc=( + _PDFIUM_IMPORT_ERROR + if _PDFIUM_IMPORT_ERROR is not None + else RuntimeError("pypdfium2 unavailable") + ), + ) + ) + return pd.DataFrame(outputs) + + outputs: List[Dict[str, Any]] = [] + for _, row in pdf_batch.iterrows(): + pdf_bytes = row.get("bytes") + pdf_path = row.get("path") + try: + if not isinstance(pdf_bytes, (bytes, bytearray, memoryview)): + raise RuntimeError(f"Unsupported bytes payload type: {type(pdf_bytes)!r}") + + try: + doc = pdfium.PdfDocument(pdf_bytes) + except Exception: + doc = pdfium.PdfDocument(BytesIO(bytes(pdf_bytes))) + + n_pages = len(doc) + start_idx = 0 if start_page is None else max(int(start_page) - 1, 0) + end_idx = (n_pages - 1) if end_page is None else min(int(end_page) - 1, n_pages - 1) + + for page_idx in range(start_idx, end_idx + 1): + page = doc.get_page(page_idx) + try: + is_scanned = _is_scanned_page(page) + ocr_needed = extract_text and ( + (text_extraction_method == "pdfium_hybrid" and is_scanned) or text_extraction_method == "ocr" + ) + + text = "" + if extract_text and not ocr_needed: + text = _extract_page_text(page) + + want_raster = bool( + extract_images or extract_tables or extract_charts or extract_infographics or ocr_needed + ) + page_image = None + if want_raster: + page_image = _render_page_to_base64(page, dpi=dpi) + + outputs.append( + { + "path": pdf_path, + "page_number": page_idx + 1, + "text": text if extract_text else "", + "page_image": page_image, + "images": [], + "tables": [], + "charts": [], + "infographics": [], + "metadata": { + "has_text": bool(text.strip()) if extract_text else False, + "needs_ocr_for_text": ocr_needed, + "dpi": dpi, + "source_path": pdf_path, + "error": None, + }, + } + ) + finally: + if hasattr(page, "close"): + page.close() + doc.close() + except BaseException as e: + outputs.append( + _error_record( + source_path=str(pdf_path) if pdf_path else None, + stage="split_and_extract", + exc=e, + ) + ) + return pd.DataFrame(outputs) + + @dataclass(slots=True) class PDFExtractionActor: """