From e69acd40abf1703a9dc17f653a7546854db6866f Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 17:22:50 -0500 Subject: [PATCH 01/12] Extract shared embedder factory into model/__init__.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Centralises the resolve → branch → construct pattern for local HF embedding models (VL and non-VL) that was duplicated across batch, inprocess, fused, gpu_pool, recall, retriever, and text_embed code paths into a single `create_local_embedder` factory function. Made-with: Cursor --- .../src/nemo_retriever/ingest_modes/batch.py | 41 ++----- .../src/nemo_retriever/ingest_modes/fused.py | 23 ++-- .../nemo_retriever/ingest_modes/gpu_pool.py | 21 +--- .../nemo_retriever/ingest_modes/inprocess.py | 33 ++---- .../src/nemo_retriever/model/__init__.py | 40 +++++++ .../src/nemo_retriever/recall/core.py | 15 +-- .../src/nemo_retriever/retriever.py | 21 +--- .../nemo_retriever/text_embed/processor.py | 9 +- .../nemo_retriever/text_embed/text_embed.py | 28 ++--- .../tests/test_create_local_embedder.py | 110 ++++++++++++++++++ 10 files changed, 195 insertions(+), 146 deletions(-) create mode 100644 nemo_retriever/tests/test_create_local_embedder.py diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py index a265572ed..2b6f6d262 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py @@ -266,38 +266,15 @@ def __init__(self, params: EmbedParams) -> None: self._model = None return - device = self._kwargs.get("device") - hf_cache_dir = self._kwargs.get("hf_cache_dir") - normalize = bool(self._kwargs.get("normalize", True)) - max_length = int(self._kwargs.get("max_length", 8192)) - model_name_raw = self._kwargs.get("model_name") - - from nemo_retriever.model import is_vl_embed_model, resolve_embed_model - - model_id = resolve_embed_model(model_name_raw) - - if is_vl_embed_model(model_name_raw): - from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import ( - LlamaNemotronEmbedVL1BV2Embedder, - ) - - self._model = LlamaNemotronEmbedVL1BV2Embedder( - device=str(device) if device else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir else None, - model_id=model_id, - ) - else: - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import ( - LlamaNemotronEmbed1BV2Embedder, - ) - - self._model = LlamaNemotronEmbed1BV2Embedder( - device=str(device) if device else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir else None, - normalize=normalize, - max_length=max_length, - model_id=model_id, - ) + from nemo_retriever.model import create_local_embedder + + self._model = create_local_embedder( + self._kwargs.get("model_name"), + device=str(self._kwargs["device"]) if self._kwargs.get("device") else None, + hf_cache_dir=str(self._kwargs["hf_cache_dir"]) if self._kwargs.get("hf_cache_dir") else None, + normalize=bool(self._kwargs.get("normalize", True)), + max_length=int(self._kwargs.get("max_length", 8192)), + ) def __call__(self, batch_df: Any) -> Any: from nemo_retriever.ingest_modes.inprocess import embed_text_main_text_embed diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/fused.py b/nemo_retriever/src/nemo_retriever/ingest_modes/fused.py index 842053bd0..cd4571aa9 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/fused.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/fused.py @@ -55,10 +55,8 @@ class _FusedModelActor: def __init__(self, **kwargs: Any) -> None: _assert_no_remote_endpoints(dict(kwargs), context="actor init") + from nemo_retriever.model import create_local_embedder from nemo_retriever.model.local import NemotronOCRV1, NemotronPageElementsV3 - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import ( - LlamaNemotronEmbed1BV2Embedder, - ) self._detect_kwargs = { "inference_batch_size": int(kwargs.get("inference_batch_size", 8)), @@ -89,13 +87,6 @@ def __init__(self, **kwargs: Any) -> None: "has_embedding_column": str(kwargs.get("has_embedding_column", "text_embeddings_1b_v2_has_embedding")), } - device = kwargs.get("device") - hf_cache_dir = kwargs.get("hf_cache_dir") - normalize = bool(kwargs.get("normalize", True)) - max_length = int(kwargs.get("max_length", 8192)) - model_name_raw = kwargs.get("model_name") - model_id = model_name_raw if (isinstance(model_name_raw, str) and "/" in model_name_raw) else None - self._page_elements_model = NemotronPageElementsV3() self._ocr_model = NemotronOCRV1() self._table_structure_model = None @@ -103,12 +94,12 @@ def __init__(self, **kwargs: Any) -> None: from nemo_retriever.model.local import NemotronTableStructureV1 self._table_structure_model = NemotronTableStructureV1() - self._embed_model = LlamaNemotronEmbed1BV2Embedder( - device=str(device) if device else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir else None, - normalize=normalize, - max_length=max_length, - model_id=model_id, + self._embed_model = create_local_embedder( + kwargs.get("model_name"), + device=str(kwargs["device"]) if kwargs.get("device") else None, + hf_cache_dir=str(kwargs["hf_cache_dir"]) if kwargs.get("hf_cache_dir") else None, + normalize=bool(kwargs.get("normalize", True)), + max_length=int(kwargs.get("max_length", 8192)), ) def __call__(self, batch_df: Any) -> Any: diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/gpu_pool.py b/nemo_retriever/src/nemo_retriever/ingest_modes/gpu_pool.py index 775ded97a..cb1aa019a 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/gpu_pool.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/gpu_pool.py @@ -77,29 +77,14 @@ class EmbeddingModelConfig: model_id: Optional[str] = None def create(self) -> Any: - from nemo_retriever.model import is_vl_embed_model + from nemo_retriever.model import create_local_embedder - if is_vl_embed_model(self.model_id): - from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import ( - LlamaNemotronEmbedVL1BV2Embedder, - ) - - return LlamaNemotronEmbedVL1BV2Embedder( - device=self.device, - hf_cache_dir=self.hf_cache_dir, - model_id=self.model_id, - ) - - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import ( - LlamaNemotronEmbed1BV2Embedder, - ) - - return LlamaNemotronEmbed1BV2Embedder( + return create_local_embedder( + self.model_id, device=self.device, hf_cache_dir=self.hf_cache_dir, normalize=self.normalize, max_length=self.max_length, - model_id=self.model_id, ) diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index 087c10749..0e09dd787 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -29,7 +29,6 @@ import pandas as pd from nemo_retriever.model.local import NemotronOCRV1, NemotronPageElementsV3, NemotronParseV12 -from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder from nemo_retriever.page_elements import detect_page_elements_v3 from nemo_retriever.ocr.ocr import _crop_b64_image_by_norm_bbox, nemotron_parse_page_elements, ocr_page_elements from nemo_retriever.table.table_detection import table_structure_ocr_page_elements @@ -1507,38 +1506,22 @@ def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "InProcessI return self # Local HF embedder path. - # Allow callers to control device / max_length to avoid OOMs. device = embed_kwargs.pop("device", None) hf_cache_dir = embed_kwargs.pop("hf_cache_dir", None) normalize = bool(embed_kwargs.pop("normalize", True)) max_length = int(embed_kwargs.pop("max_length", 8192)) - model_name_raw = embed_kwargs.pop("model_name", None) - from nemo_retriever.model import is_vl_embed_model, resolve_embed_model - - model_id = resolve_embed_model(model_name_raw) + from nemo_retriever.model import create_local_embedder embed_kwargs.setdefault("input_type", "passage") - - if is_vl_embed_model(model_name_raw): - from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import ( - LlamaNemotronEmbedVL1BV2Embedder, - ) - - embed_kwargs["model"] = LlamaNemotronEmbedVL1BV2Embedder( - device=str(device) if device is not None else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, - model_id=model_id, - ) - else: - embed_kwargs["model"] = LlamaNemotronEmbed1BV2Embedder( - device=str(device) if device is not None else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, - normalize=normalize, - max_length=max_length, - model_id=model_id, - ) + embed_kwargs["model"] = create_local_embedder( + model_name_raw, + device=str(device) if device is not None else None, + hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, + normalize=normalize, + max_length=max_length, + ) self._tasks.append((embed_text_main_text_embed, embed_kwargs)) return self diff --git a/nemo_retriever/src/nemo_retriever/model/__init__.py b/nemo_retriever/src/nemo_retriever/model/__init__.py index dc763d548..cef002494 100644 --- a/nemo_retriever/src/nemo_retriever/model/__init__.py +++ b/nemo_retriever/src/nemo_retriever/model/__init__.py @@ -33,3 +33,43 @@ def resolve_embed_model(model_name: str | None) -> str: def is_vl_embed_model(model_name: str | None) -> bool: """Return True if *model_name* refers to the VL embedding model.""" return resolve_embed_model(model_name) in _VL_EMBED_MODEL_IDS + + +def create_local_embedder( + model_name: str | None = None, + *, + device: str | None = None, + hf_cache_dir: str | None = None, + normalize: bool = True, + max_length: int = 8192, +): + """Create the appropriate local embedding model (VL or non-VL). + + Centralises the resolve -> branch -> construct pattern that was previously + duplicated across batch, inprocess, fused, gpu_pool, recall, retriever, + and text_embed code paths. + """ + model_id = resolve_embed_model(model_name) + + if is_vl_embed_model(model_name): + from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import ( + LlamaNemotronEmbedVL1BV2Embedder, + ) + + return LlamaNemotronEmbedVL1BV2Embedder( + device=device, + hf_cache_dir=hf_cache_dir, + model_id=model_id, + ) + + from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import ( + LlamaNemotronEmbed1BV2Embedder, + ) + + return LlamaNemotronEmbed1BV2Embedder( + device=device, + hf_cache_dir=hf_cache_dir, + normalize=normalize, + max_length=max_length, + model_id=model_id, + ) diff --git a/nemo_retriever/src/nemo_retriever/recall/core.py b/nemo_retriever/src/nemo_retriever/recall/core.py index f5dbe1e68..b684a3f6d 100644 --- a/nemo_retriever/src/nemo_retriever/recall/core.py +++ b/nemo_retriever/src/nemo_retriever/recall/core.py @@ -168,25 +168,14 @@ def _embed_queries_local_hf( batch_size: int, model_name: Optional[str] = None, ) -> List[List[float]]: - # Lazy import: only load torch/HF when needed. - from nemo_retriever.model import is_vl_embed_model, resolve_embed_model + from nemo_retriever.model import create_local_embedder, is_vl_embed_model - model_id = resolve_embed_model(model_name) + embedder = create_local_embedder(model_name, device=device, hf_cache_dir=cache_dir) if is_vl_embed_model(model_name): - from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import LlamaNemotronEmbedVL1BV2Embedder - - embedder = LlamaNemotronEmbedVL1BV2Embedder(device=device, hf_cache_dir=cache_dir, model_id=model_id) - # VL model handles query formatting internally via encode_queries(). vecs = embedder.embed_queries(queries, batch_size=int(batch_size)) else: - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder - - embedder = LlamaNemotronEmbed1BV2Embedder( - device=device, hf_cache_dir=cache_dir, normalize=True, model_id=model_id - ) vecs = embedder.embed(["query: " + q for q in queries], batch_size=int(batch_size)) - # Ensure list-of-list floats. return vecs.detach().to("cpu").tolist() diff --git a/nemo_retriever/src/nemo_retriever/retriever.py b/nemo_retriever/src/nemo_retriever/retriever.py index 28ab35d4b..e018aa426 100644 --- a/nemo_retriever/src/nemo_retriever/retriever.py +++ b/nemo_retriever/src/nemo_retriever/retriever.py @@ -66,31 +66,14 @@ def _embed_queries_nim( return out def _embed_queries_local_hf(self, query_texts: list[str], *, model_name: str) -> list[list[float]]: - from nemo_retriever.model import is_vl_embed_model, resolve_embed_model + from nemo_retriever.model import create_local_embedder, is_vl_embed_model - model_id = resolve_embed_model(model_name) cache_dir = str(self.local_hf_cache_dir) if self.local_hf_cache_dir else None + embedder = create_local_embedder(model_name, device=self.local_hf_device, hf_cache_dir=cache_dir) if is_vl_embed_model(model_name): - from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import ( - LlamaNemotronEmbedVL1BV2Embedder, - ) - - embedder = LlamaNemotronEmbedVL1BV2Embedder( - device=self.local_hf_device, - hf_cache_dir=cache_dir, - model_id=model_id, - ) vectors = embedder.embed_queries(query_texts, batch_size=int(self.local_hf_batch_size)) else: - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder - - embedder = LlamaNemotronEmbed1BV2Embedder( - device=self.local_hf_device, - hf_cache_dir=cache_dir, - normalize=True, - model_id=model_id, - ) vectors = embedder.embed(["query: " + q for q in query_texts], batch_size=int(self.local_hf_batch_size)) return vectors.detach().to("cpu").tolist() diff --git a/nemo_retriever/src/nemo_retriever/text_embed/processor.py b/nemo_retriever/src/nemo_retriever/text_embed/processor.py index 555f7e28b..81dd4b8a6 100644 --- a/nemo_retriever/src/nemo_retriever/text_embed/processor.py +++ b/nemo_retriever/src/nemo_retriever/text_embed/processor.py @@ -89,14 +89,17 @@ def maybe_inject_local_hf_embedder(task_config: Dict[str, Any], transform_config if has_endpoint or not use_local: return - # Lazy import: only load torch/HF when we truly need local embeddings. - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder + from nemo_retriever.model import create_local_embedder local_device = task_config.get("local_hf_device") local_cache_dir = task_config.get("local_hf_cache_dir") local_batch_size = int(task_config.get("local_hf_batch_size") or 64) - embedder = LlamaNemotronEmbed1BV2Embedder(device=local_device, hf_cache_dir=local_cache_dir, normalize=True) + embedder = create_local_embedder( + task_config.get("model_name"), + device=local_device, + hf_cache_dir=local_cache_dir, + ) def _embed(texts): prefix = f"{transform_config.input_type}: " if getattr(transform_config, "input_type", None) else "" diff --git a/nemo_retriever/src/nemo_retriever/text_embed/text_embed.py b/nemo_retriever/src/nemo_retriever/text_embed/text_embed.py index 6d5d30610..5889d7627 100644 --- a/nemo_retriever/src/nemo_retriever/text_embed/text_embed.py +++ b/nemo_retriever/src/nemo_retriever/text_embed/text_embed.py @@ -190,28 +190,16 @@ def __init__(self, **detect_kwargs: Any) -> None: hf_cache_dir = self.detect_kwargs.pop("hf_cache_dir", None) normalize = bool(self.detect_kwargs.pop("normalize", True)) max_length = self.detect_kwargs.pop("max_length", 4096) - model_name = self.detect_kwargs.get("model_name") - from nemo_retriever.model import is_vl_embed_model + from nemo_retriever.model import create_local_embedder - if is_vl_embed_model(model_name): - from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import ( - LlamaNemotronEmbedVL1BV2Embedder, - ) - - self._model = LlamaNemotronEmbedVL1BV2Embedder( - device=str(device) if device is not None else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, - ) - else: - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder - - self._model = LlamaNemotronEmbed1BV2Embedder( - device=str(device) if device is not None else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, - normalize=normalize, - max_length=int(max_length), - ) + self._model = create_local_embedder( + self.detect_kwargs.get("model_name"), + device=str(device) if device is not None else None, + hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, + normalize=normalize, + max_length=int(max_length), + ) def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: diff --git a/nemo_retriever/tests/test_create_local_embedder.py b/nemo_retriever/tests/test_create_local_embedder.py new file mode 100644 index 000000000..6ba3fb9c5 --- /dev/null +++ b/nemo_retriever/tests/test_create_local_embedder.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for nemo_retriever.model.create_local_embedder factory.""" + +import sys +from types import ModuleType +from unittest.mock import MagicMock + +import pytest + +from nemo_retriever.model import create_local_embedder + + +@pytest.fixture(autouse=True) +def _patch_embedders(monkeypatch): + """Prevent real model downloads by stubbing both embedder classes. + + The ``nemo_retriever.model.local`` package uses a custom ``__getattr__`` + that only exposes specific class names — not submodule names. Because + ``monkeypatch.setattr`` resolves each path segment via ``getattr``, it + cannot traverse to the submodule. We work around this by injecting fake + modules directly into ``sys.modules``, which Python checks first when + handling ``from … import`` statements. + """ + fake_text = MagicMock(name="LlamaNemotronEmbed1BV2Embedder") + fake_vl = MagicMock(name="LlamaNemotronEmbedVL1BV2Embedder") + + text_mod = ModuleType("nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder") + text_mod.LlamaNemotronEmbed1BV2Embedder = fake_text + + vl_mod = ModuleType("nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder") + vl_mod.LlamaNemotronEmbedVL1BV2Embedder = fake_vl + + monkeypatch.setitem(sys.modules, "nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder", text_mod) + monkeypatch.setitem(sys.modules, "nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder", vl_mod) + + yield fake_text, fake_vl + + +def test_default_returns_text_embedder(_patch_embedders): + fake_text, _ = _patch_embedders + result = create_local_embedder() + fake_text.assert_called_once() + assert result is fake_text.return_value + + +def test_none_model_name_returns_text_embedder(_patch_embedders): + fake_text, _ = _patch_embedders + result = create_local_embedder(None) + fake_text.assert_called_once() + assert result is fake_text.return_value + + +def test_alias_resolved_to_text_embedder(_patch_embedders): + fake_text, _ = _patch_embedders + result = create_local_embedder("nemo_retriever_v1") + call_kwargs = fake_text.call_args + assert call_kwargs.kwargs["model_id"] == "nvidia/llama-3.2-nv-embedqa-1b-v2" + assert result is fake_text.return_value + + +def test_vl_model_returns_vl_embedder(_patch_embedders): + _, fake_vl = _patch_embedders + result = create_local_embedder("nvidia/llama-nemotron-embed-vl-1b-v2") + fake_vl.assert_called_once() + assert result is fake_vl.return_value + + +def test_vl_short_alias_returns_vl_embedder(_patch_embedders): + _, fake_vl = _patch_embedders + result = create_local_embedder("llama-nemotron-embed-vl-1b-v2") + fake_vl.assert_called_once() + assert result is fake_vl.return_value + + +def test_kwargs_forwarded_to_text_embedder(_patch_embedders): + fake_text, _ = _patch_embedders + create_local_embedder( + device="cuda:1", + hf_cache_dir="/tmp/cache", + normalize=False, + max_length=4096, + ) + kw = fake_text.call_args.kwargs + assert kw["device"] == "cuda:1" + assert kw["hf_cache_dir"] == "/tmp/cache" + assert kw["normalize"] is False + assert kw["max_length"] == 4096 + + +def test_kwargs_forwarded_to_vl_embedder(_patch_embedders): + _, fake_vl = _patch_embedders + create_local_embedder( + "nvidia/llama-nemotron-embed-vl-1b-v2", + device="cuda:0", + hf_cache_dir="/models", + ) + kw = fake_vl.call_args.kwargs + assert kw["device"] == "cuda:0" + assert kw["hf_cache_dir"] == "/models" + assert kw["model_id"] == "nvidia/llama-nemotron-embed-vl-1b-v2" + + +def test_unknown_model_passes_through(_patch_embedders): + fake_text, _ = _patch_embedders + create_local_embedder("custom-org/my-embed-model") + kw = fake_text.call_args.kwargs + assert kw["model_id"] == "custom-org/my-embed-model" From 7fe8d21c3b22ce99cb16c07a082cbbfa1385cefa Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 17:24:44 -0500 Subject: [PATCH 02/12] Consolidate LanceDB row construction, schema, and table creation Extracts duplicated LanceDB row-building, schema definition, and table-creation logic from batch.py and inprocess.py into a shared ingest_modes/lancedb_utils.py module. Made-with: Cursor --- .../src/nemo_retriever/ingest_modes/batch.py | 127 ++-------- .../nemo_retriever/ingest_modes/inprocess.py | 169 ++----------- .../ingest_modes/lancedb_utils.py | 226 ++++++++++++++++++ nemo_retriever/tests/test_lancedb_utils.py | 194 +++++++++++++++ 4 files changed, 452 insertions(+), 264 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/ingest_modes/lancedb_utils.py create mode 100644 nemo_retriever/tests/test_lancedb_utils.py diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py index 2b6f6d262..b9b697fdd 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py @@ -86,11 +86,8 @@ class _LanceDBWriteActor: """ def __init__(self, params: VdbUploadParams | None = None) -> None: - import json - from pathlib import Path + from nemo_retriever.ingest_modes.lancedb_utils import lancedb_schema - self._json = json - self._Path = Path lancedb_params = (params or VdbUploadParams()).lancedb self._lancedb_uri = lancedb_params.lancedb_uri @@ -102,30 +99,13 @@ def __init__(self, params: VdbUploadParams | None = None) -> None: self._text_column = lancedb_params.text_column import lancedb # type: ignore - import pyarrow as pa # type: ignore - self._pa = pa self._db = lancedb.connect(uri=self._lancedb_uri) - self._table = None - self._schema = None - self._first_batch = True self._total_rows = 0 - self._table = None - mode = "overwrite" if self._overwrite else "create" - fields = [ - pa.field("vector", pa.list_(pa.float32(), 2048)), - pa.field("pdf_page", pa.string()), - pa.field("filename", pa.string()), - pa.field("pdf_basename", pa.string()), - pa.field("page_number", pa.int32()), - pa.field("source_id", pa.string()), - pa.field("path", pa.string()), - pa.field("text", pa.string()), - pa.field("metadata", pa.string()), - pa.field("source", pa.string()), - ] - self._schema = pa.schema(fields) + # Use a default dim for the initial empty table; rows are appended via add(). + self._schema = lancedb_schema(2048) + mode = "overwrite" if self._overwrite else "create" self._table = self._db.create_table( self._table_name, schema=self._schema, @@ -133,95 +113,16 @@ def __init__(self, params: VdbUploadParams | None = None) -> None: ) def _build_rows(self, df: Any) -> list: - """Build LanceDB rows from a pandas DataFrame batch. - - Mirrors the row-building logic from - ``upload_embeddings_to_lancedb_inprocess`` in inprocess.py. - """ - rows: list = [] - for row in df.itertuples(index=False): - # Extract embedding - emb = None - meta = getattr(row, "metadata", None) - if isinstance(meta, dict): - emb = meta.get("embedding") - if not (isinstance(emb, list) and emb): - emb = None - if emb is None: - payload = getattr(row, self._embedding_column, None) - if isinstance(payload, dict): - emb = payload.get(self._embedding_key) - if not (isinstance(emb, list) and emb): - emb = None - if emb is None: - continue - - # Extract source path and page number - path = "" - page = -1 - v = getattr(row, "path", None) - if isinstance(v, str) and v.strip(): - path = v.strip() - v = getattr(row, "page_number", None) - try: - if v is not None: - page = int(v) - except Exception: - pass - if isinstance(meta, dict): - sp = meta.get("source_path") - if isinstance(sp, str) and sp.strip(): - path = sp.strip() - - p = self._Path(path) if path else None - filename = p.name if p is not None else "" - pdf_basename = p.stem if p is not None else "" - pdf_page = f"{pdf_basename}_{page}" if (pdf_basename and page >= 0) else "" - source_id = path or filename or pdf_basename - - metadata_obj = {"page_number": int(page) if page is not None else -1} - if pdf_page: - metadata_obj["pdf_page"] = pdf_page - # Persist per-page detection counters for end-of-run summaries. - # These may be duplicated across exploded content rows; downstream - # summary logic should dedupe by (source_id, page_number). - pe_num = getattr(row, "page_elements_v3_num_detections", None) - if pe_num is not None: - try: - metadata_obj["page_elements_v3_num_detections"] = int(pe_num) - except Exception: - pass - pe_counts = getattr(row, "page_elements_v3_counts_by_label", None) - if isinstance(pe_counts, dict): - metadata_obj["page_elements_v3_counts_by_label"] = { - str(k): int(v) for k, v in pe_counts.items() if isinstance(k, str) and v is not None - } - for ocr_col in ("table", "chart", "infographic"): - entries = getattr(row, ocr_col, None) - if isinstance(entries, list): - metadata_obj[f"ocr_{ocr_col}_detections"] = int(len(entries)) - source_obj = {"source_id": str(path)} - - row_out = { - "vector": emb, - "pdf_page": pdf_page, - "filename": filename, - "pdf_basename": pdf_basename, - "page_number": int(page) if page is not None else -1, - "source_id": str(source_id), - "path": str(path), - "metadata": self._json.dumps(metadata_obj, ensure_ascii=False), - "source": self._json.dumps(source_obj, ensure_ascii=False), - } - - if self._include_text: - t = getattr(row, self._text_column, None) - row_out["text"] = str(t) if isinstance(t, str) else "" - else: - row_out["text"] = "" - - rows.append(row_out) - return rows + """Build LanceDB rows from a pandas DataFrame batch.""" + from nemo_retriever.ingest_modes.lancedb_utils import build_lancedb_rows + + return build_lancedb_rows( + df, + embedding_column=self._embedding_column, + embedding_key=self._embedding_key, + text_column=self._text_column, + include_text=self._include_text, + ) def __call__(self, batch_df: Any) -> Any: rows = self._build_rows(batch_df) diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index 0e09dd787..f4ed21596 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -634,67 +634,14 @@ def save_dataframe_to_disk_json(df: Any, *, output_directory: str) -> Any: return df -def _extract_embedding_from_row( - row: Any, - *, - embedding_column: str = "text_embeddings_1b_v2", - embedding_key: str = "embedding", -) -> Optional[List[float]]: - """ - Extract an embedding vector from a row (namedtuple or pd.Series). - - Supports: - - `metadata.embedding` (preferred if present) - - `embedding_column` payloads like `{"embedding": [...], ...}` (from `embed_text_1b_v2`) - """ - meta = getattr(row, "metadata", None) - if isinstance(meta, dict): - emb = meta.get("embedding") - if isinstance(emb, list) and emb: - return emb # type: ignore[return-value] - - payload = getattr(row, embedding_column, None) - if isinstance(payload, dict): - emb = payload.get(embedding_key) - if isinstance(emb, list) and emb: - return emb # type: ignore[return-value] - return None - - -def _extract_source_path_and_page(row: Any) -> Tuple[str, int]: - """ - Best-effort extract of source path and page number for LanceDB row metadata. - """ - path = "" - page = -1 - - v = getattr(row, "path", None) - if isinstance(v, str) and v.strip(): - path = v.strip() - - v = getattr(row, "page_number", None) - try: - if v is not None: - page = int(v) - except Exception: - pass - - meta = getattr(row, "metadata", None) - if isinstance(meta, dict): - sp = meta.get("source_path") - if isinstance(sp, str) and sp.strip(): - path = sp.strip() - # Some schemas store page under content metadata; support if present. - cm = meta.get("content_metadata") - if isinstance(cm, dict) and page == -1: - h = cm.get("hierarchy") - if isinstance(h, dict) and "page" in h: - try: - page = int(h.get("page")) - except Exception: - pass - - return path, page +from nemo_retriever.ingest_modes.lancedb_utils import ( + build_lancedb_rows, + create_or_append_lancedb_table, + extract_embedding_from_row as _extract_embedding_from_row, + extract_source_path_and_page as _extract_source_path_and_page, + infer_vector_dim, + lancedb_schema, +) def upload_embeddings_to_lancedb_inprocess( @@ -744,112 +691,32 @@ def upload_embeddings_to_lancedb_inprocess( if not isinstance(df, pd.DataFrame): raise TypeError(f"upload_embeddings_to_lancedb_inprocess expects pandas.DataFrame, got {type(df)!r}") - rows: List[Dict[str, Any]] = [] - for r in df.itertuples(index=False): - emb = _extract_embedding_from_row(r, embedding_column=str(embedding_column), embedding_key=str(embedding_key)) - if emb is None: - continue - - path, page_number = _extract_source_path_and_page(r) - p = Path(path) if path else None - filename = p.name if p is not None else "" - pdf_basename = p.stem if p is not None else "" - pdf_page = f"{pdf_basename}_{page_number}" if (pdf_basename and page_number >= 0) else "" - source_id = path or filename or pdf_basename - - # Provide fields compatible with `nemo_retriever.recall.core` which expects LanceDB hits - # to include JSON-encoded `metadata` and `source` strings. - metadata_obj: Dict[str, Any] = {"page_number": int(page_number) if page_number is not None else -1} - if pdf_page: - metadata_obj["pdf_page"] = pdf_page - # Persist per-page detection counters for end-of-run summaries. - # Mirrors batch.py so LanceDB-based summary reads also work. - pe_num = getattr(r, "page_elements_v3_num_detections", None) - if pe_num is not None: - try: - metadata_obj["page_elements_v3_num_detections"] = int(pe_num) - except Exception: - pass - pe_counts = getattr(r, "page_elements_v3_counts_by_label", None) - if isinstance(pe_counts, dict): - metadata_obj["page_elements_v3_counts_by_label"] = { - str(k): int(v) for k, v in pe_counts.items() if isinstance(k, str) and v is not None - } - for ocr_col in ("table", "chart", "infographic"): - entries = getattr(r, ocr_col, None) - if isinstance(entries, list): - metadata_obj[f"ocr_{ocr_col}_detections"] = int(len(entries)) - source_obj: Dict[str, Any] = {"source_id": str(path)} - - row_out: Dict[str, Any] = { - "vector": emb, - "pdf_page": pdf_page, - "filename": filename, - "pdf_basename": pdf_basename, - "page_number": int(page_number) if page_number is not None else -1, - "source_id": str(source_id), - "path": str(path), - "metadata": json.dumps(metadata_obj, ensure_ascii=False), - "source": json.dumps(source_obj, ensure_ascii=False), - } - - if include_text: - t = getattr(r, text_column, None) - row_out["text"] = str(t) if isinstance(t, str) else "" - else: - # Still include the column for compatibility with the recall script's `.select(["text",...])`. - row_out["text"] = "" - - rows.append(row_out) + rows = build_lancedb_rows( + df, + embedding_column=str(embedding_column), + embedding_key=str(embedding_key), + text_column=str(text_column), + include_text=bool(include_text), + ) if not rows: print("No embeddings found to upload to LanceDB (no rows had embeddings).") return df - # Infer vector dim from first row. - dim = 0 - for rr in rows: - v = rr.get("vector") - if isinstance(v, list) and v: - dim = int(len(v)) - break + dim = infer_vector_dim(rows) if dim <= 0: raise ValueError("Failed to infer embedding dimension from DataFrame rows.") try: import lancedb # type: ignore - import pyarrow as pa # type: ignore except Exception as e: raise RuntimeError( "LanceDB upload requested but dependencies are missing. Install `lancedb` and `pyarrow`." ) from e db = lancedb.connect(uri=str(lancedb_uri)) - - fields = [ - pa.field("vector", pa.list_(pa.float32(), dim)), - pa.field("pdf_page", pa.string()), - pa.field("filename", pa.string()), - pa.field("pdf_basename", pa.string()), - pa.field("page_number", pa.int32()), - pa.field("source_id", pa.string()), - pa.field("path", pa.string()), - # Compatibility columns expected by `nemo_retriever.recall.core`: - pa.field("text", pa.string()), - pa.field("metadata", pa.string()), - pa.field("source", pa.string()), - ] - schema = pa.schema(fields) - - # Overwrite vs append. - if overwrite: - table = db.create_table(str(table_name), data=list(rows), schema=schema, mode="overwrite") - else: - try: - table = db.open_table(str(table_name)) - table.add(list(rows)) - except Exception: - table = db.create_table(str(table_name), data=list(rows), schema=schema, mode="create") + schema = lancedb_schema(dim) + table = create_or_append_lancedb_table(db, str(table_name), rows, schema, overwrite=overwrite) if create_index: # LanceDB IVF-based indexes train k-means with K=num_partitions. K must be < N vectors. diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/lancedb_utils.py b/nemo_retriever/src/nemo_retriever/ingest_modes/lancedb_utils.py new file mode 100644 index 000000000..f74f787eb --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/lancedb_utils.py @@ -0,0 +1,226 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared LanceDB row construction, schema, and table helpers. + +Consolidates the duplicated logic that previously lived independently in +``inprocess.py`` (``upload_embeddings_to_lancedb_inprocess``) and +``batch.py`` (``_LanceDBWriteActor._build_rows``). +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +def extract_embedding_from_row( + row: Any, + *, + embedding_column: str = "text_embeddings_1b_v2", + embedding_key: str = "embedding", +) -> Optional[List[float]]: + """Extract an embedding vector from a row (namedtuple or pd.Series). + + Supports: + - ``metadata.embedding`` (preferred if present) + - *embedding_column* payloads like ``{"embedding": [...], ...}`` + """ + meta = getattr(row, "metadata", None) + if isinstance(meta, dict): + emb = meta.get("embedding") + if isinstance(emb, list) and emb: + return emb # type: ignore[return-value] + + payload = getattr(row, embedding_column, None) + if isinstance(payload, dict): + emb = payload.get(embedding_key) + if isinstance(emb, list) and emb: + return emb # type: ignore[return-value] + return None + + +def extract_source_path_and_page(row: Any) -> Tuple[str, int]: + """Best-effort extract of source path and page number from a row.""" + path = "" + page = -1 + + v = getattr(row, "path", None) + if isinstance(v, str) and v.strip(): + path = v.strip() + + v = getattr(row, "page_number", None) + try: + if v is not None: + page = int(v) + except Exception: + pass + + meta = getattr(row, "metadata", None) + if isinstance(meta, dict): + sp = meta.get("source_path") + if isinstance(sp, str) and sp.strip(): + path = sp.strip() + cm = meta.get("content_metadata") + if isinstance(cm, dict) and page == -1: + h = cm.get("hierarchy") + if isinstance(h, dict) and "page" in h: + try: + page = int(h.get("page")) + except Exception: + pass + + return path, page + + +def _build_detection_metadata(row: Any) -> Dict[str, Any]: + """Extract per-page detection counters from a row for LanceDB metadata.""" + out: Dict[str, Any] = {} + + pe_num = getattr(row, "page_elements_v3_num_detections", None) + if pe_num is not None: + try: + out["page_elements_v3_num_detections"] = int(pe_num) + except Exception: + pass + + pe_counts = getattr(row, "page_elements_v3_counts_by_label", None) + if isinstance(pe_counts, dict): + out["page_elements_v3_counts_by_label"] = { + str(k): int(v) for k, v in pe_counts.items() if isinstance(k, str) and v is not None + } + + for ocr_col in ("table", "chart", "infographic"): + entries = getattr(row, ocr_col, None) + if isinstance(entries, list): + out[f"ocr_{ocr_col}_detections"] = int(len(entries)) + + return out + + +def build_lancedb_row( + row: Any, + *, + embedding_column: str = "text_embeddings_1b_v2", + embedding_key: str = "embedding", + text_column: str = "text", + include_text: bool = True, +) -> Optional[Dict[str, Any]]: + """Build a single LanceDB-ready dict from a DataFrame row. + + Returns ``None`` when no embedding is found in the row. + """ + emb = extract_embedding_from_row(row, embedding_column=embedding_column, embedding_key=embedding_key) + if emb is None: + return None + + path, page_number = extract_source_path_and_page(row) + p = Path(path) if path else None + filename = p.name if p is not None else "" + pdf_basename = p.stem if p is not None else "" + pdf_page = f"{pdf_basename}_{page_number}" if (pdf_basename and page_number >= 0) else "" + source_id = path or filename or pdf_basename + + metadata_obj: Dict[str, Any] = {"page_number": int(page_number) if page_number is not None else -1} + if pdf_page: + metadata_obj["pdf_page"] = pdf_page + metadata_obj.update(_build_detection_metadata(row)) + + source_obj: Dict[str, Any] = {"source_id": str(path)} + + row_out: Dict[str, Any] = { + "vector": emb, + "pdf_page": pdf_page, + "filename": filename, + "pdf_basename": pdf_basename, + "page_number": int(page_number) if page_number is not None else -1, + "source_id": str(source_id), + "path": str(path), + "metadata": json.dumps(metadata_obj, ensure_ascii=False), + "source": json.dumps(source_obj, ensure_ascii=False), + } + + if include_text: + t = getattr(row, text_column, None) + row_out["text"] = str(t) if isinstance(t, str) else "" + else: + row_out["text"] = "" + + return row_out + + +def build_lancedb_rows( + df: Any, + *, + embedding_column: str = "text_embeddings_1b_v2", + embedding_key: str = "embedding", + text_column: str = "text", + include_text: bool = True, +) -> List[Dict[str, Any]]: + """Build LanceDB rows from a pandas DataFrame. + + Iterates with ``itertuples`` and delegates to :func:`build_lancedb_row`. + Rows without an embedding are silently skipped. + """ + rows: List[Dict[str, Any]] = [] + for r in df.itertuples(index=False): + row_out = build_lancedb_row( + r, + embedding_column=embedding_column, + embedding_key=embedding_key, + text_column=text_column, + include_text=include_text, + ) + if row_out is not None: + rows.append(row_out) + return rows + + +def lancedb_schema(vector_dim: int) -> Any: + """Return a PyArrow schema for the standard LanceDB table layout.""" + import pyarrow as pa # type: ignore + + return pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), vector_dim)), + pa.field("pdf_page", pa.string()), + pa.field("filename", pa.string()), + pa.field("pdf_basename", pa.string()), + pa.field("page_number", pa.int32()), + pa.field("source_id", pa.string()), + pa.field("path", pa.string()), + pa.field("text", pa.string()), + pa.field("metadata", pa.string()), + pa.field("source", pa.string()), + ] + ) + + +def infer_vector_dim(rows: List[Dict[str, Any]]) -> int: + """Return the embedding dimension from the first row that has a vector.""" + for r in rows: + v = r.get("vector") + if isinstance(v, list) and v: + return len(v) + return 0 + + +def create_or_append_lancedb_table( + db: Any, + table_name: str, + rows: List[Dict[str, Any]], + schema: Any, + overwrite: bool = True, +) -> Any: + """Create or append to a LanceDB table, returning the table object.""" + if overwrite: + return db.create_table(str(table_name), data=list(rows), schema=schema, mode="overwrite") + + try: + table = db.open_table(str(table_name)) + table.add(list(rows)) + return table + except Exception: + return db.create_table(str(table_name), data=list(rows), schema=schema, mode="create") diff --git a/nemo_retriever/tests/test_lancedb_utils.py b/nemo_retriever/tests/test_lancedb_utils.py new file mode 100644 index 000000000..cf4d3cdcc --- /dev/null +++ b/nemo_retriever/tests/test_lancedb_utils.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for nemo_retriever.ingest_modes.lancedb_utils.""" + +import json +from types import SimpleNamespace + +import pytest + +from nemo_retriever.ingest_modes.lancedb_utils import ( + build_lancedb_row, + build_lancedb_rows, + create_or_append_lancedb_table, + extract_embedding_from_row, + extract_source_path_and_page, + infer_vector_dim, + lancedb_schema, +) + + +class TestExtractEmbeddingFromRow: + def test_from_metadata(self): + row = SimpleNamespace(metadata={"embedding": [1.0, 2.0, 3.0]}) + assert extract_embedding_from_row(row) == [1.0, 2.0, 3.0] + + def test_from_embedding_column(self): + row = SimpleNamespace( + metadata=None, + text_embeddings_1b_v2={"embedding": [4.0, 5.0]}, + ) + assert extract_embedding_from_row(row) == [4.0, 5.0] + + def test_custom_column(self): + row = SimpleNamespace(metadata=None, my_col={"vec": [6.0]}) + assert extract_embedding_from_row(row, embedding_column="my_col", embedding_key="vec") == [6.0] + + def test_returns_none_when_missing(self): + row = SimpleNamespace(metadata=None) + assert extract_embedding_from_row(row) is None + + def test_empty_embedding_returns_none(self): + row = SimpleNamespace(metadata={"embedding": []}) + assert extract_embedding_from_row(row) is None + + +class TestExtractSourcePathAndPage: + def test_from_direct_attrs(self): + row = SimpleNamespace(path="/docs/file.pdf", page_number=3, metadata=None) + assert extract_source_path_and_page(row) == ("/docs/file.pdf", 3) + + def test_from_metadata_source_path(self): + row = SimpleNamespace(path="", page_number=None, metadata={"source_path": "/meta/path.pdf"}) + assert extract_source_path_and_page(row) == ("/meta/path.pdf", -1) + + def test_from_content_metadata_hierarchy(self): + row = SimpleNamespace( + path="", + page_number=None, + metadata={"content_metadata": {"hierarchy": {"page": 7}}}, + ) + path, page = extract_source_path_and_page(row) + assert page == 7 + + def test_defaults_when_missing(self): + row = SimpleNamespace() + assert extract_source_path_and_page(row) == ("", -1) + + +class TestBuildLancedbRow: + def _row(self, **kwargs): + defaults = { + "metadata": {"embedding": [0.1, 0.2]}, + "path": "/docs/test.pdf", + "page_number": 1, + "text": "hello world", + } + defaults.update(kwargs) + return SimpleNamespace(**defaults) + + def test_returns_dict_with_expected_keys(self): + result = build_lancedb_row(self._row()) + assert result is not None + assert set(result.keys()) == { + "vector", "pdf_page", "filename", "pdf_basename", + "page_number", "source_id", "path", "metadata", "source", "text", + } + + def test_vector_extracted(self): + result = build_lancedb_row(self._row()) + assert result["vector"] == [0.1, 0.2] + + def test_path_fields(self): + result = build_lancedb_row(self._row()) + assert result["filename"] == "test.pdf" + assert result["pdf_basename"] == "test" + assert result["pdf_page"] == "test_1" + + def test_text_included(self): + result = build_lancedb_row(self._row()) + assert result["text"] == "hello world" + + def test_text_excluded(self): + result = build_lancedb_row(self._row(), include_text=False) + assert result["text"] == "" + + def test_metadata_json(self): + result = build_lancedb_row(self._row()) + meta = json.loads(result["metadata"]) + assert meta["page_number"] == 1 + assert meta["pdf_page"] == "test_1" + + def test_returns_none_when_no_embedding(self): + row = SimpleNamespace(metadata=None, path="/x.pdf", page_number=1, text="hi") + assert build_lancedb_row(row) is None + + def test_detection_metadata_included(self): + row = self._row( + page_elements_v3_num_detections=5, + page_elements_v3_counts_by_label={"text": 3, "figure": 2}, + table=[{}, {}], + ) + result = build_lancedb_row(row) + meta = json.loads(result["metadata"]) + assert meta["page_elements_v3_num_detections"] == 5 + assert meta["page_elements_v3_counts_by_label"] == {"text": 3, "figure": 2} + assert meta["ocr_table_detections"] == 2 + + +class TestBuildLancedbRows: + def test_filters_rows_without_embeddings(self): + import pandas as pd + + df = pd.DataFrame([ + {"metadata": {"embedding": [1.0]}, "path": "/a.pdf", "page_number": 1, "text": "a"}, + {"metadata": {}, "path": "/b.pdf", "page_number": 1, "text": "b"}, + ]) + rows = build_lancedb_rows(df) + assert len(rows) == 1 + assert rows[0]["vector"] == [1.0] + + +class TestLancedbSchema: + def test_returns_schema_with_correct_fields(self): + schema = lancedb_schema(768) + names = [f.name for f in schema] + assert "vector" in names + assert "text" in names + assert "metadata" in names + assert "source" in names + assert len(names) == 10 + + +class TestInferVectorDim: + def test_returns_dim(self): + assert infer_vector_dim([{"vector": [1.0, 2.0, 3.0]}]) == 3 + + def test_returns_zero_when_empty(self): + assert infer_vector_dim([]) == 0 + assert infer_vector_dim([{"vector": []}]) == 0 + + +class TestCreateOrAppendLancedbTable: + def test_overwrite_calls_create(self): + from unittest.mock import MagicMock + + db = MagicMock() + schema = MagicMock() + rows = [{"a": 1}] + create_or_append_lancedb_table(db, "test", rows, schema, overwrite=True) + db.create_table.assert_called_once_with("test", data=[{"a": 1}], schema=schema, mode="overwrite") + + def test_append_opens_then_adds(self): + from unittest.mock import MagicMock + + db = MagicMock() + table = MagicMock() + db.open_table.return_value = table + rows = [{"a": 1}] + result = create_or_append_lancedb_table(db, "t", rows, MagicMock(), overwrite=False) + db.open_table.assert_called_once_with("t") + table.add.assert_called_once() + assert result is table + + def test_append_falls_back_to_create(self): + from unittest.mock import MagicMock + + db = MagicMock() + db.open_table.side_effect = Exception("not found") + schema = MagicMock() + rows = [{"a": 1}] + create_or_append_lancedb_table(db, "t", rows, schema, overwrite=False) + db.create_table.assert_called_once_with("t", data=[{"a": 1}], schema=schema, mode="create") From 13933f7cd9d9cdbed92bafb94a4a97bcaaebebfc Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 19:48:58 -0500 Subject: [PATCH 03/12] Fix lint: remove unused imports, apply black formatting - Remove unused Path import and unused _extract_* aliases from inprocess.py - Remove unused pytest import from test_lancedb_utils.py - Apply black formatting to set literal and DataFrame constructor Made-with: Cursor --- .../nemo_retriever/ingest_modes/inprocess.py | 3 --- nemo_retriever/tests/test_lancedb_utils.py | 24 ++++++++++++------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index 0191abc4d..c07d2f44c 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -23,7 +23,6 @@ from datetime import datetime, timezone from io import BytesIO from collections.abc import Callable, Iterator -from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union @@ -637,8 +636,6 @@ def save_dataframe_to_disk_json(df: Any, *, output_directory: str) -> Any: from nemo_retriever.ingest_modes.lancedb_utils import ( build_lancedb_rows, create_or_append_lancedb_table, - extract_embedding_from_row as _extract_embedding_from_row, - extract_source_path_and_page as _extract_source_path_and_page, infer_vector_dim, lancedb_schema, ) diff --git a/nemo_retriever/tests/test_lancedb_utils.py b/nemo_retriever/tests/test_lancedb_utils.py index cf4d3cdcc..36a5296c3 100644 --- a/nemo_retriever/tests/test_lancedb_utils.py +++ b/nemo_retriever/tests/test_lancedb_utils.py @@ -7,8 +7,6 @@ import json from types import SimpleNamespace -import pytest - from nemo_retriever.ingest_modes.lancedb_utils import ( build_lancedb_row, build_lancedb_rows, @@ -83,8 +81,16 @@ def test_returns_dict_with_expected_keys(self): result = build_lancedb_row(self._row()) assert result is not None assert set(result.keys()) == { - "vector", "pdf_page", "filename", "pdf_basename", - "page_number", "source_id", "path", "metadata", "source", "text", + "vector", + "pdf_page", + "filename", + "pdf_basename", + "page_number", + "source_id", + "path", + "metadata", + "source", + "text", } def test_vector_extracted(self): @@ -132,10 +138,12 @@ class TestBuildLancedbRows: def test_filters_rows_without_embeddings(self): import pandas as pd - df = pd.DataFrame([ - {"metadata": {"embedding": [1.0]}, "path": "/a.pdf", "page_number": 1, "text": "a"}, - {"metadata": {}, "path": "/b.pdf", "page_number": 1, "text": "b"}, - ]) + df = pd.DataFrame( + [ + {"metadata": {"embedding": [1.0]}, "path": "/a.pdf", "page_number": 1, "text": "a"}, + {"metadata": {}, "path": "/b.pdf", "page_number": 1, "text": "b"}, + ] + ) rows = build_lancedb_rows(df) assert len(rows) == 1 assert rows[0]["vector"] == [1.0] From 894434f554bb10a2cab041a1c4c297a7decc08e1 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 19:51:47 -0500 Subject: [PATCH 04/12] Fix test collection: stub heavy sibling modules before lancedb_utils import The ingest_modes __init__.py eagerly imports batch/fused/inprocess/online which pull in ray, torch, etc. Pre-populate sys.modules with MagicMock stubs so lancedb_utils tests can run in lightweight CI without those deps. Made-with: Cursor --- nemo_retriever/tests/test_lancedb_utils.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/nemo_retriever/tests/test_lancedb_utils.py b/nemo_retriever/tests/test_lancedb_utils.py index 36a5296c3..e037083dc 100644 --- a/nemo_retriever/tests/test_lancedb_utils.py +++ b/nemo_retriever/tests/test_lancedb_utils.py @@ -5,9 +5,22 @@ """Unit tests for nemo_retriever.ingest_modes.lancedb_utils.""" import json +import sys from types import SimpleNamespace - -from nemo_retriever.ingest_modes.lancedb_utils import ( +from unittest.mock import MagicMock + +# The ingest_modes __init__.py eagerly imports batch/fused/inprocess/online, +# which pull in ray, torch, etc. Stub them so lancedb_utils can be imported +# in lightweight CI (matching the pattern in test_multimodal_embed.py). +for _mod_name in [ + "nemo_retriever.ingest_modes.batch", + "nemo_retriever.ingest_modes.fused", + "nemo_retriever.ingest_modes.inprocess", + "nemo_retriever.ingest_modes.online", +]: + sys.modules.setdefault(_mod_name, MagicMock()) + +from nemo_retriever.ingest_modes.lancedb_utils import ( # noqa: E402 build_lancedb_row, build_lancedb_rows, create_or_append_lancedb_table, From 7e63f8aedd0e74528463328079f473d337f8a8b4 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 18:03:39 -0500 Subject: [PATCH 05/12] Move duplicated recall helpers to recall/core.py and examples/common.py Centralises gold_to_doc_page, hit_key_and_distance, estimate_processed_pages, and print_pages_per_second that were duplicated across batch, inprocess, online, and fused pipeline examples. Fixes broken imports in fused_pipeline.py that referenced non-existent functions in batch_pipeline.py. Made-with: Cursor --- .../nemo_retriever/examples/batch_pipeline.py | 73 ++------------- .../src/nemo_retriever/examples/common.py | 51 +++++++++++ .../nemo_retriever/examples/fused_pipeline.py | 30 ++++--- .../examples/inprocess_pipeline.py | 90 +++---------------- .../examples/online_pipeline.py | 48 +++------- .../src/nemo_retriever/recall/core.py | 31 +++++++ 6 files changed, 127 insertions(+), 196 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/examples/common.py diff --git a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py index 36cb58034..137f31bc5 100644 --- a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py @@ -257,26 +257,15 @@ def _write_detection_summary(path: Path, summary: Optional[dict]) -> None: target.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") -def _print_pages_per_second(processed_pages: Optional[int], ingest_elapsed_s: float) -> None: - if ingest_elapsed_s <= 0: - print("Pages/sec: unavailable (ingest elapsed time was non-positive).") - return - if processed_pages is None: - print("Pages/sec: unavailable (could not estimate processed pages). " f"Ingest time: {ingest_elapsed_s:.2f}s") - return - - pps = processed_pages / ingest_elapsed_s - print(f"Pages processed: {processed_pages}") - print(f"Pages/sec (ingest only; excludes Ray startup and recall): {pps:.2f}") def _ensure_lancedb_table(uri: str, table_name: str) -> None: - """ - Ensure the local LanceDB URI exists and table can be opened. + """Ensure the local LanceDB URI exists and table can be opened. Creates an empty table with the expected schema if it does not exist yet. """ - # Local path URI in this pipeline. + from nemo_retriever.ingest_modes.lancedb_utils import lancedb_schema + Path(uri).mkdir(parents=True, exist_ok=True) db = _lancedb().connect(uri) @@ -288,63 +277,11 @@ def _ensure_lancedb_table(uri: str, table_name: str) -> None: import pyarrow as pa # type: ignore - schema = pa.schema( - [ - pa.field("vector", pa.list_(pa.float32(), 2048)), - pa.field("pdf_page", pa.string()), - pa.field("filename", pa.string()), - pa.field("pdf_basename", pa.string()), - pa.field("page_number", pa.int32()), - pa.field("source_id", pa.string()), - pa.field("path", pa.string()), - pa.field("text", pa.string()), - pa.field("metadata", pa.string()), - pa.field("source", pa.string()), - ] - ) - empty = pa.table( - { - "vector": [], - "pdf_page": [], - "filename": [], - "pdf_basename": [], - "page_number": [], - "source_id": [], - "path": [], - "text": [], - "metadata": [], - "source": [], - }, - schema=schema, - ) + schema = lancedb_schema(2048) + empty = pa.table({f.name: [] for f in schema}, schema=schema) db.create_table(table_name, data=empty, schema=schema, mode="create") -def _gold_to_doc_page(golden_key: str) -> tuple[str, str]: - s = str(golden_key) - if "_" not in s: - return s, "" - doc, page = s.rsplit("_", 1) - return doc, page - - -def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: - try: - res = json.loads(hit.get("metadata", "{}")) - source = json.loads(hit.get("source", "{}")) - except Exception: - return None, None - - source_id = source.get("source_id") - page_number = res.get("page_number") - if not source_id or page_number is None: - return None, float(hit.get("_distance")) if "_distance" in hit else None - - key = f"{Path(str(source_id)).stem}_{page_number}" - dist = float(hit["_distance"]) if "_distance" in hit else float(hit["_score"]) if "_score" in hit else None - return key, dist - - @app.command() def main( ctx: typer.Context, diff --git a/nemo_retriever/src/nemo_retriever/examples/common.py b/nemo_retriever/src/nemo_retriever/examples/common.py new file mode 100644 index 000000000..70a165bf6 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/examples/common.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared helpers used by multiple example pipeline scripts.""" + +from __future__ import annotations + +from typing import Optional + + +def estimate_processed_pages(uri: str, table_name: str) -> Optional[int]: + """Estimate pages processed by counting unique (source_id, page_number) pairs. + + Falls back to table row count if page-level fields are unavailable. + """ + try: + import lancedb # type: ignore + + db = lancedb.connect(uri) + table = db.open_table(table_name) + except Exception: + return None + + try: + df = table.to_pandas()[["source_id", "page_number"]] + return int(df.dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0]) + except Exception: + try: + return int(table.count_rows()) + except Exception: + return None + + +def print_pages_per_second( + processed_pages: Optional[int], + ingest_elapsed_s: float, + *, + label: str = "ingest only", +) -> None: + """Print a throughput summary line.""" + if ingest_elapsed_s <= 0: + print("Pages/sec: unavailable (ingest elapsed time was non-positive).") + return + if processed_pages is None: + print(f"Pages/sec: unavailable (could not estimate processed pages). Ingest time: {ingest_elapsed_s:.2f}s") + return + + pps = processed_pages / ingest_elapsed_s + print(f"Pages processed: {processed_pages}") + print(f"Pages/sec ({label}): {pps:.2f}") diff --git a/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py index b378ba69f..da327ff1f 100644 --- a/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py @@ -27,18 +27,20 @@ from nemo_retriever.examples.batch_pipeline import ( LANCEDB_TABLE, LANCEDB_URI, + _collect_detection_summary, _configure_logging, _ensure_lancedb_table, - _estimate_processed_pages, - _gold_to_doc_page, - _hit_key_and_distance, - _is_hit_at_k, _print_detection_summary, - _print_pages_per_second, _write_detection_summary, - _collect_detection_summary, ) -from nemo_retriever.recall.core import RecallConfig, retrieve_and_score +from nemo_retriever.examples.common import estimate_processed_pages, print_pages_per_second +from nemo_retriever.recall.core import ( + RecallConfig, + gold_to_doc_page, + hit_key_and_distance, + is_hit_at_k, + retrieve_and_score, +) app = typer.Typer() @@ -242,7 +244,7 @@ def main( ) ) ingest_elapsed_s = time.perf_counter() - ingest_start - processed_pages = _estimate_processed_pages(lancedb_uri, LANCEDB_TABLE) + processed_pages = estimate_processed_pages(lancedb_uri, LANCEDB_TABLE) detection_summary = _collect_detection_summary(lancedb_uri, LANCEDB_TABLE) print("Extraction complete.") _print_detection_summary(detection_summary) @@ -255,7 +257,7 @@ def main( query_csv = Path(query_csv) if not query_csv.exists(): print(f"Query CSV not found at {query_csv}; skipping recall evaluation.") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) return db = lancedb.connect(lancedb_uri) @@ -277,7 +279,7 @@ def main( try: if int(table.count_rows()) == 0: print(f"LanceDB table {LANCEDB_TABLE!r} exists but is empty; skipping recall evaluation.") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) return except Exception: pass @@ -305,16 +307,16 @@ def main( _raw_hits, ) ): - doc, page = _gold_to_doc_page(g) + doc, page = gold_to_doc_page(g) scored_hits: list[tuple[str, float | None]] = [] for h in hits: - key, dist = _hit_key_and_distance(h) + key, dist = hit_key_and_distance(h) if key: scored_hits.append((key, dist)) top_keys = [k for (k, _d) in scored_hits] - hit = _is_hit_at_k(g, top_keys, cfg.top_k) + hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page") if not no_recall_details: print(f"\nQuery {i}: {q}") @@ -345,7 +347,7 @@ def main( print("\nRecall metrics (matching nemo_retriever.recall.core):") for k, v in metrics.items(): print(f" {k}: {v:.4f}") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) finally: # Restore real stdio before closing the mirror file so exception hooks # and late flushes never write to a closed stream wrapper. diff --git a/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py index 238da6f2c..cec6ec3bf 100644 --- a/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py @@ -7,7 +7,6 @@ Run with: uv run python -m nemo_retriever.examples.inprocess_pipeline """ -import json import time from pathlib import Path from typing import Optional @@ -15,12 +14,19 @@ import lancedb import typer from nemo_retriever import create_ingestor +from nemo_retriever.examples.common import estimate_processed_pages, print_pages_per_second from nemo_retriever.params import EmbedParams from nemo_retriever.params import ExtractParams from nemo_retriever.params import IngestExecuteParams from nemo_retriever.params import TextChunkParams from nemo_retriever.params import VdbUploadParams -from nemo_retriever.recall.core import RecallConfig, retrieve_and_score +from nemo_retriever.recall.core import ( + RecallConfig, + gold_to_doc_page, + hit_key_and_distance, + is_hit_at_k, + retrieve_and_score, +) app = typer.Typer() @@ -28,74 +34,6 @@ LANCEDB_TABLE = "nv-ingest" -def _estimate_processed_pages(uri: str, table_name: str) -> Optional[int]: - """ - Estimate pages processed by counting unique (source_id, page_number) pairs. - - Falls back to table row count if page-level fields are unavailable. - """ - try: - db = lancedb.connect(uri) - table = db.open_table(table_name) - except Exception: - return None - - try: - df = table.to_pandas()[["source_id", "page_number"]] - return int(df.dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0]) - except Exception: - try: - return int(table.count_rows()) - except Exception: - return None - - -def _print_pages_per_second(processed_pages: Optional[int], ingest_elapsed_s: float) -> None: - if ingest_elapsed_s <= 0: - print("Pages/sec: unavailable (ingest elapsed time was non-positive).") - return - if processed_pages is None: - print("Pages/sec: unavailable (could not estimate processed pages). " f"Ingest time: {ingest_elapsed_s:.2f}s") - return - - pps = processed_pages / ingest_elapsed_s - print(f"Pages processed: {processed_pages}") - print(f"Pages/sec (ingest only): {pps:.2f}") - - -def _gold_to_doc_page(golden_key: str) -> tuple[str, str]: - s = str(golden_key) - if "_" not in s: - return s, "" - doc, page = s.rsplit("_", 1) - return doc, page - - -def _is_hit_at_k(golden_key: str, retrieved_keys: list[str], k: int) -> bool: - doc, page = _gold_to_doc_page(golden_key) - specific_page = f"{doc}_{page}" - entire_document = f"{doc}_-1" - top = (retrieved_keys or [])[: int(k)] - return (specific_page in top) or (entire_document in top) - - -def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: - try: - res = json.loads(hit.get("metadata", "{}")) - source = json.loads(hit.get("source", "{}")) - except Exception: - return None, None - - source_id = source.get("source_id") - page_number = res.get("page_number") - if not source_id or page_number is None: - return None, float(hit.get("_distance")) if "_distance" in hit else None - - key = f"{Path(str(source_id)).stem}_{page_number}" - dist = float(hit.get("_distance")) if "_distance" in hit else None - return key, dist - - @app.command() def main( input_path: Path = typer.Argument( @@ -388,7 +326,7 @@ def main( ) ) ingest_elapsed_s = time.perf_counter() - ingest_start - processed_pages = _estimate_processed_pages(LANCEDB_URI, LANCEDB_TABLE) + processed_pages = estimate_processed_pages(LANCEDB_URI, LANCEDB_TABLE) print("Extraction complete.") # --------------------------------------------------------------------------- @@ -397,7 +335,7 @@ def main( query_csv = Path(query_csv) if not query_csv.exists(): print(f"Query CSV not found at {query_csv}; skipping recall evaluation.") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) return db = lancedb.connect(f"./{LANCEDB_URI}") @@ -432,16 +370,16 @@ def main( _raw_hits, ) ): - doc, page = _gold_to_doc_page(g) + doc, page = gold_to_doc_page(g) scored_hits: list[tuple[str, float | None]] = [] for h in hits: - key, dist = _hit_key_and_distance(h) + key, dist = hit_key_and_distance(h) if key: scored_hits.append((key, dist)) top_keys = [k for (k, _d) in scored_hits] - hit = _is_hit_at_k(g, top_keys, cfg.top_k) + hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page") if not no_recall_details: ext = ( @@ -482,7 +420,7 @@ def main( print("\nRecall metrics (matching nemo_retriever.recall.core):") for k, v in metrics.items(): print(f" {k}: {v:.4f}") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) if __name__ == "__main__": diff --git a/nemo_retriever/src/nemo_retriever/examples/online_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/online_pipeline.py index d72fea7cf..f2da17b5d 100644 --- a/nemo_retriever/src/nemo_retriever/examples/online_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/online_pipeline.py @@ -15,7 +15,6 @@ --run-mode online --base-url http://localhost:7670 """ -import json from pathlib import Path import lancedb @@ -27,7 +26,13 @@ from nemo_retriever.params import IngestorCreateParams from nemo_retriever.params import TextChunkParams from nemo_retriever.params import VdbUploadParams -from nemo_retriever.recall.core import RecallConfig, retrieve_and_score +from nemo_retriever.recall.core import ( + RecallConfig, + gold_to_doc_page, + hit_key_and_distance, + is_hit_at_k, + retrieve_and_score, +) app = typer.Typer() @@ -35,39 +40,6 @@ LANCEDB_TABLE = "nv-ingest" -def _gold_to_doc_page(golden_key: str) -> tuple[str, str]: - s = str(golden_key) - if "_" not in s: - return s, "" - doc, page = s.rsplit("_", 1) - return doc, page - - -def _is_hit_at_k(golden_key: str, retrieved_keys: list[str], k: int) -> bool: - doc, page = _gold_to_doc_page(golden_key) - specific_page = f"{doc}_{page}" - entire_document = f"{doc}_-1" - top = (retrieved_keys or [])[: int(k)] - return (specific_page in top) or (entire_document in top) - - -def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: - try: - res = json.loads(hit.get("metadata", "{}")) - source = json.loads(hit.get("source", "{}")) - except Exception: - return None, None - - source_id = source.get("source_id") - page_number = res.get("page_number") - if not source_id or page_number is None: - return None, float(hit.get("_distance")) if "_distance" in hit else None - - key = f"{Path(str(source_id)).stem}_{page_number}" - dist = float(hit.get("_distance")) if "_distance" in hit else None - return key, dist - - @app.command() def main( input_path: Path = typer.Argument( @@ -236,14 +208,14 @@ def main( _raw_hits, ) ): - doc, page = _gold_to_doc_page(g) + doc, page = gold_to_doc_page(g) scored_hits: list[tuple[str, float | None]] = [] for h in hits: - key, dist = _hit_key_and_distance(h) + key, dist = hit_key_and_distance(h) if key: scored_hits.append((key, dist)) top_keys = [k for (k, _d) in scored_hits] - hit = _is_hit_at_k(g, top_keys, cfg.top_k) + hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page") if not no_recall_details: ext = ".txt" if input_type == "txt" else (".docx" if input_type == "doc" else ".pdf") typer.echo(f"\nQuery {i}: {q}") diff --git a/nemo_retriever/src/nemo_retriever/recall/core.py b/nemo_retriever/src/nemo_retriever/recall/core.py index b684a3f6d..95c540850 100644 --- a/nemo_retriever/src/nemo_retriever/recall/core.py +++ b/nemo_retriever/src/nemo_retriever/recall/core.py @@ -299,6 +299,37 @@ def is_hit_at_k(golden_key: str, retrieved: Sequence[str], k: int, *, match_mode return _is_hit(str(golden_key), list(retrieved), int(k), match_mode=str(match_mode)) +def gold_to_doc_page(golden_key: str) -> tuple[str, str]: + """Split a golden key like ``"docname_page"`` into ``(doc, page)``.""" + s = str(golden_key) + if "_" not in s: + return s, "" + doc, page = s.rsplit("_", 1) + return doc, page + + +def hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: + """Extract ``(pdf_page key, distance)`` from a single LanceDB hit dict. + + Supports both ``_distance`` and ``_score`` fields for compatibility across + LanceDB query types (vector vs hybrid). + """ + try: + res = json.loads(hit.get("metadata", "{}")) + source = json.loads(hit.get("source", "{}")) + except Exception: + return None, None + + source_id = source.get("source_id") + page_number = res.get("page_number") + if not source_id or page_number is None: + return None, float(hit.get("_distance")) if "_distance" in hit else None + + key = f"{Path(str(source_id)).stem}_{page_number}" + dist = float(hit["_distance"]) if "_distance" in hit else float(hit["_score"]) if "_score" in hit else None + return key, dist + + def _recall_at_k(gold: List[str], retrieved: List[List[str]], k: int, *, match_mode: str) -> float: hits = sum(is_hit_at_k(g, r, k, match_mode=match_mode) for g, r in zip(gold, retrieved)) return hits / max(1, len(gold)) From 5109a34ab2730d1887f3bfd5dfd75ce24b156300 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 18:04:59 -0500 Subject: [PATCH 06/12] Unify detection summary logic between batch and inprocess pipelines Extracts duplicated detection summary computation and printing into a shared utils/detection_summary.py module, replacing ~200 lines of near-identical logic in batch_pipeline.py and inprocess.py with thin wrappers around the shared implementation. Made-with: Cursor --- .../nemo_retriever/examples/batch_pipeline.py | 104 +--------- .../nemo_retriever/ingest_modes/inprocess.py | 114 +---------- .../nemo_retriever/utils/detection_summary.py | 177 ++++++++++++++++++ 3 files changed, 189 insertions(+), 206 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/utils/detection_summary.py diff --git a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py index 137f31bc5..a470e553d 100644 --- a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py @@ -12,7 +12,6 @@ import os import sys import time -from collections import defaultdict from importlib import import_module from pathlib import Path from typing import Optional, TextIO @@ -116,107 +115,16 @@ def _to_int(value: object, default: int = 0) -> int: def _collect_detection_summary(uri: str, table_name: str) -> Optional[dict]: - """ - Collect per-model detection totals deduped by (source_id, page_number). - - Counts are read from LanceDB row `metadata`, which is populated during batch - ingestion by the Ray write stage. - """ - try: - db = _lancedb().connect(uri) - table = db.open_table(table_name) - df = table.to_pandas()[["source_id", "page_number", "metadata"]] - except Exception: - return None - - # Deduplicate exploded rows by page key; keep max per-page counts. - per_page: dict[tuple[str, int], dict] = {} - for row in df.itertuples(index=False): - source_id = str(getattr(row, "source_id", "") or "") - page_number = _to_int(getattr(row, "page_number", -1), default=-1) - key = (source_id, page_number) - - raw_metadata = getattr(row, "metadata", None) - meta: dict = {} - if isinstance(raw_metadata, str) and raw_metadata.strip(): - try: - parsed = json.loads(raw_metadata) - if isinstance(parsed, dict): - meta = parsed - except Exception: - meta = {} - - entry = per_page.setdefault( - key, - { - "page_elements_total": 0, - "ocr_table_total": 0, - "ocr_chart_total": 0, - "ocr_infographic_total": 0, - "page_elements_by_label": defaultdict(int), - }, - ) - - pe_total = _to_int(meta.get("page_elements_v3_num_detections"), default=0) - entry["page_elements_total"] = max(entry["page_elements_total"], pe_total) - - ocr_table = _to_int(meta.get("ocr_table_detections"), default=0) - ocr_chart = _to_int(meta.get("ocr_chart_detections"), default=0) - ocr_infographic = _to_int(meta.get("ocr_infographic_detections"), default=0) - entry["ocr_table_total"] = max(entry["ocr_table_total"], ocr_table) - entry["ocr_chart_total"] = max(entry["ocr_chart_total"], ocr_chart) - entry["ocr_infographic_total"] = max(entry["ocr_infographic_total"], ocr_infographic) - - label_counts = meta.get("page_elements_v3_counts_by_label") - if isinstance(label_counts, dict): - for label, count in label_counts.items(): - if not isinstance(label, str): - continue - entry["page_elements_by_label"][label] = max( - entry["page_elements_by_label"][label], - _to_int(count, default=0), - ) + """Collect per-model detection totals deduped by (source_id, page_number).""" + from nemo_retriever.utils.detection_summary import collect_detection_summary_from_lancedb - pe_by_label_totals: dict[str, int] = defaultdict(int) - page_elements_total = 0 - ocr_table_total = 0 - ocr_chart_total = 0 - ocr_infographic_total = 0 - for page_entry in per_page.values(): - page_elements_total += int(page_entry["page_elements_total"]) - ocr_table_total += int(page_entry["ocr_table_total"]) - ocr_chart_total += int(page_entry["ocr_chart_total"]) - ocr_infographic_total += int(page_entry["ocr_infographic_total"]) - for label, count in page_entry["page_elements_by_label"].items(): - pe_by_label_totals[label] += int(count) - - return { - "pages_seen": int(len(per_page)), - "page_elements_v3_total_detections": int(page_elements_total), - "page_elements_v3_counts_by_label": dict(sorted(pe_by_label_totals.items())), - "ocr_table_total_detections": int(ocr_table_total), - "ocr_chart_total_detections": int(ocr_chart_total), - "ocr_infographic_total_detections": int(ocr_infographic_total), - } + return collect_detection_summary_from_lancedb(uri, table_name) def _print_detection_summary(summary: Optional[dict]) -> None: - if summary is None: - print("Detection summary: unavailable (could not read LanceDB metadata).") - return - print("\nDetection summary (deduped by source_id/page_number):") - print(f" Pages seen: {summary['pages_seen']}") - print(f" PageElements v3 total detections: {summary['page_elements_v3_total_detections']}") - print(f" OCR table detections: {summary['ocr_table_total_detections']}") - print(f" OCR chart detections: {summary['ocr_chart_total_detections']}") - print(f" OCR infographic detections: {summary['ocr_infographic_total_detections']}") - print(" PageElements v3 counts by label:") - by_label = summary.get("page_elements_v3_counts_by_label") or {} - if not by_label: - print(" (none)") - else: - for label, count in by_label.items(): - print(f" {label}: {count}") + from nemo_retriever.utils.detection_summary import print_detection_summary + + print_detection_summary(summary) def _extract_error_payloads(v: object) -> list[object]: diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index c07d2f44c..755ba0d18 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -907,105 +907,16 @@ def _process_chunk_cpu(chunk_df: pd.DataFrame, cpu_tasks: list) -> pd.DataFrame: def _collect_summary_from_df(df: pd.DataFrame) -> dict: - """Compute detection summary from a result DataFrame. + """Compute detection summary from a result DataFrame.""" + from nemo_retriever.utils.detection_summary import collect_detection_summary_from_df - Mirrors the batch pipeline's ``_collect_detection_summary`` but reads - directly from the in-memory DataFrame instead of LanceDB. Rows are - deduplicated by ``(path, page_number)`` so exploded content rows don't - inflate counts. - """ - per_page: dict[tuple, dict] = {} - - for _, row in df.iterrows(): - row_dict = row.to_dict() - - path = str(row_dict.get("path") or row_dict.get("source_id") or "") - page_number = -1 - try: - page_number = int(row_dict.get("page_number", -1)) - except (TypeError, ValueError): - pass - - key = (path, page_number) - - meta = row_dict.get("metadata") - if isinstance(meta, str): - try: - meta = json.loads(meta) - except Exception: - meta = {} - if not isinstance(meta, dict): - meta = {} - - entry = per_page.setdefault( - key, - { - "pe": 0, - "ocr_table": 0, - "ocr_chart": 0, - "ocr_infographic": 0, - "pe_by_label": defaultdict(int), - }, - ) - - # Check metadata first, then fall back to direct DataFrame columns. - # The batch pipeline stores these inside the metadata JSON, but the - # inprocess pipeline keeps them as top-level DataFrame columns. - try: - pe = int( - meta.get("page_elements_v3_num_detections") or row_dict.get("page_elements_v3_num_detections") or 0 - ) - except (TypeError, ValueError): - pe = 0 - entry["pe"] = max(entry["pe"], pe) - - for field, meta_key, col_key in [ - ("ocr_table", "ocr_table_detections", "table"), - ("ocr_chart", "ocr_chart_detections", "chart"), - ("ocr_infographic", "ocr_infographic_detections", "infographic"), - ]: - try: - val = int(meta.get(meta_key, 0) or 0) - except (TypeError, ValueError): - val = 0 - # Fall back to counting direct list columns (e.g. row["table"]). - if val == 0: - col_val = row_dict.get(col_key) - if isinstance(col_val, list): - val = len(col_val) - entry[field] = max(entry[field], val) - - label_counts = meta.get("page_elements_v3_counts_by_label") or row_dict.get("page_elements_v3_counts_by_label") - if isinstance(label_counts, dict): - for label, count in label_counts.items(): - try: - c = int(count or 0) - except (TypeError, ValueError): - c = 0 - entry["pe_by_label"][str(label)] = max(entry["pe_by_label"][str(label)], c) - - pe_by_label_totals: dict[str, int] = defaultdict(int) - pe_total = ocr_table_total = ocr_chart_total = ocr_infographic_total = 0 - for e in per_page.values(): - pe_total += e["pe"] - ocr_table_total += e["ocr_table"] - ocr_chart_total += e["ocr_chart"] - ocr_infographic_total += e["ocr_infographic"] - for label, count in e["pe_by_label"].items(): - pe_by_label_totals[label] += count - - return { - "pages_seen": len(per_page), - "page_elements_v3_total_detections": pe_total, - "page_elements_v3_counts_by_label": dict(sorted(pe_by_label_totals.items())), - "ocr_table_total_detections": ocr_table_total, - "ocr_chart_total_detections": ocr_chart_total, - "ocr_infographic_total_detections": ocr_infographic_total, - } + return collect_detection_summary_from_df(df) def _print_ingest_summary(results: list, elapsed_s: float) -> None: """Print end-of-ingest summary matching batch pipeline output format.""" + from nemo_retriever.utils.detection_summary import print_detection_summary + dfs = [r for r in results if isinstance(r, pd.DataFrame) and not r.empty] if not dfs: print(f"\nIngest time: {elapsed_s:.2f}s (no documents processed)") @@ -1013,20 +924,7 @@ def _print_ingest_summary(results: list, elapsed_s: float) -> None: combined = pd.concat(dfs, ignore_index=True) if len(dfs) > 1 else dfs[0] summary = _collect_summary_from_df(combined) - - print("\nDetection summary (deduped by source/page_number):") - print(f" Pages seen: {summary['pages_seen']}") - print(f" PageElements v3 total detections: {summary['page_elements_v3_total_detections']}") - print(f" OCR table detections: {summary['ocr_table_total_detections']}") - print(f" OCR chart detections: {summary['ocr_chart_total_detections']}") - print(f" OCR infographic detections: {summary['ocr_infographic_total_detections']}") - print(" PageElements v3 counts by label:") - by_label = summary.get("page_elements_v3_counts_by_label", {}) - if not by_label: - print(" (none)") - else: - for label, count in by_label.items(): - print(f" {label}: {count}") + print_detection_summary(summary) pages = summary["pages_seen"] if elapsed_s > 0 and pages > 0: diff --git a/nemo_retriever/src/nemo_retriever/utils/detection_summary.py b/nemo_retriever/src/nemo_retriever/utils/detection_summary.py new file mode 100644 index 000000000..b2ecdb7f5 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/utils/detection_summary.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared detection summary logic. + +Provides a single function that accumulates per-page detection counters from +an iterable of ``(page_key, metadata_dict, row_dict)`` tuples. Both the +batch pipeline (reading from LanceDB) and inprocess pipeline (reading from +a DataFrame) can produce these tuples, allowing the summary computation to +be shared. +""" + +from __future__ import annotations + +import json +from collections import defaultdict +from typing import Any, Dict, Iterable, Optional, Tuple + + +def _safe_int(value: object, default: int = 0) -> int: + try: + if value is None: + return default + return int(value) + except Exception: + return default + + +def compute_detection_summary( + rows: Iterable[Tuple[Any, Dict[str, Any], Dict[str, Any]]], +) -> Dict[str, Any]: + """Compute deduped detection totals from an iterable of page data. + + Each element is ``(page_key, metadata_dict, row_dict)`` where: + + - *page_key* is a hashable value used to deduplicate exploded content rows + (e.g. ``(source_id, page_number)``). + - *metadata_dict* is the parsed JSON metadata (may contain counters from the + LanceDB metadata column or from direct DataFrame columns). + - *row_dict* is the raw row dict, used as fallback for counters stored as + top-level DataFrame columns (e.g. ``table``, ``chart`` lists). + """ + per_page: dict[Any, dict] = {} + + for page_key, meta, raw_row in rows: + entry = per_page.setdefault( + page_key, + { + "pe": 0, + "ocr_table": 0, + "ocr_chart": 0, + "ocr_infographic": 0, + "pe_by_label": defaultdict(int), + }, + ) + + pe = _safe_int( + meta.get("page_elements_v3_num_detections") or raw_row.get("page_elements_v3_num_detections") + ) + entry["pe"] = max(entry["pe"], pe) + + for field, meta_key, col_key in [ + ("ocr_table", "ocr_table_detections", "table"), + ("ocr_chart", "ocr_chart_detections", "chart"), + ("ocr_infographic", "ocr_infographic_detections", "infographic"), + ]: + val = _safe_int(meta.get(meta_key)) + if val == 0: + col_val = raw_row.get(col_key) + if isinstance(col_val, list): + val = len(col_val) + entry[field] = max(entry[field], val) + + label_counts = meta.get("page_elements_v3_counts_by_label") or raw_row.get( + "page_elements_v3_counts_by_label" + ) + if isinstance(label_counts, dict): + for label, count in label_counts.items(): + entry["pe_by_label"][str(label)] = max( + entry["pe_by_label"][str(label)], + _safe_int(count), + ) + + pe_by_label_totals: dict[str, int] = defaultdict(int) + pe_total = ocr_table_total = ocr_chart_total = ocr_infographic_total = 0 + for e in per_page.values(): + pe_total += e["pe"] + ocr_table_total += e["ocr_table"] + ocr_chart_total += e["ocr_chart"] + ocr_infographic_total += e["ocr_infographic"] + for label, count in e["pe_by_label"].items(): + pe_by_label_totals[label] += count + + return { + "pages_seen": len(per_page), + "page_elements_v3_total_detections": pe_total, + "page_elements_v3_counts_by_label": dict(sorted(pe_by_label_totals.items())), + "ocr_table_total_detections": ocr_table_total, + "ocr_chart_total_detections": ocr_chart_total, + "ocr_infographic_total_detections": ocr_infographic_total, + } + + +def iter_lancedb_rows(uri: str, table_name: str): + """Yield ``(page_key, meta, row_dict)`` tuples from a LanceDB table.""" + import lancedb # type: ignore + + db = lancedb.connect(uri) + table = db.open_table(table_name) + df = table.to_pandas()[["source_id", "page_number", "metadata"]] + + for row in df.itertuples(index=False): + source_id = str(getattr(row, "source_id", "") or "") + page_number = _safe_int(getattr(row, "page_number", -1), default=-1) + raw_metadata = getattr(row, "metadata", None) + meta: dict = {} + if isinstance(raw_metadata, str) and raw_metadata.strip(): + try: + parsed = json.loads(raw_metadata) + if isinstance(parsed, dict): + meta = parsed + except Exception: + pass + yield (source_id, page_number), meta, {} + + +def iter_dataframe_rows(df): + """Yield ``(page_key, meta, row_dict)`` tuples from a pandas DataFrame.""" + for _, row in df.iterrows(): + row_dict = row.to_dict() + path = str(row_dict.get("path") or row_dict.get("source_id") or "") + page_number = _safe_int(row_dict.get("page_number", -1), default=-1) + + meta = row_dict.get("metadata") + if isinstance(meta, str): + try: + meta = json.loads(meta) + except Exception: + meta = {} + if not isinstance(meta, dict): + meta = {} + + yield (path, page_number), meta, row_dict + + +def collect_detection_summary_from_lancedb(uri: str, table_name: str) -> Optional[Dict[str, Any]]: + """Collect detection summary from a LanceDB table.""" + try: + return compute_detection_summary(iter_lancedb_rows(uri, table_name)) + except Exception: + return None + + +def collect_detection_summary_from_df(df) -> Dict[str, Any]: + """Collect detection summary from a pandas DataFrame.""" + return compute_detection_summary(iter_dataframe_rows(df)) + + +def print_detection_summary(summary: Optional[Dict[str, Any]]) -> None: + """Print a detection summary to stdout.""" + if summary is None: + print("Detection summary: unavailable (could not read metadata).") + return + print("\nDetection summary (deduped by source_id/page_number):") + print(f" Pages seen: {summary['pages_seen']}") + print(f" PageElements v3 total detections: {summary['page_elements_v3_total_detections']}") + print(f" OCR table detections: {summary['ocr_table_total_detections']}") + print(f" OCR chart detections: {summary['ocr_chart_total_detections']}") + print(f" OCR infographic detections: {summary['ocr_infographic_total_detections']}") + print(" PageElements v3 counts by label:") + by_label = summary.get("page_elements_v3_counts_by_label") or {} + if not by_label: + print(" (none)") + else: + for label, count in by_label.items(): + print(f" {label}: {count}") From f9b62ac03f5f659f34fd4091392742413ffdf422 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 18:05:47 -0500 Subject: [PATCH 07/12] Extract shared parameter coercion and embed kwargs helpers Consolidates the duplicated _coerce_params pattern and embed parameter flattening logic from batch.py and inprocess.py into a shared params/utils.py module with coerce_params and build_embed_kwargs helpers. Made-with: Cursor --- .../src/nemo_retriever/ingest_modes/batch.py | 20 ++----- .../nemo_retriever/ingest_modes/inprocess.py | 18 ++---- .../src/nemo_retriever/params/utils.py | 43 ++++++++++++++ nemo_retriever/tests/test_params_utils.py | 57 +++++++++++++++++++ 4 files changed, 108 insertions(+), 30 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/params/utils.py create mode 100644 nemo_retriever/tests/test_params_utils.py diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py index d38765a3f..4b26e9887 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py @@ -70,12 +70,7 @@ def _debug_log(*, logger: logging.Logger, location: str, message: str, data: dic logger.debug("%s | %s | %r", location, message, data) -def _coerce_params[T](params: T | None, model_cls: type[T], kwargs: dict[str, Any]) -> T: - if params is None: - return model_cls(**kwargs) - if kwargs: - return params.model_copy(update=kwargs) # type: ignore[return-value] - return params +from nemo_retriever.params.utils import coerce_params as _coerce_params class _LanceDBWriteActor: @@ -668,17 +663,10 @@ def embed( "No Ray Dataset to embed. Provide input_dataset or run .files(...) / .extract(...) first." ) - resolved = _coerce_params(params, EmbedParams, kwargs) - kwargs = { - **resolved.model_dump( - mode="python", exclude={"runtime", "batch_tuning", "fused_tuning"}, exclude_none=True - ), - **resolved.runtime.model_dump(mode="python", exclude_none=True), - **resolved.batch_tuning.model_dump(mode="python", exclude_none=True), - } + from nemo_retriever.params.utils import build_embed_kwargs - if "embedding_endpoint" not in kwargs and kwargs.get("embed_invoke_url"): - kwargs["embedding_endpoint"] = kwargs.get("embed_invoke_url") + resolved = _coerce_params(params, EmbedParams, kwargs) + kwargs = build_embed_kwargs(resolved, include_batch_tuning=True) # Remaining kwargs are forwarded to the actor constructor. embed_modality = resolved.embed_modality diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index 755ba0d18..9709d2147 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -63,12 +63,7 @@ _CONTENT_COLUMNS = ("table", "chart", "infographic") -def _coerce_params[T](params: T | None, model_cls: type[T], kwargs: dict[str, Any]) -> T: - if params is None: - return model_cls(**kwargs) - if kwargs: - return params.model_copy(update=kwargs) # type: ignore[return-value] - return params +from nemo_retriever.params.utils import coerce_params as _coerce_params def _combine_text_with_content(row, text_column, content_columns): @@ -1248,14 +1243,9 @@ def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "InProcessI ) ) - embed_kwargs = { - **resolved.model_dump( - mode="python", exclude={"runtime", "batch_tuning", "fused_tuning"}, exclude_none=True - ), - **resolved.runtime.model_dump(mode="python", exclude_none=True), - } - if "embedding_endpoint" not in embed_kwargs and embed_kwargs.get("embed_invoke_url"): - embed_kwargs["embedding_endpoint"] = embed_kwargs.get("embed_invoke_url") + from nemo_retriever.params.utils import build_embed_kwargs + + embed_kwargs = build_embed_kwargs(resolved) # Ensure embed_modality is forwarded to the embedding function. embed_kwargs["embed_modality"] = embed_modality diff --git a/nemo_retriever/src/nemo_retriever/params/utils.py b/nemo_retriever/src/nemo_retriever/params/utils.py new file mode 100644 index 000000000..8b076e169 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/params/utils.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared parameter coercion and building helpers used by ingest modes.""" + +from __future__ import annotations + +from typing import Any, Dict + + +def coerce_params[T](params: T | None, model_cls: type[T], kwargs: dict[str, Any]) -> T: + """Merge *params* and *kwargs* into an instance of *model_cls*. + + - If *params* is ``None``, construct from *kwargs*. + - If *kwargs* is non-empty, apply them as overrides via ``model_copy``. + - Otherwise return *params* unchanged. + """ + if params is None: + return model_cls(**kwargs) + if kwargs: + return params.model_copy(update=kwargs) # type: ignore[return-value] + return params + + +def build_embed_kwargs(resolved: Any, *, include_batch_tuning: bool = False) -> Dict[str, Any]: + """Flatten an ``EmbedParams`` instance into a dict ready for actor/task kwargs. + + Merges ``runtime`` (always) and optionally ``batch_tuning`` sub-models. + Also normalises ``embed_invoke_url`` → ``embedding_endpoint``. + """ + exclude = {"runtime", "batch_tuning", "fused_tuning"} + kwargs: Dict[str, Any] = { + **resolved.model_dump(mode="python", exclude=exclude, exclude_none=True), + **resolved.runtime.model_dump(mode="python", exclude_none=True), + } + if include_batch_tuning: + kwargs.update(resolved.batch_tuning.model_dump(mode="python", exclude_none=True)) + + if "embedding_endpoint" not in kwargs and kwargs.get("embed_invoke_url"): + kwargs["embedding_endpoint"] = kwargs["embed_invoke_url"] + + return kwargs diff --git a/nemo_retriever/tests/test_params_utils.py b/nemo_retriever/tests/test_params_utils.py new file mode 100644 index 000000000..ea96fb3d0 --- /dev/null +++ b/nemo_retriever/tests/test_params_utils.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for nemo_retriever.params.utils.""" + +import pytest + +from nemo_retriever.params.models import EmbedParams +from nemo_retriever.params.utils import build_embed_kwargs, coerce_params + + +class TestCoerceParams: + def test_none_params_constructs_from_kwargs(self): + result = coerce_params(None, EmbedParams, {"embed_modality": "image"}) + assert isinstance(result, EmbedParams) + assert result.embed_modality == "image" + + def test_params_without_kwargs_returned_unchanged(self): + original = EmbedParams(embed_modality="text") + result = coerce_params(original, EmbedParams, {}) + assert result is original + + def test_params_with_kwargs_applies_overrides(self): + original = EmbedParams(embed_modality="text") + result = coerce_params(original, EmbedParams, {"embed_modality": "image"}) + assert result.embed_modality == "image" + assert result is not original + + +class TestBuildEmbedKwargs: + def test_normalises_embed_invoke_url(self): + params = EmbedParams(embed_invoke_url="http://nim:8000/v1") + kwargs = build_embed_kwargs(params) + assert kwargs["embedding_endpoint"] == "http://nim:8000/v1" + + def test_does_not_overwrite_existing_embedding_endpoint(self): + params = EmbedParams( + embed_invoke_url="http://old:8000/v1", + ) + kwargs = build_embed_kwargs(params) + assert "embedding_endpoint" in kwargs + + def test_includes_batch_tuning_when_requested(self): + params = EmbedParams() + with_bt = build_embed_kwargs(params, include_batch_tuning=True) + without_bt = build_embed_kwargs(params, include_batch_tuning=False) + # batch_tuning keys should be present when included + assert isinstance(with_bt, dict) + assert isinstance(without_bt, dict) + + def test_excludes_nested_sub_models(self): + params = EmbedParams() + kwargs = build_embed_kwargs(params) + assert "runtime" not in kwargs + assert "batch_tuning" not in kwargs + assert "fused_tuning" not in kwargs From 14aff09d8e2b2c0d5c0c2ac1da79f0b22a936c41 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Sat, 7 Mar 2026 07:47:21 -0500 Subject: [PATCH 08/12] Remove ray balancing variants --- .github/workflows/pypi-nightly-publish.yml | 13 +- .github/workflows/reusable-pypi-build.yml | 13 +- nemo_retriever/RAY_BALANCING.md | 162 ------------ nemo_retriever/chart_stage_config.yaml | 53 ---- nemo_retriever/embedding_stage_config.yaml | 43 ---- nemo_retriever/infographic_stage_config.yaml | 33 --- nemo_retriever/pdf_stage_config.yaml | 49 ---- nemo_retriever/pyproject.toml | 2 +- nemo_retriever/ray_balance_variants.csv | 240 ------------------ nemo_retriever/src/nemo_retriever/__init__.py | 2 +- nemo_retriever/src/nemo_retriever/__main__.py | 9 - .../src/nemo_retriever/_build_info.py | 8 - nemo_retriever/src/nemo_retriever/version.py | 10 +- nemo_retriever/table_stage_config.yaml | 38 --- 14 files changed, 19 insertions(+), 656 deletions(-) delete mode 100644 nemo_retriever/RAY_BALANCING.md delete mode 100644 nemo_retriever/chart_stage_config.yaml delete mode 100644 nemo_retriever/embedding_stage_config.yaml delete mode 100644 nemo_retriever/infographic_stage_config.yaml delete mode 100644 nemo_retriever/pdf_stage_config.yaml delete mode 100644 nemo_retriever/ray_balance_variants.csv delete mode 100644 nemo_retriever/src/nemo_retriever/__main__.py delete mode 100644 nemo_retriever/src/nemo_retriever/_build_info.py delete mode 100644 nemo_retriever/table_stage_config.yaml diff --git a/.github/workflows/pypi-nightly-publish.yml b/.github/workflows/pypi-nightly-publish.yml index 98ea26927..40549a429 100644 --- a/.github/workflows/pypi-nightly-publish.yml +++ b/.github/workflows/pypi-nightly-publish.yml @@ -121,15 +121,16 @@ jobs: run: | cd nemo_retriever python - <<'PY' + import re from datetime import datetime, timezone from pathlib import Path - Path("src/nemo_retriever/_build_info.py").write_text( - '"""Build metadata written by CI before packaging."""\n\n' - 'BUILD_GIT_SHA = "${{ github.sha }}"\n' - f'BUILD_DATE = "{datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")}"\n', - encoding="utf-8", - ) + vf = Path("src/nemo_retriever/version.py") + src = vf.read_text(encoding="utf-8") + src = re.sub(r'^_PACKAGE_BUILD_GIT_SHA = .*$', '_PACKAGE_BUILD_GIT_SHA = "${{ github.sha }}"', src, flags=re.M) + build_date = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + src = re.sub(r'^_PACKAGE_BUILD_DATE = .*$', f'_PACKAGE_BUILD_DATE = "{build_date}"', src, flags=re.M) + vf.write_text(src, encoding="utf-8") PY RETRIEVER_RELEASE_TYPE=${{ env.RELEASE_TYPE }} \ RETRIEVER_VERSION=${{ env.VERSION }} \ diff --git a/.github/workflows/reusable-pypi-build.yml b/.github/workflows/reusable-pypi-build.yml index 7f9947d1d..7df9091a8 100644 --- a/.github/workflows/reusable-pypi-build.yml +++ b/.github/workflows/reusable-pypi-build.yml @@ -91,15 +91,16 @@ jobs: run: | cd nemo_retriever python - <<'PY' + import re from datetime import datetime, timezone from pathlib import Path - Path("src/nemo_retriever/_build_info.py").write_text( - '"""Build metadata written by CI before packaging."""\n\n' - 'BUILD_GIT_SHA = "${{ github.sha }}"\n' - f'BUILD_DATE = "{datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")}"\n', - encoding="utf-8", - ) + vf = Path("src/nemo_retriever/version.py") + src = vf.read_text(encoding="utf-8") + src = re.sub(r'^_PACKAGE_BUILD_GIT_SHA = .*$', '_PACKAGE_BUILD_GIT_SHA = "${{ github.sha }}"', src, flags=re.M) + build_date = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + src = re.sub(r'^_PACKAGE_BUILD_DATE = .*$', f'_PACKAGE_BUILD_DATE = "{build_date}"', src, flags=re.M) + vf.write_text(src, encoding="utf-8") PY RETRIEVER_RELEASE_TYPE=${{ inputs.release-type }} \ RETRIEVER_VERSION=${{ steps.set-version.outputs.version }} \ diff --git a/nemo_retriever/RAY_BALANCING.md b/nemo_retriever/RAY_BALANCING.md deleted file mode 100644 index ec71d3142..000000000 --- a/nemo_retriever/RAY_BALANCING.md +++ /dev/null @@ -1,162 +0,0 @@ -# Ray Balancing Strategy - -This document describes the default Ray Data balancing strategy used by -`nemo_retriever/src/nemo_retriever/ray_balance_dag.py`, why each test family exists, and -what to try next. - -## Goal - -Keep the default experiment set broad enough to find bottlenecks, but small -enough to run repeatedly across multiple machines. The default matrix is now -designed to stay under 1,000 variants. - -## Design Approach - -The default matrix uses a practical DOE-style approach: - -1. **Baseline run** - A stable reference point used for quick comparisons. -2. **One-factor-at-a-time (OFAT) sweeps** - Change one knob while keeping others at baseline to isolate sensitivity. -3. **Targeted interaction sweeps** - Test only high-value parameter interactions where coupling is expected. - -This avoids a full Cartesian product over all knobs (which grows to millions of -runs and is usually not actionable). - -## Default Matrix Definition - -The script’s default matrix includes the following families. - -### A) Baseline - -- Single baseline config with balanced CPU/GPU and midrange batch sizes. -- Purpose: anchor for all deltas and detect regressions quickly. - -### B) OFAT Sweeps - -- `pdf_workers`: `[4, 8, 12, 16]` -- `pdf_num_cpus`: `[1.0, 2.0, 3.0, 4.0]` -- `pdf_split_bs`: `[1, 4, 8]` -- `pdf_bs`: `[8, 16, 24, 32]` -- `page_elements_bs`: `[8, 16, 24, 32]` -- `page_elements_workers`: `[1, 2, 3]` -- `ocr_workers`: `[1, 2, 3]` -- `ocr_bs`: `[8, 16, 24, 32]` -- `embed_workers`: `[1, 2, 3]` -- `embed_bs`: `[128, 256, 512, 768]` -- `page_elements_cpus_per_actor`: `[1.0, 2.0, 4.0]` -- `ocr_cpus_per_actor`: `[1.0, 2.0, 4.0]` -- `embed_cpus_per_actor`: `[1.0, 2.0, 4.0]` -- `gpu_page_elements`: `[0.25, 0.5, 0.75]` -- `gpu_ocr`: `[0.75, 1.0]` -- `gpu_embed`: `[0.25, 0.5, 0.75]` - -Why this matters: - -- Identifies which knobs are low-impact (can be fixed) vs high-impact (worth - deeper search). -- Narrows the search space before trying interactions. - -### C) Targeted Interaction Grids - -1. **OCR throughput coupling** - - `ocr_bs x ocr_workers x gpu_ocr` -2. **Embedding throughput coupling** - - `embed_bs x embed_workers x gpu_embed` -3. **Page-elements throughput coupling** - - `page_elements_bs x page_elements_workers x gpu_page_elements` -4. **CPU extraction balance** - - `pdf_workers x pdf_num_cpus x pdf_bs` -5. **Actor CPU pressure** - - `page_elements_cpus_per_actor x ocr_cpus_per_actor x embed_cpus_per_actor` -6. **Pipeline batch-shape interaction** - - `pdf_bs x ocr_bs x embed_bs` - -Why this matters: - -- These are the pairs/triples most likely to create backpressure or starvation. -- Captures non-linear behavior without exploding matrix size. - -## What Has Been Tried So Far - -- Full/fat sweeps were tested early and found to be too large operationally. -- Matrix generation now deduplicates repeated variants and focuses on high-signal - combinations. -- Row-range sharding support (`--row-start`, `--row-end`) is used for distributed - execution across machines. -- Runtime metrics are captured per run: - - Ray Data operator stats (`rd_dataset.stats()`) - - Ray timeline (`ray.timeline(...)`) - -## GPU Constraint Handling - -Some deployments reject fractional `num_gpus` values above `1.0` per actor. - -To avoid invalid scheduling requests, matrix generation/loading normalizes any -`gpu_* > 1.0` request by: - -- setting per-actor GPU to `1.0`, and -- multiplying the corresponding actor count (`*_workers`) by `ceil(gpu_*)`. - -This keeps total requested GPU capacity similar while using valid actor specs. - -## Runtime Metrics Artifacts - -For each run, the pipeline writes metrics files under the run logs directory -(`runtime_metrics/` subdir) with the run prefix: - -- `.rd_dataset.stats.txt` (per-operator Ray Data stats) -- `.ray.timeline.json` (cluster task timeline) -- `.runtime.summary.json` (top-level run summary) - -## LanceDB Isolation and Recall Guarantees - -To prevent cross-run contamination: - -- The matrix runner deletes the configured LanceDB URI path before each run. -- Each run then recreates and writes a fresh `nv-ingest` table. - -To ensure recall is actually executed: - -- The batch pipeline now treats a missing LanceDB table as a hard failure - (after a short retry), instead of silently skipping recall. -- The matrix results CSV includes a `recall_ran` flag and marks runs as failed - if recall metrics are absent. - -## How to Generate and Run - -Generate matrix CSV only: - -```bash -python nemo_retriever/src/nemo_retriever/ray_balance_dag.py \ - --input-dir /path/to/pdfs \ - --write-default-matrix-csv nemo_retriever/ray_balance_variants.csv \ - --exit-after-writing-matrix -``` - -Run a shard: - -```bash -python nemo_retriever/src/nemo_retriever/ray_balance_dag.py \ - --input-dir /path/to/pdfs \ - --matrix-csv nemo_retriever/ray_balance_variants.csv \ - --row-start 1 \ - --row-end 200 \ - --output-csv nemo_retriever/ray_balance_results_001_200.csv -``` - -## Recommended Next Experiments - -1. **Adaptive second-pass search** - - Keep top 10-20% by throughput and run local neighborhood sweeps. -2. **Constraint-aware optimization** - - Add objective penalties for GPU OOM, high object-store pressure, or low - recall to avoid fragile winners. -3. **Dataset-stratified tests** - - Split small/medium/large PDFs and optimize per segment; mixed corpora often - hide better settings. -4. **Stability runs** - - Re-run top candidates 3-5 times and compare variance, not just best mean. -5. **Multi-objective scoring** - - Rank by weighted score: throughput, recall@k, and cost (GPU-hours). diff --git a/nemo_retriever/chart_stage_config.yaml b/nemo_retriever/chart_stage_config.yaml deleted file mode 100644 index fada15b68..000000000 --- a/nemo_retriever/chart_stage_config.yaml +++ /dev/null @@ -1,53 +0,0 @@ -# Example config for chart extraction. -# -# Intended usage (once the chart stage CLI is wired up similarly to table stage): -# - `retriever chart stage run --config --input ` -# - `retriever local stage4 run --config --input ` -# -# This YAML is parsed into `nv_ingest_api.internal.schemas.extract.extract_chart_schema.ChartExtractorSchema` -# via `nemo_retriever.chart.config.load_chart_extractor_schema_from_dict`. -# -# IMPORTANT: -# If `endpoint_config.yolox_endpoints` is null/empty, chart extraction will fall back to the local -# HuggingFace model (`nemo_retriever.model.local.nemotron_graphic_elements_v1`). -# If `endpoint_config.ocr_endpoints` is null/empty, chart extraction falls back to local Nemotron OCR -# (`nemo_retriever.model.local.nemotron_ocr_v1`) with default HuggingFace model loading. -# - -# Optional worker settings -max_queue_size: 1 -n_workers: 2 -raise_on_failure: false - -# Endpoint configuration for chart extraction (YOLOX graphic-elements + OCR). -endpoint_config: - # Optional auth token for secured services (NIM / hosted endpoints) - auth_token: null - - # Tuple/list in the form: [grpc, http] - # - # Chart extraction uses the YOLOX *graphic-elements* model (not page-elements). - # - # For the provided `docker-compose.yaml`, the host-mapped ports are: - # - graphic-elements HTTP: 8003 (container 8000) - # - graphic-elements gRPC: 8004 (container 8001) - # - # If you're running from inside the docker compose network instead, these often look like: - # - "graphic-elements:8001" and "http://graphic-elements:8000/v1/infer" - # yolox_endpoints: ["localhost:8004", "http://localhost:8003/v1/infer"] - yolox_endpoints: null - # Optional; if omitted it is inferred from which endpoint is present. - # yolox_infer_protocol: grpc - - # OCR model endpoints (same pattern: [grpc, http]). - # For the provided `docker-compose.yaml`, the host-mapped ports are: - # - ocr HTTP: 8019 (container 8000) - # - ocr gRPC: 8010 (container 8001) - # ocr_endpoints: ["localhost:8010", "http://localhost:8019/v1/infer"] - ocr_endpoints: null - # Optional; if omitted it is inferred from which endpoint is present. - # ocr_infer_protocol: grpc - - # Optional performance knobs - nim_batch_size: 2 - workers_per_progress_engine: 5 diff --git a/nemo_retriever/embedding_stage_config.yaml b/nemo_retriever/embedding_stage_config.yaml deleted file mode 100644 index f04c29841..000000000 --- a/nemo_retriever/embedding_stage_config.yaml +++ /dev/null @@ -1,43 +0,0 @@ -# Text embedding stage config (nemo_retriever.text_embed) -# -# This YAML is passed to: -# nemo_retriever.text_embed.config.load_text_embedding_schema_from_dict(...) -# which validates against nv-ingest-api's `TextEmbeddingSchema`. -# -# Minimal required fields are optional (schema provides defaults), but you -# typically set api_key / endpoint / model to point at your embedding service. - -# Auth (optional; can also be provided via task_config overrides) -api_key: "" # e.g. $NGC_API_KEY or $NVIDIA_API_KEY - -# Embedding service settings -# If set to null/empty, `retriever local stage5` will fall back to local HF embeddings -# via `nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder`. -embedding_nim_endpoint: null -# embedding_nim_endpoint: "http://localhost:8012/v1" -embedding_model: "nvidia/llama-nemotron-embed-1b-v2" - -# Request formatting -encoding_format: "float" # usually "float" -input_type: "passage" # "passage" (docs) or "query" (queries) -truncate: "END" # how the service truncates long inputs - -# Batch sizing (NIM-side batching is handled internally; this is stage batching) -batch_size: 4 - -# Modalities for multi-modal models (leave as "text" for text-only models) -text_elements_modality: "text" -image_elements_modality: "text" -structured_elements_modality: "text" -audio_elements_modality: "text" - -# Behavior -raise_on_failure: false -httpx_log_level: "WARNING" # DEBUG | INFO | WARNING | ERROR | CRITICAL - -# Optional: embed custom content from metadata.custom_content via glom path -# custom_content_field: "my_field" # e.g. "foo" or "nested.foo" -# result_target_field: "my_field_embedding" # where to write embedding under custom_content - -# Optional: request embedding vector size if the backend supports it -# dimensions: 1024 diff --git a/nemo_retriever/infographic_stage_config.yaml b/nemo_retriever/infographic_stage_config.yaml deleted file mode 100644 index 67ac68546..000000000 --- a/nemo_retriever/infographic_stage_config.yaml +++ /dev/null @@ -1,33 +0,0 @@ -# Example config for: -# - `retriever infographic stage run --config --input ` -# - `retriever local stage2 run --config --input ` -# -# This YAML is parsed into `nv_ingest_api.internal.schemas.extract.extract_infographic_schema.InfographicExtractorSchema` -# via `nemo_retriever.infographic.config.load_infographic_extractor_schema_from_dict`. -# -# IMPORTANT: -# `endpoint_config.ocr_endpoints` must provide at least one endpoint (gRPC or HTTP). Both cannot be null/empty. -# - -# Optional worker settings -max_queue_size: 1 -n_workers: 2 -raise_on_failure: false - -# Endpoint configuration for OCR used to enrich infographic primitives. -endpoint_config: - # Tuple/list in the form: [grpc, http] - # - gRPC example: "ocr:8001" - # - HTTP example: "http://ocr:8000/v1/infer" - #ocr_endpoints: ["localhost:8001", "http://localhost:8019/v1/infer"] - ocr_endpoints: null - - # Optional; if omitted it is inferred from which endpoint is present. - # ocr_infer_protocol: grpc - - # Optional auth token for secured services (NIM) - auth_token: null - - # Optional performance knobs - nim_batch_size: 2 - workers_per_progress_engine: 5 diff --git a/nemo_retriever/pdf_stage_config.yaml b/nemo_retriever/pdf_stage_config.yaml deleted file mode 100644 index 33126ba61..000000000 --- a/nemo_retriever/pdf_stage_config.yaml +++ /dev/null @@ -1,49 +0,0 @@ -# Example config for: `retriever pdf stage page-elements --config ` -# -# CLI override rule: -# - If you pass an option explicitly on the CLI, it wins. -# - Otherwise the value from this YAML file is used. -# -# You can run repeatedly: -# retriever pdf stage page-elements --config nemo_retriever/pdf_stage_config.yaml -# - -# Directory containing PDFs (scanned recursively for *.pdf) -input_dir: /home/local/jdyer/datasets/bo767 - -# PDF extraction method: pdfium | pdfium_hybrid | ocr | nemotron_parse | tika | unstructured_io | adobe | llama -method: pdfium - -# Optional auth token for NIM-backed services -auth_token: null - -endpoints: - yolox: - # If set to null then HuggingFace models will be used instead of NIMs - # grpc: localhost:8001 - # http: http://localhost:8000/v1/infer - grpc: null - http: null - - # Only required for method: nemotron_parse - nemotron_parse: - grpc: null - http: null - model_name: null - -extract: - text: true - # Text depth: page | document - text_depth: page - images: false - tables: true - charts: true - infographics: true - page_as_image: false - -outputs: - write_json: true - json_output_dir: /home/local/jdyer/datasets/bo767-results-hf-standalone/ - -# Optionally limit number of PDFs processed -limit: null diff --git a/nemo_retriever/pyproject.toml b/nemo_retriever/pyproject.toml index b9d1ac476..e8ae7fa0e 100644 --- a/nemo_retriever/pyproject.toml +++ b/nemo_retriever/pyproject.toml @@ -78,7 +78,7 @@ dev = [ ] [project.scripts] -retriever = "nemo_retriever.__main__:main" +retriever = "nemo_retriever.adapters.cli.main:main" [tool.setuptools.dynamic] version = {attr = "nemo_retriever.version.get_build_version"} diff --git a/nemo_retriever/ray_balance_variants.csv b/nemo_retriever/ray_balance_variants.csv deleted file mode 100644 index 309b3cdf4..000000000 --- a/nemo_retriever/ray_balance_variants.csv +++ /dev/null @@ -1,240 +0,0 @@ -run_id,pdf_workers,pdf_num_cpus,pdf_split_bs,pdf_bs,page_elements_bs,page_elements_workers,ocr_workers,ocr_bs,embed_workers,embed_bs,page_elements_cpus_per_actor,ocr_cpus_per_actor,embed_cpus_per_actor,gpu_page_elements,gpu_ocr,gpu_embed,ray_address,start_ray -V00001,8,2.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00002,4,2.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00003,12,2.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00004,16,2.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00005,8,1.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00006,8,3.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00007,8,4.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00008,8,2.0,4,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00009,8,2.0,8,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00010,8,2.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00011,8,2.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00012,8,2.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00013,8,2.0,1,16,8,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00014,8,2.0,1,16,24,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00015,8,2.0,1,16,32,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00016,8,2.0,1,16,16,2,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00017,8,2.0,1,16,16,3,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00018,8,2.0,1,16,16,1,1,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00019,8,2.0,1,16,16,1,3,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00020,8,2.0,1,16,16,1,2,8,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00021,8,2.0,1,16,16,1,2,24,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00022,8,2.0,1,16,16,1,2,32,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00023,8,2.0,1,16,16,1,2,16,2,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00024,8,2.0,1,16,16,1,2,16,3,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00025,8,2.0,1,16,16,1,2,16,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00026,8,2.0,1,16,16,1,2,16,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00027,8,2.0,1,16,16,1,2,16,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00028,8,2.0,1,16,16,1,2,16,1,256,2.0,2.0,1.0,0.5,1.0,0.5,,false -V00029,8,2.0,1,16,16,1,2,16,1,256,4.0,2.0,1.0,0.5,1.0,0.5,,false -V00030,8,2.0,1,16,16,1,2,16,1,256,1.0,1.0,1.0,0.5,1.0,0.5,,false -V00031,8,2.0,1,16,16,1,2,16,1,256,1.0,4.0,1.0,0.5,1.0,0.5,,false -V00032,8,2.0,1,16,16,1,2,16,1,256,1.0,2.0,2.0,0.5,1.0,0.5,,false -V00033,8,2.0,1,16,16,1,2,16,1,256,1.0,2.0,4.0,0.5,1.0,0.5,,false -V00034,8,2.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.25,1.0,0.5,,false -V00035,8,2.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.75,1.0,0.5,,false -V00036,8,2.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,0.75,0.5,,false -V00037,8,2.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.25,,false -V00038,8,2.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.75,,false -V00039,8,2.0,1,16,16,1,1,8,1,256,1.0,2.0,1.0,0.5,0.75,0.5,,false -V00040,8,2.0,1,16,16,1,1,8,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00041,8,2.0,1,16,16,1,2,8,1,256,1.0,2.0,1.0,0.5,0.75,0.5,,false -V00042,8,2.0,1,16,16,1,3,8,1,256,1.0,2.0,1.0,0.5,0.75,0.5,,false -V00043,8,2.0,1,16,16,1,3,8,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00044,8,2.0,1,16,16,1,1,16,1,256,1.0,2.0,1.0,0.5,0.75,0.5,,false -V00045,8,2.0,1,16,16,1,3,16,1,256,1.0,2.0,1.0,0.5,0.75,0.5,,false -V00046,8,2.0,1,16,16,1,1,24,1,256,1.0,2.0,1.0,0.5,0.75,0.5,,false -V00047,8,2.0,1,16,16,1,1,24,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00048,8,2.0,1,16,16,1,2,24,1,256,1.0,2.0,1.0,0.5,0.75,0.5,,false -V00049,8,2.0,1,16,16,1,3,24,1,256,1.0,2.0,1.0,0.5,0.75,0.5,,false -V00050,8,2.0,1,16,16,1,3,24,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00051,8,2.0,1,16,16,1,1,32,1,256,1.0,2.0,1.0,0.5,0.75,0.5,,false -V00052,8,2.0,1,16,16,1,1,32,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00053,8,2.0,1,16,16,1,2,32,1,256,1.0,2.0,1.0,0.5,0.75,0.5,,false -V00054,8,2.0,1,16,16,1,3,32,1,256,1.0,2.0,1.0,0.5,0.75,0.5,,false -V00055,8,2.0,1,16,16,1,3,32,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00056,8,2.0,1,16,16,1,2,16,1,128,1.0,2.0,1.0,0.5,1.0,0.25,,false -V00057,8,2.0,1,16,16,1,2,16,1,128,1.0,2.0,1.0,0.5,1.0,0.75,,false -V00058,8,2.0,1,16,16,1,2,16,2,128,1.0,2.0,1.0,0.5,1.0,0.25,,false -V00059,8,2.0,1,16,16,1,2,16,2,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00060,8,2.0,1,16,16,1,2,16,2,128,1.0,2.0,1.0,0.5,1.0,0.75,,false -V00061,8,2.0,1,16,16,1,2,16,3,128,1.0,2.0,1.0,0.5,1.0,0.25,,false -V00062,8,2.0,1,16,16,1,2,16,3,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00063,8,2.0,1,16,16,1,2,16,3,128,1.0,2.0,1.0,0.5,1.0,0.75,,false -V00064,8,2.0,1,16,16,1,2,16,2,256,1.0,2.0,1.0,0.5,1.0,0.25,,false -V00065,8,2.0,1,16,16,1,2,16,2,256,1.0,2.0,1.0,0.5,1.0,0.75,,false -V00066,8,2.0,1,16,16,1,2,16,3,256,1.0,2.0,1.0,0.5,1.0,0.25,,false -V00067,8,2.0,1,16,16,1,2,16,3,256,1.0,2.0,1.0,0.5,1.0,0.75,,false -V00068,8,2.0,1,16,16,1,2,16,1,512,1.0,2.0,1.0,0.5,1.0,0.25,,false -V00069,8,2.0,1,16,16,1,2,16,1,512,1.0,2.0,1.0,0.5,1.0,0.75,,false -V00070,8,2.0,1,16,16,1,2,16,2,512,1.0,2.0,1.0,0.5,1.0,0.25,,false -V00071,8,2.0,1,16,16,1,2,16,2,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00072,8,2.0,1,16,16,1,2,16,2,512,1.0,2.0,1.0,0.5,1.0,0.75,,false -V00073,8,2.0,1,16,16,1,2,16,3,512,1.0,2.0,1.0,0.5,1.0,0.25,,false -V00074,8,2.0,1,16,16,1,2,16,3,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00075,8,2.0,1,16,16,1,2,16,3,512,1.0,2.0,1.0,0.5,1.0,0.75,,false -V00076,8,2.0,1,16,16,1,2,16,1,768,1.0,2.0,1.0,0.5,1.0,0.25,,false -V00077,8,2.0,1,16,16,1,2,16,1,768,1.0,2.0,1.0,0.5,1.0,0.75,,false -V00078,8,2.0,1,16,16,1,2,16,2,768,1.0,2.0,1.0,0.5,1.0,0.25,,false -V00079,8,2.0,1,16,16,1,2,16,2,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00080,8,2.0,1,16,16,1,2,16,2,768,1.0,2.0,1.0,0.5,1.0,0.75,,false -V00081,8,2.0,1,16,16,1,2,16,3,768,1.0,2.0,1.0,0.5,1.0,0.25,,false -V00082,8,2.0,1,16,16,1,2,16,3,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00083,8,2.0,1,16,16,1,2,16,3,768,1.0,2.0,1.0,0.5,1.0,0.75,,false -V00084,8,2.0,1,16,8,1,2,16,1,256,1.0,2.0,1.0,0.25,1.0,0.5,,false -V00085,8,2.0,1,16,8,1,2,16,1,256,1.0,2.0,1.0,0.75,1.0,0.5,,false -V00086,8,2.0,1,16,8,2,2,16,1,256,1.0,2.0,1.0,0.25,1.0,0.5,,false -V00087,8,2.0,1,16,8,2,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00088,8,2.0,1,16,8,2,2,16,1,256,1.0,2.0,1.0,0.75,1.0,0.5,,false -V00089,8,2.0,1,16,8,3,2,16,1,256,1.0,2.0,1.0,0.25,1.0,0.5,,false -V00090,8,2.0,1,16,8,3,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00091,8,2.0,1,16,8,3,2,16,1,256,1.0,2.0,1.0,0.75,1.0,0.5,,false -V00092,8,2.0,1,16,16,2,2,16,1,256,1.0,2.0,1.0,0.25,1.0,0.5,,false -V00093,8,2.0,1,16,16,2,2,16,1,256,1.0,2.0,1.0,0.75,1.0,0.5,,false -V00094,8,2.0,1,16,16,3,2,16,1,256,1.0,2.0,1.0,0.25,1.0,0.5,,false -V00095,8,2.0,1,16,16,3,2,16,1,256,1.0,2.0,1.0,0.75,1.0,0.5,,false -V00096,8,2.0,1,16,24,1,2,16,1,256,1.0,2.0,1.0,0.25,1.0,0.5,,false -V00097,8,2.0,1,16,24,1,2,16,1,256,1.0,2.0,1.0,0.75,1.0,0.5,,false -V00098,8,2.0,1,16,24,2,2,16,1,256,1.0,2.0,1.0,0.25,1.0,0.5,,false -V00099,8,2.0,1,16,24,2,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00100,8,2.0,1,16,24,2,2,16,1,256,1.0,2.0,1.0,0.75,1.0,0.5,,false -V00101,8,2.0,1,16,24,3,2,16,1,256,1.0,2.0,1.0,0.25,1.0,0.5,,false -V00102,8,2.0,1,16,24,3,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00103,8,2.0,1,16,24,3,2,16,1,256,1.0,2.0,1.0,0.75,1.0,0.5,,false -V00104,8,2.0,1,16,32,1,2,16,1,256,1.0,2.0,1.0,0.25,1.0,0.5,,false -V00105,8,2.0,1,16,32,1,2,16,1,256,1.0,2.0,1.0,0.75,1.0,0.5,,false -V00106,8,2.0,1,16,32,2,2,16,1,256,1.0,2.0,1.0,0.25,1.0,0.5,,false -V00107,8,2.0,1,16,32,2,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00108,8,2.0,1,16,32,2,2,16,1,256,1.0,2.0,1.0,0.75,1.0,0.5,,false -V00109,8,2.0,1,16,32,3,2,16,1,256,1.0,2.0,1.0,0.25,1.0,0.5,,false -V00110,8,2.0,1,16,32,3,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00111,8,2.0,1,16,32,3,2,16,1,256,1.0,2.0,1.0,0.75,1.0,0.5,,false -V00112,4,1.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00113,4,1.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00114,4,1.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00115,4,1.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00116,4,2.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00117,4,2.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00118,4,2.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00119,4,3.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00120,4,3.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00121,4,3.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00122,4,3.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00123,4,4.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00124,4,4.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00125,4,4.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00126,4,4.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00127,8,1.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00128,8,1.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00129,8,1.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00130,8,3.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00131,8,3.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00132,8,3.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00133,8,4.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00134,8,4.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00135,8,4.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00136,12,1.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00137,12,1.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00138,12,1.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00139,12,1.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00140,12,2.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00141,12,2.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00142,12,2.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00143,12,3.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00144,12,3.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00145,12,3.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00146,12,3.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00147,12,4.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00148,12,4.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00149,12,4.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00150,12,4.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00151,16,1.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00152,16,1.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00153,16,1.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00154,16,1.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00155,16,2.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00156,16,2.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00157,16,2.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00158,16,3.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00159,16,3.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00160,16,3.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00161,16,3.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00162,16,4.0,1,8,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00163,16,4.0,1,16,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00164,16,4.0,1,24,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00165,16,4.0,1,32,16,1,2,16,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00166,8,2.0,1,16,16,1,2,16,1,256,1.0,1.0,2.0,0.5,1.0,0.5,,false -V00167,8,2.0,1,16,16,1,2,16,1,256,1.0,1.0,4.0,0.5,1.0,0.5,,false -V00168,8,2.0,1,16,16,1,2,16,1,256,1.0,4.0,2.0,0.5,1.0,0.5,,false -V00169,8,2.0,1,16,16,1,2,16,1,256,1.0,4.0,4.0,0.5,1.0,0.5,,false -V00170,8,2.0,1,16,16,1,2,16,1,256,2.0,1.0,1.0,0.5,1.0,0.5,,false -V00171,8,2.0,1,16,16,1,2,16,1,256,2.0,1.0,2.0,0.5,1.0,0.5,,false -V00172,8,2.0,1,16,16,1,2,16,1,256,2.0,1.0,4.0,0.5,1.0,0.5,,false -V00173,8,2.0,1,16,16,1,2,16,1,256,2.0,2.0,2.0,0.5,1.0,0.5,,false -V00174,8,2.0,1,16,16,1,2,16,1,256,2.0,2.0,4.0,0.5,1.0,0.5,,false -V00175,8,2.0,1,16,16,1,2,16,1,256,2.0,4.0,1.0,0.5,1.0,0.5,,false -V00176,8,2.0,1,16,16,1,2,16,1,256,2.0,4.0,2.0,0.5,1.0,0.5,,false -V00177,8,2.0,1,16,16,1,2,16,1,256,2.0,4.0,4.0,0.5,1.0,0.5,,false -V00178,8,2.0,1,16,16,1,2,16,1,256,4.0,1.0,1.0,0.5,1.0,0.5,,false -V00179,8,2.0,1,16,16,1,2,16,1,256,4.0,1.0,2.0,0.5,1.0,0.5,,false -V00180,8,2.0,1,16,16,1,2,16,1,256,4.0,1.0,4.0,0.5,1.0,0.5,,false -V00181,8,2.0,1,16,16,1,2,16,1,256,4.0,2.0,2.0,0.5,1.0,0.5,,false -V00182,8,2.0,1,16,16,1,2,16,1,256,4.0,2.0,4.0,0.5,1.0,0.5,,false -V00183,8,2.0,1,16,16,1,2,16,1,256,4.0,4.0,1.0,0.5,1.0,0.5,,false -V00184,8,2.0,1,16,16,1,2,16,1,256,4.0,4.0,2.0,0.5,1.0,0.5,,false -V00185,8,2.0,1,16,16,1,2,16,1,256,4.0,4.0,4.0,0.5,1.0,0.5,,false -V00186,8,2.0,1,8,16,1,2,8,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00187,8,2.0,1,8,16,1,2,8,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00188,8,2.0,1,8,16,1,2,8,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00189,8,2.0,1,8,16,1,2,8,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00190,8,2.0,1,8,16,1,2,16,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00191,8,2.0,1,8,16,1,2,16,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00192,8,2.0,1,8,16,1,2,16,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00193,8,2.0,1,8,16,1,2,24,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00194,8,2.0,1,8,16,1,2,24,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00195,8,2.0,1,8,16,1,2,24,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00196,8,2.0,1,8,16,1,2,24,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00197,8,2.0,1,8,16,1,2,32,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00198,8,2.0,1,8,16,1,2,32,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00199,8,2.0,1,8,16,1,2,32,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00200,8,2.0,1,8,16,1,2,32,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00201,8,2.0,1,16,16,1,2,8,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00202,8,2.0,1,16,16,1,2,8,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00203,8,2.0,1,16,16,1,2,8,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00204,8,2.0,1,16,16,1,2,24,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00205,8,2.0,1,16,16,1,2,24,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00206,8,2.0,1,16,16,1,2,24,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00207,8,2.0,1,16,16,1,2,32,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00208,8,2.0,1,16,16,1,2,32,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00209,8,2.0,1,16,16,1,2,32,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00210,8,2.0,1,24,16,1,2,8,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00211,8,2.0,1,24,16,1,2,8,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00212,8,2.0,1,24,16,1,2,8,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00213,8,2.0,1,24,16,1,2,8,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00214,8,2.0,1,24,16,1,2,16,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00215,8,2.0,1,24,16,1,2,16,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00216,8,2.0,1,24,16,1,2,16,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00217,8,2.0,1,24,16,1,2,24,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00218,8,2.0,1,24,16,1,2,24,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00219,8,2.0,1,24,16,1,2,24,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00220,8,2.0,1,24,16,1,2,24,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00221,8,2.0,1,24,16,1,2,32,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00222,8,2.0,1,24,16,1,2,32,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00223,8,2.0,1,24,16,1,2,32,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00224,8,2.0,1,24,16,1,2,32,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00225,8,2.0,1,32,16,1,2,8,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00226,8,2.0,1,32,16,1,2,8,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00227,8,2.0,1,32,16,1,2,8,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00228,8,2.0,1,32,16,1,2,8,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00229,8,2.0,1,32,16,1,2,16,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00230,8,2.0,1,32,16,1,2,16,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00231,8,2.0,1,32,16,1,2,16,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00232,8,2.0,1,32,16,1,2,24,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00233,8,2.0,1,32,16,1,2,24,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00234,8,2.0,1,32,16,1,2,24,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00235,8,2.0,1,32,16,1,2,24,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00236,8,2.0,1,32,16,1,2,32,1,128,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00237,8,2.0,1,32,16,1,2,32,1,256,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00238,8,2.0,1,32,16,1,2,32,1,512,1.0,2.0,1.0,0.5,1.0,0.5,,false -V00239,8,2.0,1,32,16,1,2,32,1,768,1.0,2.0,1.0,0.5,1.0,0.5,,false diff --git a/nemo_retriever/src/nemo_retriever/__init__.py b/nemo_retriever/src/nemo_retriever/__init__.py index 2d0dbc01b..e4b32d309 100644 --- a/nemo_retriever/src/nemo_retriever/__init__.py +++ b/nemo_retriever/src/nemo_retriever/__init__.py @@ -2,7 +2,7 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Retriever application package.""" +"""NeMo Retriever application package.""" from __future__ import annotations diff --git a/nemo_retriever/src/nemo_retriever/__main__.py b/nemo_retriever/src/nemo_retriever/__main__.py deleted file mode 100644 index 652df9193..000000000 --- a/nemo_retriever/src/nemo_retriever/__main__.py +++ /dev/null @@ -1,9 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from .adapters.cli.main import app, main - -__all__ = ["app", "main"] diff --git a/nemo_retriever/src/nemo_retriever/_build_info.py b/nemo_retriever/src/nemo_retriever/_build_info.py deleted file mode 100644 index c446b4406..000000000 --- a/nemo_retriever/src/nemo_retriever/_build_info.py +++ /dev/null @@ -1,8 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Build metadata written by CI before packaging.""" - -BUILD_GIT_SHA = "unknown" -BUILD_DATE = "unknown" diff --git a/nemo_retriever/src/nemo_retriever/version.py b/nemo_retriever/src/nemo_retriever/version.py index 81c099013..553b2d464 100644 --- a/nemo_retriever/src/nemo_retriever/version.py +++ b/nemo_retriever/src/nemo_retriever/version.py @@ -13,13 +13,9 @@ import os import subprocess -try: - from ._build_info import BUILD_DATE as _PACKAGE_BUILD_DATE - from ._build_info import BUILD_GIT_SHA as _PACKAGE_BUILD_GIT_SHA -except ImportError: - # During setuptools build isolation the package may not be importable - _PACKAGE_BUILD_DATE = "unknown" - _PACKAGE_BUILD_GIT_SHA = "unknown" +# Overwritten by CI before packaging; see .github/workflows/*.yml. +_PACKAGE_BUILD_GIT_SHA = "unknown" +_PACKAGE_BUILD_DATE = "unknown" _PKG_NAME = "nemo-retriever" _UNKNOWN = "unknown" diff --git a/nemo_retriever/table_stage_config.yaml b/nemo_retriever/table_stage_config.yaml deleted file mode 100644 index e268c8ba7..000000000 --- a/nemo_retriever/table_stage_config.yaml +++ /dev/null @@ -1,38 +0,0 @@ -# Example config for: -# - `retriever table stage run --config --input ` -# - `retriever local stage3 run --config --input ` -# -# This YAML is parsed into `nv_ingest_api.internal.schemas.extract.extract_table_schema.TableExtractorSchema` -# via `nemo_retriever.table.config.load_table_extractor_schema_from_dict`. -# -# IMPORTANT: -# `endpoint_config.yolox_endpoints` and `endpoint_config.ocr_endpoints` must each provide at least one -# endpoint (gRPC or HTTP). Both cannot be null/empty for either entry. -# - -# Optional worker settings -max_queue_size: 1 -n_workers: 2 -raise_on_failure: false - -# Endpoint configuration for table extraction (YOLOX table-structure + OCR). -endpoint_config: - # Optional auth token for secured services (NIM) - auth_token: null - - # Tuple/list in the form: [grpc, http] - # YOLOX table-structure model endpoints - # yolox_endpoints: ["localhost:8007", "http://localhost:8006/v1/infer"] - yolox_endpoints: null - # Optional; if omitted it is inferred from which endpoint is present. - # yolox_infer_protocol: grpc - - # OCR model endpoints - # ocr_endpoints: ["localhost:8010", "http://localhost:8019/v1/infer"] - ocr_endpoints: null - # Optional; if omitted it is inferred from which endpoint is present. - # ocr_infer_protocol: grpc - - # Optional performance knobs - nim_batch_size: 2 - workers_per_progress_engine: 5 From 27aa293afb04ce4f3eb6c474400daddd60e72b6a Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 9 Mar 2026 11:00:50 -0400 Subject: [PATCH 09/12] checkpoint commit --- nemo_retriever/src/nemo_retriever/chart/config.py | 2 +- nemo_retriever/src/nemo_retriever/infographic/config.py | 2 +- nemo_retriever/src/nemo_retriever/pdf/config.py | 2 +- nemo_retriever/src/nemo_retriever/table/config.py | 2 +- nemo_retriever/src/nemo_retriever/table/table_detection.py | 2 +- nemo_retriever/src/nemo_retriever/util/__init__.py | 3 --- nemo_retriever/src/nemo_retriever/{ => utils}/config_utils.py | 0 .../src/nemo_retriever/{util => utils}/table_and_chart.py | 0 nemo_retriever/tests/test_table_structure.py | 2 +- 9 files changed, 6 insertions(+), 9 deletions(-) delete mode 100644 nemo_retriever/src/nemo_retriever/util/__init__.py rename nemo_retriever/src/nemo_retriever/{ => utils}/config_utils.py (100%) rename nemo_retriever/src/nemo_retriever/{util => utils}/table_and_chart.py (100%) diff --git a/nemo_retriever/src/nemo_retriever/chart/config.py b/nemo_retriever/src/nemo_retriever/chart/config.py index a9dbadf44..68c3a0301 100644 --- a/nemo_retriever/src/nemo_retriever/chart/config.py +++ b/nemo_retriever/src/nemo_retriever/chart/config.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from typing import Any, Dict -from nemo_retriever.config_utils import endpoints_from_yaml +from nemo_retriever.utils.config_utils import endpoints_from_yaml from nv_ingest_api.internal.schemas.extract.extract_chart_schema import ChartExtractorSchema diff --git a/nemo_retriever/src/nemo_retriever/infographic/config.py b/nemo_retriever/src/nemo_retriever/infographic/config.py index f44495890..48ae71ac0 100644 --- a/nemo_retriever/src/nemo_retriever/infographic/config.py +++ b/nemo_retriever/src/nemo_retriever/infographic/config.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from typing import Any, Dict -from nemo_retriever.config_utils import endpoints_from_yaml +from nemo_retriever.utils.config_utils import endpoints_from_yaml from nv_ingest_api.internal.schemas.extract.extract_infographic_schema import InfographicExtractorSchema diff --git a/nemo_retriever/src/nemo_retriever/pdf/config.py b/nemo_retriever/src/nemo_retriever/pdf/config.py index f91b0a685..ab831a45d 100644 --- a/nemo_retriever/src/nemo_retriever/pdf/config.py +++ b/nemo_retriever/src/nemo_retriever/pdf/config.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from typing import Any, Dict -from nemo_retriever.config_utils import endpoints_from_yaml +from nemo_retriever.utils.config_utils import endpoints_from_yaml from nv_ingest_api.internal.schemas.extract.extract_pdf_schema import PDFExtractorSchema diff --git a/nemo_retriever/src/nemo_retriever/table/config.py b/nemo_retriever/src/nemo_retriever/table/config.py index 412dba1c0..535c4df8b 100644 --- a/nemo_retriever/src/nemo_retriever/table/config.py +++ b/nemo_retriever/src/nemo_retriever/table/config.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from typing import Any, Dict -from nemo_retriever.config_utils import endpoints_from_yaml +from nemo_retriever.utils.config_utils import endpoints_from_yaml from nv_ingest_api.internal.schemas.extract.extract_table_schema import TableExtractorSchema diff --git a/nemo_retriever/src/nemo_retriever/table/table_detection.py b/nemo_retriever/src/nemo_retriever/table/table_detection.py index 14df8f22c..bfb82a187 100644 --- a/nemo_retriever/src/nemo_retriever/table/table_detection.py +++ b/nemo_retriever/src/nemo_retriever/table/table_detection.py @@ -261,7 +261,7 @@ def table_structure_ocr_page_elements( _np_rgb_to_b64_png, _parse_ocr_result, ) - from nemo_retriever.util.table_and_chart import join_table_structure_and_ocr_output + from nemo_retriever.utils.table_and_chart import join_table_structure_and_ocr_output retry = remote_retry or RemoteRetryParams( remote_max_pool_workers=int(kwargs.get("remote_max_pool_workers", 16)), diff --git a/nemo_retriever/src/nemo_retriever/util/__init__.py b/nemo_retriever/src/nemo_retriever/util/__init__.py deleted file mode 100644 index 6aa2e3d5b..000000000 --- a/nemo_retriever/src/nemo_retriever/util/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 diff --git a/nemo_retriever/src/nemo_retriever/config_utils.py b/nemo_retriever/src/nemo_retriever/utils/config_utils.py similarity index 100% rename from nemo_retriever/src/nemo_retriever/config_utils.py rename to nemo_retriever/src/nemo_retriever/utils/config_utils.py diff --git a/nemo_retriever/src/nemo_retriever/util/table_and_chart.py b/nemo_retriever/src/nemo_retriever/utils/table_and_chart.py similarity index 100% rename from nemo_retriever/src/nemo_retriever/util/table_and_chart.py rename to nemo_retriever/src/nemo_retriever/utils/table_and_chart.py diff --git a/nemo_retriever/tests/test_table_structure.py b/nemo_retriever/tests/test_table_structure.py index 7499dff1a..d200cbae1 100644 --- a/nemo_retriever/tests/test_table_structure.py +++ b/nemo_retriever/tests/test_table_structure.py @@ -14,7 +14,7 @@ import pandas as pd import pytest -from nemo_retriever.util.table_and_chart import join_table_structure_and_ocr_output +from nemo_retriever.utils.table_and_chart import join_table_structure_and_ocr_output def _can_import(mod: str) -> bool: From 1450af523c4f8f2a2efbc975106b2a7ec1f809fc Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 10 Mar 2026 12:31:47 -0400 Subject: [PATCH 10/12] Remove dead code at nemo_retriever/src/nemo_retriever/config --- .../src/nemo_retriever/config/__init__.py | 7 ---- .../src/nemo_retriever/config/loader.py | 37 ------------------- 2 files changed, 44 deletions(-) delete mode 100644 nemo_retriever/src/nemo_retriever/config/__init__.py delete mode 100644 nemo_retriever/src/nemo_retriever/config/loader.py diff --git a/nemo_retriever/src/nemo_retriever/config/__init__.py b/nemo_retriever/src/nemo_retriever/config/__init__.py deleted file mode 100644 index 23432d208..000000000 --- a/nemo_retriever/src/nemo_retriever/config/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from .loader import load_config_file, load_config_section, resolve_config_path - -__all__ = ["load_config_file", "load_config_section", "resolve_config_path"] diff --git a/nemo_retriever/src/nemo_retriever/config/loader.py b/nemo_retriever/src/nemo_retriever/config/loader.py deleted file mode 100644 index a21532d33..000000000 --- a/nemo_retriever/src/nemo_retriever/config/loader.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -from nemo_retriever.ingest_config import ( - load_ingest_config_file, - load_ingest_config_section, - resolve_ingest_config_path, -) - - -def resolve_config_path(explicit: Optional[Path]) -> Tuple[Optional[Path], str]: - return resolve_ingest_config_path(explicit) - - -def load_config_file(explicit: Optional[Path], *, verbose: bool = True) -> Tuple[Dict[str, Any], Optional[Path], str]: - return load_ingest_config_file(explicit, verbose=verbose) - - -def load_config_section( - explicit: Optional[Path], - *, - section: str, - verbose: bool = True, - warn_if_missing_section: bool = True, -) -> Dict[str, Any]: - return load_ingest_config_section( - explicit, - section=section, - verbose=verbose, - warn_if_missing_section=warn_if_missing_section, - ) From 2f914a6a34671a576b8b47e143acb99f1aeab0f9 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 10 Mar 2026 13:01:50 -0400 Subject: [PATCH 11/12] Do not pin to specific cuda device and allow underlying framework to determine device for nemotron_parse --- .../src/nemo_retriever/model/local/nemotron_parse_v1_2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_retriever/src/nemo_retriever/model/local/nemotron_parse_v1_2.py b/nemo_retriever/src/nemo_retriever/model/local/nemotron_parse_v1_2.py index 07e47ddcd..93467887c 100644 --- a/nemo_retriever/src/nemo_retriever/model/local/nemotron_parse_v1_2.py +++ b/nemo_retriever/src/nemo_retriever/model/local/nemotron_parse_v1_2.py @@ -37,7 +37,7 @@ def __init__( self._model_path = model_path self._task_prompt = task_prompt - self._device = torch.device(device or ("cuda:0" if torch.cuda.is_available() else "cpu")) + self._device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) self._dtype = torch.bfloat16 if self._device.type == "cuda" else torch.float32 hf_cache_dir = configure_global_hf_cache_base(hf_cache_dir) _revision = get_hf_revision(self._model_path) From e5753e0c28363ac8fbc1e2ec7f97ef7bcdefa1fc Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 10 Mar 2026 13:46:54 -0400 Subject: [PATCH 12/12] Cleanup common functions --- .../nemo_retriever/chart/chart_detection.py | 182 ++---------------- .../src/nemo_retriever/harness/nightly.py | 4 +- .../src/nemo_retriever/harness/run.py | 6 +- .../src/nemo_retriever/html/convert.py | 2 +- .../infographic/infographic_detection.py | 161 +--------------- .../src/nemo_retriever/ingest_modes/batch.py | 4 +- .../src/nemo_retriever/ingest_modes/fused.py | 6 +- .../nemo_retriever/ingest_modes/inprocess.py | 10 +- .../src/nemo_retriever/ocr/__init__.py | 24 ++- nemo_retriever/src/nemo_retriever/ocr/ocr.py | 38 ++-- .../src/nemo_retriever/pdf/split.py | 6 +- .../src/nemo_retriever/recall/vdb_recall.py | 2 +- .../nemo_retriever/table/table_detection.py | 137 ++----------- .../src/nemo_retriever/txt/split.py | 6 +- .../src/nemo_retriever/utils/detection.py | 118 ++++++++++++ .../nemo_retriever/utils/table_and_chart.py | 6 +- .../tests/test_audio_pipeline_batch.py | 2 +- .../tests/test_chart_graphic_elements.py | 10 +- nemo_retriever/tests/test_multimodal_embed.py | 2 +- 19 files changed, 228 insertions(+), 498 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/utils/detection.py diff --git a/nemo_retriever/src/nemo_retriever/chart/chart_detection.py b/nemo_retriever/src/nemo_retriever/chart/chart_detection.py index 76578a5ad..10cf19512 100644 --- a/nemo_retriever/src/nemo_retriever/chart/chart_detection.py +++ b/nemo_retriever/src/nemo_retriever/chart/chart_detection.py @@ -14,6 +14,7 @@ import pandas as pd from nemo_retriever.nim.nim import invoke_image_inference_batches from nemo_retriever.params import RemoteRetryParams +from nemo_retriever.utils.detection import prediction_to_detections try: import numpy as np @@ -58,64 +59,6 @@ def _decode_b64_image_to_chw_tensor(image_b64: str) -> Tuple["torch.Tensor", Tup return t, (int(h), int(w)) -def _crop_b64_image_by_norm_bbox( - page_image_b64: str, - *, - bbox_xyxy_norm: Sequence[float], - image_format: str = "png", -) -> Tuple[Optional[str], Optional[Tuple[int, int]]]: - """ - Crop a base64-encoded RGB image by a normalized xyxy bbox. - - Returns: - - cropped_image_b64 (png) or None - - cropped_shape_hw (H,W) or None - """ - if Image is None: # pragma: no cover - raise ImportError("Cropping requires pillow.") - if not isinstance(page_image_b64, str) or not page_image_b64: - return None, None - try: - x1n, y1n, x2n, y2n = [float(x) for x in bbox_xyxy_norm] - except Exception: - return None, None - - try: - raw = base64.b64decode(page_image_b64) - with Image.open(io.BytesIO(raw)) as im0: - im = im0.convert("RGB") - w, h = im.size - if w <= 1 or h <= 1: - return None, None - - def _clamp_int(v: float, lo: int, hi: int) -> int: - if v != v: # NaN - return lo - return int(min(max(v, float(lo)), float(hi))) - - x1 = _clamp_int(x1n * w, 0, w) - x2 = _clamp_int(x2n * w, 0, w) - y1 = _clamp_int(y1n * h, 0, h) - y2 = _clamp_int(y2n * h, 0, h) - - if x2 <= x1 or y2 <= y1: - return None, None - - crop = im.crop((x1, y1, x2, y2)) - cw, ch = crop.size - if cw <= 1 or ch <= 1: - return None, None - - buf = io.BytesIO() - fmt = str(image_format or "png").lower() - if fmt not in {"png"}: - fmt = "png" - crop.save(buf, format=fmt.upper()) - return base64.b64encode(buf.getvalue()).decode("ascii"), (int(ch), int(cw)) - except Exception: - return None, None - - def _labels_from_model(model: Any) -> List[str]: try: labels = getattr(getattr(model, "_model", None), "labels", None) @@ -136,107 +79,6 @@ def _labels_from_model(model: Any) -> List[str]: return [] -def _prediction_to_detections(pred: Any, *, label_names: List[str]) -> List[Dict[str, Any]]: - if torch is None: # pragma: no cover - raise ImportError("torch required for prediction parsing.") - - boxes = labels = scores = None - if isinstance(pred, dict): - # IMPORTANT: do not use `or` chains here. torch.Tensor truthiness is ambiguous and raises. - def _get_any(d: Dict[str, Any], *keys: str) -> Any: - for k in keys: - if k in d: - v = d.get(k) - if v is not None: - return v - return None - - boxes = _get_any(pred, "boxes", "bboxes", "bbox", "box") - labels = _get_any(pred, "labels", "classes", "class_ids", "class") - scores = _get_any(pred, "scores", "conf", "confidences", "score") - elif isinstance(pred, (list, tuple)) and len(pred) >= 3: - boxes, labels, scores = pred[0], pred[1], pred[2] - - if boxes is None or labels is None: - return [] - - def _to_tensor(x: Any) -> Optional["torch.Tensor"]: - if x is None: - return None - if isinstance(x, torch.Tensor): - return x.detach().cpu() - try: - return torch.as_tensor(x).detach().cpu() - except Exception: - return None - - # Handle string labels (e.g. NIM returns ["chart_title", "xlabel", ...]). - # torch.as_tensor cannot convert strings, so handle them before tensor conversion. - _string_labels: Optional[List[str]] = None - if isinstance(labels, (list, tuple)) and labels and isinstance(labels[0], str): - _string_labels = [str(x) for x in labels] - - b = _to_tensor(boxes) - labels_t = _to_tensor(labels) if _string_labels is None else None - s = _to_tensor(scores) if scores is not None else None - if b is None: - return [] - if labels_t is None and _string_labels is None: - return [] - - if b.ndim != 2 or int(b.shape[-1]) != 4: - return [] - if labels_t is not None: - if labels_t.ndim == 2 and int(labels_t.shape[-1]) == 1: - labels_t = labels_t.squeeze(-1) - if labels_t.ndim != 1: - return [] - - n_labels = len(_string_labels) if _string_labels is not None else int(labels_t.shape[0]) - n = int(min(b.shape[0], n_labels)) - dets: List[Dict[str, Any]] = [] - for i in range(n): - try: - x1, y1, x2, y2 = [float(x) for x in b[i].tolist()] - except Exception: - continue - - if _string_labels is not None: - label_i = i - label_name = _string_labels[i] - else: - label_i: Optional[int] - try: - label_i = int(labels_t[i].item()) - except Exception: - label_i = None - - label_name = None - if label_i is not None and 0 <= label_i < len(label_names): - label_name = label_names[label_i] - if not label_name: - label_name = f"label_{label_i}" if label_i is not None else "unknown" - - score_f: Optional[float] - if s is not None and s.ndim >= 1 and int(s.shape[0]) > i: - try: - score_f = float(s[i].item()) - except Exception: - score_f = None - else: - score_f = None - - dets.append( - { - "bbox_xyxy_norm": [x1, y1, x2, y2], - "label": label_i, - "label_name": str(label_name), - "score": score_f, - } - ) - return dets - - def _counts_by_label(detections: Sequence[Dict[str, Any]]) -> Dict[str, int]: out: Dict[str, int] = {} for d in detections: @@ -354,11 +196,11 @@ def graphic_elements_ocr_page_elements( Original columns plus ``chart`` and ``graphic_elements_ocr_v1``. """ from nemo_retriever.ocr.ocr import ( - _blocks_to_text, - _crop_all_from_page, - _extract_remote_ocr_item, - _np_rgb_to_b64_png, - _parse_ocr_result, + blocks_to_text, + crop_all_from_page, + extract_remote_ocr_item, + np_rgb_to_b64_png, + parse_ocr_result, ) from nemo_retriever.util.table_and_chart import join_graphic_elements_and_ocr_output @@ -413,7 +255,7 @@ def graphic_elements_ocr_page_elements( continue # --- Crop all chart detections --- - crops = _crop_all_from_page(page_image_b64, dets, {"chart"}) + crops = crop_all_from_page(page_image_b64, dets, {"chart"}) if not crops: all_chart.append(chart_items) @@ -422,7 +264,7 @@ def graphic_elements_ocr_page_elements( # Pre-compute base64 encodings once for remote paths. crop_b64s = ( - [_np_rgb_to_b64_png(crop_array) for _, _, crop_array in crops] + [np_rgb_to_b64_png(crop_array) for _, _, crop_array in crops] if (use_remote_ge or use_remote_ocr) else [] ) @@ -457,7 +299,7 @@ def graphic_elements_ocr_page_elements( if isinstance(pre, torch.Tensor) and pre.ndim == 3: pre = pre.unsqueeze(0) pred = graphic_elements_model.invoke(pre, (h, w)) - ge_dets = _prediction_to_detections(pred, label_names=label_names) + ge_dets = prediction_to_detections(pred, label_names=label_names) ge_results.append(ge_dets) # --- Run OCR on all crops --- @@ -476,7 +318,7 @@ def graphic_elements_ocr_page_elements( if len(ocr_response_items) != len(crops): raise RuntimeError(f"Expected {len(crops)} OCR responses, got {len(ocr_response_items)}") for resp in ocr_response_items: - ocr_results.append(_extract_remote_ocr_item(resp)) + ocr_results.append(extract_remote_ocr_item(resp)) else: for _, _, crop_array in crops: ocr_results.append(ocr_model.invoke(crop_array, merge_level="word")) @@ -492,8 +334,8 @@ def graphic_elements_ocr_page_elements( # Fallback: if no GE detections matched, use OCR-only text. if not text: - blocks = _parse_ocr_result(ocr_preds) - text = _blocks_to_text(blocks) + blocks = parse_ocr_result(ocr_preds) + text = blocks_to_text(blocks) chart_items.append({"bbox_xyxy_norm": bbox, "text": text}) diff --git a/nemo_retriever/src/nemo_retriever/harness/nightly.py b/nemo_retriever/src/nemo_retriever/harness/nightly.py index cfc07f2d5..b8c50469c 100644 --- a/nemo_retriever/src/nemo_retriever/harness/nightly.py +++ b/nemo_retriever/src/nemo_retriever/harness/nightly.py @@ -11,7 +11,7 @@ from nemo_retriever.harness.artifacts import write_session_summary from nemo_retriever.harness.config import DEFAULT_NIGHTLY_CONFIG_PATH, load_nightly_config -from nemo_retriever.harness.run import _normalize_tags, execute_runs +from nemo_retriever.harness.run import normalize_tags, execute_runs from nemo_retriever.harness.slack import load_replay_report, load_session_report, post_report_to_slack @@ -58,7 +58,7 @@ def nightly_command( ), dry_run: bool = typer.Option(False, "--dry-run", help="Print nightly run plan without executing."), ) -> None: - normalized_tags = _normalize_tags(tag) + normalized_tags = normalize_tags(tag) nightly_cfg = load_nightly_config(runs_config) runs = nightly_cfg["runs"] slack_config = nightly_cfg["slack"] diff --git a/nemo_retriever/src/nemo_retriever/harness/run.py b/nemo_retriever/src/nemo_retriever/harness/run.py index f0d10509c..c76cd03a6 100644 --- a/nemo_retriever/src/nemo_retriever/harness/run.py +++ b/nemo_retriever/src/nemo_retriever/harness/run.py @@ -91,7 +91,7 @@ def _collect_run_metadata() -> dict[str, Any]: } -def _normalize_tags(tags: list[str] | None) -> list[str]: +def normalize_tags(tags: list[str] | None) -> list[str]: normalized: list[str] = [] seen: set[str] = set() @@ -466,7 +466,7 @@ def _run_entry( artifact_dir.mkdir(parents=True, exist_ok=True) resolved_run_name = run_name or cfg.dataset_label - normalized_tags = _normalize_tags(tags) + normalized_tags = normalize_tags(tags) result = _run_single(cfg, artifact_dir, run_id=resolved_run_name, tags=normalized_tags) run_result = { "run_name": resolved_run_name, @@ -548,7 +548,7 @@ def sweep_command( tag: list[str] = typer.Option([], "--tag", help="Session tag to persist on each run. Repeatable."), dry_run: bool = typer.Option(False, "--dry-run", help="Print run plan without executing."), ) -> None: - normalized_tags = _normalize_tags(tag) + normalized_tags = normalize_tags(tag) runs = load_runs_config(runs_config) if dry_run: typer.echo("Sweep dry run:") diff --git a/nemo_retriever/src/nemo_retriever/html/convert.py b/nemo_retriever/src/nemo_retriever/html/convert.py index a4968a670..222f84af3 100644 --- a/nemo_retriever/src/nemo_retriever/html/convert.py +++ b/nemo_retriever/src/nemo_retriever/html/convert.py @@ -23,7 +23,7 @@ DEFAULT_TOKENIZER_MODEL_ID, split_text_by_tokens, ) -from ..txt.split import _get_tokenizer as _get_txt_tokenizer +from ..txt.split import get_tokenizer as _get_txt_tokenizer def html_to_markdown(html_content: Union[str, bytes, Path]) -> str: diff --git a/nemo_retriever/src/nemo_retriever/infographic/infographic_detection.py b/nemo_retriever/src/nemo_retriever/infographic/infographic_detection.py index 84ca7058f..acd9fa2f0 100644 --- a/nemo_retriever/src/nemo_retriever/infographic/infographic_detection.py +++ b/nemo_retriever/src/nemo_retriever/infographic/infographic_detection.py @@ -22,6 +22,8 @@ import pandas as pd from nemo_retriever.params import RemoteRetryParams from nemo_retriever.nim.nim import invoke_image_inference_batches +from nemo_retriever.ocr.ocr import crop_b64_image_by_norm_bbox +from nemo_retriever.utils.detection import prediction_to_detections try: import numpy as np @@ -66,64 +68,6 @@ def _decode_b64_image_to_chw_tensor(image_b64: str) -> Tuple["torch.Tensor", Tup return t, (int(h), int(w)) -def _crop_b64_image_by_norm_bbox( - page_image_b64: str, - *, - bbox_xyxy_norm: Sequence[float], - image_format: str = "png", -) -> Tuple[Optional[str], Optional[Tuple[int, int]]]: - """ - Crop a base64-encoded RGB image by a normalized xyxy bbox. - - Returns: - - cropped_image_b64 (png) or None - - cropped_shape_hw (H,W) or None - """ - if Image is None: # pragma: no cover - raise ImportError("Cropping requires pillow.") - if not isinstance(page_image_b64, str) or not page_image_b64: - return None, None - try: - x1n, y1n, x2n, y2n = [float(x) for x in bbox_xyxy_norm] - except Exception: - return None, None - - try: - raw = base64.b64decode(page_image_b64) - with Image.open(io.BytesIO(raw)) as im0: - im = im0.convert("RGB") - w, h = im.size - if w <= 1 or h <= 1: - return None, None - - def _clamp_int(v: float, lo: int, hi: int) -> int: - if v != v: # NaN - return lo - return int(min(max(v, float(lo)), float(hi))) - - x1 = _clamp_int(x1n * w, 0, w) - x2 = _clamp_int(x2n * w, 0, w) - y1 = _clamp_int(y1n * h, 0, h) - y2 = _clamp_int(y2n * h, 0, h) - - if x2 <= x1 or y2 <= y1: - return None, None - - crop = im.crop((x1, y1, x2, y2)) - cw, ch = crop.size - if cw <= 1 or ch <= 1: - return None, None - - buf = io.BytesIO() - fmt = str(image_format or "png").lower() - if fmt not in {"png"}: - fmt = "png" - crop.save(buf, format=fmt.upper()) - return base64.b64encode(buf.getvalue()).decode("ascii"), (int(ch), int(cw)) - except Exception: - return None, None - - def _labels_from_model(model: Any) -> List[str]: try: labels = getattr(getattr(model, "_model", None), "labels", None) @@ -144,93 +88,6 @@ def _labels_from_model(model: Any) -> List[str]: return [] -def _prediction_to_detections(pred: Any, *, label_names: List[str]) -> List[Dict[str, Any]]: - if torch is None: # pragma: no cover - raise ImportError("torch required for prediction parsing.") - - boxes = labels = scores = None - if isinstance(pred, dict): - # IMPORTANT: do not use `or` chains here. torch.Tensor truthiness is ambiguous and raises. - def _get_any(d: Dict[str, Any], *keys: str) -> Any: - for k in keys: - if k in d: - v = d.get(k) - if v is not None: - return v - return None - - boxes = _get_any(pred, "boxes", "bboxes", "bbox", "box") - labels = _get_any(pred, "labels", "classes", "class_ids", "class") - scores = _get_any(pred, "scores", "conf", "confidences", "score") - elif isinstance(pred, (list, tuple)) and len(pred) >= 3: - boxes, labels, scores = pred[0], pred[1], pred[2] - - if boxes is None or labels is None: - return [] - - def _to_tensor(x: Any) -> Optional["torch.Tensor"]: - if x is None: - return None - if isinstance(x, torch.Tensor): - return x.detach().cpu() - try: - return torch.as_tensor(x).detach().cpu() - except Exception: - return None - - b = _to_tensor(boxes) - l = _to_tensor(labels) # noqa: E741 - s = _to_tensor(scores) if scores is not None else None - if b is None or l is None: - return [] - - if b.ndim != 2 or int(b.shape[-1]) != 4: - return [] - if l.ndim == 2 and int(l.shape[-1]) == 1: - l = l.squeeze(-1) # noqa: E741 - if l.ndim != 1: - return [] - - n = int(min(b.shape[0], l.shape[0])) - dets: List[Dict[str, Any]] = [] - for i in range(n): - try: - x1, y1, x2, y2 = [float(x) for x in b[i].tolist()] - except Exception: - continue - - label_i: Optional[int] - try: - label_i = int(l[i].item()) - except Exception: - label_i = None - - score_f: Optional[float] - if s is not None and s.ndim >= 1 and int(s.shape[0]) > i: - try: - score_f = float(s[i].item()) - except Exception: - score_f = None - else: - score_f = None - - label_name = None - if label_i is not None and 0 <= label_i < len(label_names): - label_name = label_names[label_i] - if not label_name: - label_name = f"label_{label_i}" if label_i is not None else "unknown" - - dets.append( - { - "bbox_xyxy_norm": [x1, y1, x2, y2], - "label": label_i, - "label_name": str(label_name), - "score": score_f, - } - ) - return dets - - def _extract_remote_pred_item(response_item: Any) -> Any: if isinstance(response_item, dict): for k in ("prediction", "predictions", "output", "outputs", "data"): @@ -344,7 +201,7 @@ def detect_infographic_elements_v1( raise RuntimeError(f"Expected {len(valid)} remote predictions, got {len(response_items)}") for local_j, row_i in enumerate(valid): pred_item = _extract_remote_pred_item(response_items[local_j]) - dets = _prediction_to_detections(pred_item, label_names=label_names) + dets = prediction_to_detections(pred_item, label_names=label_names) payloads[row_i] = {"detections": dets, "timing": {"seconds": float(elapsed)}, "error": None} except BaseException as e: elapsed = time.perf_counter() - t0 @@ -393,7 +250,7 @@ def detect_infographic_elements_v1( if len(preds_list) != len(idxs): raise RuntimeError("Batched invoke returned unexpected output shape; falling back to per-image calls.") for local_j, row_i in enumerate(idxs): - dets = _prediction_to_detections(preds_list[local_j], label_names=label_names) + dets = prediction_to_detections(preds_list[local_j], label_names=label_names) payloads[row_i] = {"detections": dets, "timing": {"seconds": float(elapsed)}, "error": None} except BaseException: for local_j, row_i in enumerate(idxs): @@ -411,7 +268,7 @@ def detect_infographic_elements_v1( if isinstance(pre, torch.Tensor) and pre.ndim == 3: pre = pre.unsqueeze(0) pred = model.invoke(pre, sh) - dets = _prediction_to_detections(pred, label_names=label_names) + dets = prediction_to_detections(pred, label_names=label_names) payloads[row_i] = { "detections": dets, "timing": {"seconds": float(time.perf_counter() - t1)}, @@ -541,7 +398,7 @@ def detect_infographic_elements_v1_from_page_elements_v3( if not isinstance(bbox, (list, tuple)) or len(bbox) != 4: continue - crop_b64, crop_shape_hw = _crop_b64_image_by_norm_bbox( + crop_b64, crop_shape_hw = crop_b64_image_by_norm_bbox( page_image_b64, bbox_xyxy_norm=cast(Sequence[float], bbox) ) if not crop_b64 or crop_shape_hw is None: @@ -587,7 +444,7 @@ def detect_infographic_elements_v1_from_page_elements_v3( raise RuntimeError(f"Expected {len(crop_b64s)} remote predictions, got {len(response_items)}") for resp in response_items: pred_item = _extract_remote_pred_item(resp) - dets = _prediction_to_detections(pred_item, label_names=label_names) + dets = prediction_to_detections(pred_item, label_names=label_names) crop_payloads.append({"detections": dets, "timing": {"seconds": float(elapsed)}, "error": None}) except BaseException as e: elapsed = time.perf_counter() - t0 @@ -650,7 +507,7 @@ def detect_infographic_elements_v1_from_page_elements_v3( "Batched invoke returned unexpected output shape; falling back to per-image calls." ) for local_j, crop_i in enumerate(idxs): - dets = _prediction_to_detections(preds_list[local_j], label_names=label_names) + dets = prediction_to_detections(preds_list[local_j], label_names=label_names) crop_payloads[crop_i] = { "detections": dets, "timing": {"seconds": float(elapsed)}, @@ -672,7 +529,7 @@ def detect_infographic_elements_v1_from_page_elements_v3( if isinstance(pre, torch.Tensor) and pre.ndim == 3: pre = pre.unsqueeze(0) pred = model.invoke(pre, sh) - dets = _prediction_to_detections(pred, label_names=label_names) + dets = prediction_to_detections(pred, label_names=label_names) crop_payloads[crop_i] = { "detections": dets, "timing": {"seconds": float(time.perf_counter() - t1)}, diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py index 1557c81ad..703a51fde 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py @@ -142,7 +142,7 @@ def __call__(self, batch_df: Any) -> Any: return batch_df -class _BatchEmbedActor: +class BatchEmbedActor: """Ray Data actor that holds a local text embedder on a single GPU. When ``embedding_endpoint`` is provided in kwargs, the actor skips local @@ -754,7 +754,7 @@ def embed( embed_actor_num_gpus = self._requested_plan.get_embed_gpus_per_actor() self._rd_dataset = self._rd_dataset.map_batches( - _BatchEmbedActor, + BatchEmbedActor, batch_size=self._requested_plan.get_embed_batch_size(), batch_format="pandas", num_gpus=embed_actor_num_gpus, # pulled from if statement above diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/fused.py b/nemo_retriever/src/nemo_retriever/ingest_modes/fused.py index 68c33ad58..ff580a087 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/fused.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/fused.py @@ -28,7 +28,7 @@ from ..params import EmbedParams from ..params import ExtractParams from ..params import PdfSplitParams -from .batch import _BatchEmbedActor +from .batch import BatchEmbedActor from .batch import BatchIngestor from .inprocess import collapse_content_to_page_rows from .inprocess import embed_text_main_text_embed @@ -233,7 +233,7 @@ def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "FusedInges Run page-elements + OCR + explode + embed in one GPU actor stage. `fused` mode intentionally does not support remote NIM invocation. - When _pipeline_type == "audio", uses explode + _BatchEmbedActor (no PDF stages). + When _pipeline_type == "audio", uses explode + BatchEmbedActor (no PDF stages). """ resolved = params or EmbedParams(**kwargs) if params is not None and kwargs: @@ -262,7 +262,7 @@ def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "FusedInges num_gpus=0, ) self._rd_dataset = self._rd_dataset.map_batches( - _BatchEmbedActor, + BatchEmbedActor, batch_size=embed_batch_size, batch_format="pandas", num_cpus=embed_cpus_per_actor, diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index dd2854ae7..83b337f02 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -30,7 +30,7 @@ from nemo_retriever.model.local import NemotronOCRV1, NemotronPageElementsV3, NemotronParseV12 from nemo_retriever.chart.chart_detection import graphic_elements_ocr_page_elements from nemo_retriever.page_elements import detect_page_elements_v3 -from nemo_retriever.ocr.ocr import _crop_b64_image_by_norm_bbox, nemotron_parse_page_elements, ocr_page_elements +from nemo_retriever.ocr.ocr import crop_b64_image_by_norm_bbox, nemotron_parse_page_elements, ocr_page_elements from nemo_retriever.table.table_detection import table_structure_ocr_page_elements from nemo_retriever.text_embed.main_text_embed import TextEmbeddingConfig, create_text_embeddings_for_df @@ -57,7 +57,7 @@ from ..params import TextChunkParams from ..params import VdbUploadParams from ..pdf.extract import pdf_extraction -from ..pdf.split import _split_pdf_to_single_page_bytes, pdf_path_to_pages_df +from ..pdf.split import split_pdf_to_single_page_bytes, pdf_path_to_pages_df from ..txt import txt_file_to_chunks_df from ..html import html_file_to_chunks_df @@ -207,7 +207,7 @@ def explode_content_to_rows( if struct_mod in IMAGE_MODALITIES and page_image_b64: bbox = item.get("bbox_xyxy_norm") if bbox and len(bbox) == 4: - cropped_b64, _ = _crop_b64_image_by_norm_bbox(page_image_b64, bbox_xyxy_norm=bbox) + cropped_b64, _ = crop_b64_image_by_norm_bbox(page_image_b64, bbox_xyxy_norm=bbox) content_row["_image_b64"] = cropped_b64 else: content_row["_image_b64"] = page_image_b64 @@ -550,7 +550,7 @@ def pages_df_from_pdf_bytes(pdf_bytes: Union[bytes, bytearray], source_path: str Used by the online ingest mode to run the same pipeline on document bytes received via REST. Columns: bytes, path, page_number. """ - pages = _split_pdf_to_single_page_bytes(pdf_bytes) + pages = split_pdf_to_single_page_bytes(pdf_bytes) out_rows = [{"bytes": b, "path": source_path, "page_number": i + 1} for i, b in enumerate(pages)] return pd.DataFrame(out_rows) @@ -1649,7 +1649,7 @@ def _loader(p: str) -> pd.DataFrame: with open(abs_path, "rb") as f: file_bytes = f.read() pdf_bytes = convert_to_pdf_bytes(file_bytes, ext) - pages = _split_pdf_to_single_page_bytes(pdf_bytes) + pages = split_pdf_to_single_page_bytes(pdf_bytes) out_rows = [{"bytes": b, "path": abs_path, "page_number": i + 1} for i, b in enumerate(pages)] return pd.DataFrame(out_rows) except BaseException as e: diff --git a/nemo_retriever/src/nemo_retriever/ocr/__init__.py b/nemo_retriever/src/nemo_retriever/ocr/__init__.py index 49e6be1a4..bdffed321 100644 --- a/nemo_retriever/src/nemo_retriever/ocr/__init__.py +++ b/nemo_retriever/src/nemo_retriever/ocr/__init__.py @@ -2,6 +2,26 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from .ocr import OCRActor, ocr_page_elements +from .ocr import ( + OCRActor, + blocks_to_pseudo_markdown, + blocks_to_text, + crop_all_from_page, + crop_b64_image_by_norm_bbox, + extract_remote_ocr_item, + np_rgb_to_b64_png, + ocr_page_elements, + parse_ocr_result, +) -__all__ = ["OCRActor", "ocr_page_elements"] +__all__ = [ + "OCRActor", + "blocks_to_pseudo_markdown", + "blocks_to_text", + "crop_all_from_page", + "crop_b64_image_by_norm_bbox", + "extract_remote_ocr_item", + "np_rgb_to_b64_png", + "ocr_page_elements", + "parse_ocr_result", +] diff --git a/nemo_retriever/src/nemo_retriever/ocr/ocr.py b/nemo_retriever/src/nemo_retriever/ocr/ocr.py index a99955f24..b2ca71c0d 100644 --- a/nemo_retriever/src/nemo_retriever/ocr/ocr.py +++ b/nemo_retriever/src/nemo_retriever/ocr/ocr.py @@ -48,7 +48,7 @@ def _error_payload(*, stage: str, exc: BaseException) -> Dict[str, Any]: } -def _crop_b64_image_by_norm_bbox( +def crop_b64_image_by_norm_bbox( page_image_b64: str, *, bbox_xyxy_norm: Sequence[float], @@ -110,7 +110,7 @@ def _clamp_int(v: float, lo: int, hi: int) -> int: return None, None -def _crop_all_from_page( +def crop_all_from_page( page_image_b64: str, detections: List[Dict[str, Any]], wanted_labels: set, @@ -188,7 +188,7 @@ def _clamp_int(v: float, lo: int, hi: int) -> int: return results -def _np_rgb_to_b64_png(crop_array: np.ndarray) -> str: +def np_rgb_to_b64_png(crop_array: np.ndarray) -> str: if Image is None: # pragma: no cover raise ImportError("Pillow is required for image encoding.") img = Image.fromarray(crop_array.astype(np.uint8), mode="RGB") @@ -197,7 +197,7 @@ def _np_rgb_to_b64_png(crop_array: np.ndarray) -> str: return base64.b64encode(buf.getvalue()).decode("ascii") -def _extract_remote_ocr_item(response_item: Any) -> Any: +def extract_remote_ocr_item(response_item: Any) -> Any: if isinstance(response_item, dict): # NIM text_detections format: return full list (not v[0]) td = response_item.get("text_detections") @@ -212,7 +212,7 @@ def _extract_remote_ocr_item(response_item: Any) -> Any: return response_item -def _parse_ocr_result(preds: Any) -> List[Dict[str, Any]]: +def parse_ocr_result(preds: Any) -> List[Dict[str, Any]]: """ Parse the output of ``NemotronOCRV1.invoke()`` into a flat list of ``{"text": str, "sort_y": float, "sort_x": float}`` blocks. @@ -317,13 +317,13 @@ def _parse_ocr_result(preds: Any) -> List[Dict[str, Any]]: return blocks -def _blocks_to_text(blocks: List[Dict[str, Any]]) -> str: +def blocks_to_text(blocks: List[Dict[str, Any]]) -> str: """Sort text blocks by reading order (y then x) and join with newlines.""" blocks.sort(key=lambda b: (b.get("sort_y", 0.0), b.get("sort_x", 0.0))) return "\n".join(b["text"] for b in blocks if b.get("text")) -def _blocks_to_pseudo_markdown(blocks: List[Dict[str, Any]]) -> str: +def blocks_to_pseudo_markdown(blocks: List[Dict[str, Any]]) -> str: """Convert OCR text blocks into pseudo-markdown table format. Uses DBSCAN clustering on y-coordinates to identify rows, then @@ -501,13 +501,13 @@ def ocr_page_elements( continue # --- decode page image once, crop all matching detections --- - crops = _crop_all_from_page(page_image_b64, dets, wanted_labels) + crops = crop_all_from_page(page_image_b64, dets, wanted_labels) if use_remote: crop_b64s: List[str] = [] crop_meta: List[Tuple[str, List[float], Tuple[int, int]]] = [] for label_name, bbox, crop_array in crops: - crop_b64s.append(_np_rgb_to_b64_png(crop_array)) + crop_b64s.append(np_rgb_to_b64_png(crop_array)) crop_meta.append((label_name, bbox, (crop_array.shape[0], crop_array.shape[1]))) if crop_b64s: @@ -525,7 +525,7 @@ def ocr_page_elements( raise RuntimeError(f"Expected {len(crop_meta)} OCR responses, got {len(response_items)}") for i, (label_name, bbox, crop_hw) in enumerate(crop_meta): - preds = _extract_remote_ocr_item(response_items[i]) + preds = extract_remote_ocr_item(response_items[i]) if label_name == "chart" and use_graphic_elements: ge_dets = _find_ge_detections_for_bbox(row, bbox) @@ -535,11 +535,11 @@ def ocr_page_elements( chart_items.append({"bbox_xyxy_norm": bbox, "text": text}) continue - blocks = _parse_ocr_result(preds) + blocks = parse_ocr_result(preds) if label_name == "table": - text = _blocks_to_pseudo_markdown(blocks) or _blocks_to_text(blocks) + text = blocks_to_pseudo_markdown(blocks) or blocks_to_text(blocks) else: - text = _blocks_to_text(blocks) + text = blocks_to_text(blocks) entry = {"bbox_xyxy_norm": bbox, "text": text} if label_name == "table": table_items.append(entry) @@ -572,13 +572,13 @@ def _append_local_result( if text: chart_items.append({"bbox_xyxy_norm": bbox, "text": text}) return - blocks = _parse_ocr_result(preds) + blocks = parse_ocr_result(preds) if label_name == "table": - text = _blocks_to_pseudo_markdown(blocks) + text = blocks_to_pseudo_markdown(blocks) if not text: - text = _blocks_to_text(blocks) + text = blocks_to_text(blocks) else: - text = _blocks_to_text(blocks) + text = blocks_to_text(blocks) entry = {"bbox_xyxy_norm": bbox, "text": text} if label_name == "table": table_items.append(entry) @@ -832,7 +832,7 @@ def nemotron_parse_page_elements( all_meta.append({"timing": None, "error": None}) continue - crops = _crop_all_from_page(page_image_b64, dets, wanted_labels) + crops = crop_all_from_page(page_image_b64, dets, wanted_labels) # Parse-only mode may skip page-elements detection entirely. In that # case, parse the full page once and fan out the text to enabled # content channels. @@ -849,7 +849,7 @@ def nemotron_parse_page_elements( crop_b64s: List[str] = [] crop_meta: List[Tuple[str, List[float]]] = [] for label_name, bbox, crop_array in crops: - crop_b64s.append(_np_rgb_to_b64_png(crop_array)) + crop_b64s.append(np_rgb_to_b64_png(crop_array)) crop_meta.append((label_name, bbox)) if crop_b64s: diff --git a/nemo_retriever/src/nemo_retriever/pdf/split.py b/nemo_retriever/src/nemo_retriever/pdf/split.py index c750d54fc..ea04069f3 100644 --- a/nemo_retriever/src/nemo_retriever/pdf/split.py +++ b/nemo_retriever/src/nemo_retriever/pdf/split.py @@ -46,7 +46,7 @@ def _error_record( } -def _split_pdf_to_single_page_bytes(pdf_binary: Any) -> List[bytes]: +def split_pdf_to_single_page_bytes(pdf_binary: Any) -> List[bytes]: """ Split a PDF into single-page PDFs (raw bytes) using pypdfium2. """ @@ -103,7 +103,7 @@ def pdf_path_to_pages_df(path: str) -> pd.DataFrame: out_rows: List[Dict[str, Any]] = [] try: raw_bytes = Path(abs_path).read_bytes() - pages = _split_pdf_to_single_page_bytes(raw_bytes) + pages = split_pdf_to_single_page_bytes(raw_bytes) for page_idx, page_bytes in enumerate(pages): out_rows.append( { @@ -141,7 +141,7 @@ def split_pdf_batch(pdf_batch: Any, params: PdfSplitParams | None = None) -> pd. if not isinstance(pdf_bytes, (bytes, bytearray, memoryview)): raise ValueError(f"Unsupported bytes payload type: {type(pdf_bytes)!r}") - pages = _split_pdf_to_single_page_bytes(pdf_bytes) + pages = split_pdf_to_single_page_bytes(pdf_bytes) start_idx = 0 if start_page is None else max(int(start_page) - 1, 0) end_idx = (len(pages) - 1) if end_page is None else min(int(end_page) - 1, len(pages) - 1) if len(pages) == 0 or start_idx > end_idx: diff --git a/nemo_retriever/src/nemo_retriever/recall/vdb_recall.py b/nemo_retriever/src/nemo_retriever/recall/vdb_recall.py index f223cbf1a..a2d212fda 100644 --- a/nemo_retriever/src/nemo_retriever/recall/vdb_recall.py +++ b/nemo_retriever/src/nemo_retriever/recall/vdb_recall.py @@ -12,7 +12,7 @@ import pandas as pd # noqa: F401 from rich.console import Console -from .core import RecallConfig, evaluate_recall, retrieve_and_score, _normalize_query_df # noqa: F401 +from .core import RecallConfig, retrieve_and_score app = typer.Typer(help="Embed query CSV rows, search LanceDB, print hits, and compute recall@k.") console = Console() diff --git a/nemo_retriever/src/nemo_retriever/table/table_detection.py b/nemo_retriever/src/nemo_retriever/table/table_detection.py index bfb82a187..974ecfefa 100644 --- a/nemo_retriever/src/nemo_retriever/table/table_detection.py +++ b/nemo_retriever/src/nemo_retriever/table/table_detection.py @@ -11,6 +11,7 @@ import pandas as pd from nemo_retriever.params import RemoteRetryParams +from nemo_retriever.utils.detection import prediction_to_detections try: import torch @@ -41,114 +42,6 @@ def _labels_from_model(model: Any) -> List[str]: return [] -def _prediction_to_detections(pred: Any, *, label_names: List[str]) -> List[Dict[str, Any]]: - """ - Best-effort conversion of model output into a standard detection list. - - Produces dicts of the form: - {"bbox_xyxy_norm": [...], "label": int|None, "label_name": str, "score": float|None} - """ - if torch is None: # pragma: no cover - raise ImportError("torch required for prediction parsing.") - - boxes = labels = scores = None - if isinstance(pred, dict): - # IMPORTANT: do not use `or` chains here. torch.Tensor truthiness is ambiguous and raises. - def _get_any(d: Dict[str, Any], *keys: str) -> Any: - for k in keys: - if k in d: - v = d.get(k) - if v is not None: - return v - return None - - boxes = _get_any(pred, "boxes", "bboxes", "bbox", "box") - labels = _get_any(pred, "labels", "classes", "class_ids", "class") - scores = _get_any(pred, "scores", "conf", "confidences", "score") - elif isinstance(pred, (list, tuple)) and len(pred) >= 3: - boxes, labels, scores = pred[0], pred[1], pred[2] - - if boxes is None or labels is None: - return [] - - # Normalize to torch tensors. - def _to_tensor(x: Any) -> Optional["torch.Tensor"]: - if x is None: - return None - if isinstance(x, torch.Tensor): - return x.detach().cpu() - try: - return torch.as_tensor(x).detach().cpu() - except Exception: - return None - - # Handle string labels (e.g. NIM returns ["cell", "row", "column", ...]). - # torch.as_tensor cannot convert strings, so handle them before tensor conversion. - _string_labels: Optional[List[str]] = None - if isinstance(labels, (list, tuple)) and labels and isinstance(labels[0], str): - _string_labels = [str(x) for x in labels] - - b = _to_tensor(boxes) - labels_t = _to_tensor(labels) if _string_labels is None else None - s = _to_tensor(scores) if scores is not None else None - if b is None: - return [] - if labels_t is None and _string_labels is None: - return [] - - # Expect boxes (N,4), labels (N,) - if b.ndim != 2 or int(b.shape[-1]) != 4: - return [] - if labels_t is not None: - if labels_t.ndim == 2 and int(labels_t.shape[-1]) == 1: - labels_t = labels_t.squeeze(-1) - if labels_t.ndim != 1: - return [] - - n_labels = len(_string_labels) if _string_labels is not None else int(labels_t.shape[0]) - n = int(min(b.shape[0], n_labels)) - dets: List[Dict[str, Any]] = [] - for i in range(n): - try: - x1, y1, x2, y2 = [float(x) for x in b[i].tolist()] - except Exception: - continue - - if _string_labels is not None: - label_i = i - label_name = _string_labels[i] - else: - try: - label_i = int(labels_t[i].item()) - except Exception: - label_i = None - - label_name = None - if label_i is not None and 0 <= label_i < len(label_names): - label_name = label_names[label_i] - if not label_name: - label_name = f"label_{label_i}" if label_i is not None else "unknown" - - score_f: Optional[float] - if s is not None and s.ndim >= 1 and int(s.shape[0]) > i: - try: - score_f = float(s[i].item()) - except Exception: - score_f = None - else: - score_f = None - - dets.append( - { - "bbox_xyxy_norm": [x1, y1, x2, y2], - "label": label_i, - "label_name": str(label_name), - "score": score_f, - } - ) - return dets - - def _parse_nim_bounding_boxes(response_item: Any) -> List[Dict[str, Any]]: """Parse the ``bounding_boxes`` NIM response format. @@ -255,11 +148,11 @@ def table_structure_ocr_page_elements( """ from nemo_retriever.nim.nim import invoke_image_inference_batches from nemo_retriever.ocr.ocr import ( - _blocks_to_pseudo_markdown, - _crop_all_from_page, - _extract_remote_ocr_item, - _np_rgb_to_b64_png, - _parse_ocr_result, + blocks_to_pseudo_markdown, + crop_all_from_page, + extract_remote_ocr_item, + np_rgb_to_b64_png, + parse_ocr_result, ) from nemo_retriever.utils.table_and_chart import join_table_structure_and_ocr_output @@ -316,7 +209,7 @@ def table_structure_ocr_page_elements( continue # --- Pass 1: Collect table crops --- - crops = _crop_all_from_page(page_image_b64, dets, {"table"}) + crops = crop_all_from_page(page_image_b64, dets, {"table"}) if not crops: all_table.append(table_items) @@ -325,7 +218,7 @@ def table_structure_ocr_page_elements( # Pre-compute base64 encodings once for remote paths. crop_b64s = ( - [_np_rgb_to_b64_png(crop_array) for _, _, crop_array in crops] + [np_rgb_to_b64_png(crop_array) for _, _, crop_array in crops] if (use_remote_ts or use_remote_ocr) else [] ) @@ -350,7 +243,7 @@ def table_structure_ocr_page_elements( parsed = _parse_nim_bounding_boxes(resp) if not parsed: pred_item = _extract_remote_pred_item(resp) - parsed = _prediction_to_detections(pred_item, label_names=label_names) + parsed = prediction_to_detections(pred_item, label_names=label_names) structure_results.append(parsed) else: # Local batched inference. @@ -365,7 +258,7 @@ def table_structure_ocr_page_elements( if isinstance(pre, torch.Tensor) and pre.ndim == 3: pre = pre.unsqueeze(0) pred = table_structure_model.invoke(pre, (h, w)) - dets = _prediction_to_detections(pred, label_names=label_names) + dets = prediction_to_detections(pred, label_names=label_names) structure_results.append(dets) # --- Pass 3: Run OCR on all crops --- @@ -384,7 +277,7 @@ def table_structure_ocr_page_elements( if len(ocr_response_items) != len(crops): raise RuntimeError(f"Expected {len(crops)} OCR responses, got {len(ocr_response_items)}") for resp in ocr_response_items: - ocr_results.append(_extract_remote_ocr_item(resp)) + ocr_results.append(extract_remote_ocr_item(resp)) else: for _, _, crop_array in crops: ocr_results.append(ocr_model.invoke(crop_array, merge_level="word")) @@ -400,13 +293,13 @@ def table_structure_ocr_page_elements( # Fallback: if no cells were detected, use OCR-only pseudo-markdown. if not markdown: - blocks = _parse_ocr_result(ocr_preds) - markdown = _blocks_to_pseudo_markdown(blocks) + blocks = parse_ocr_result(ocr_preds) + markdown = blocks_to_pseudo_markdown(blocks) if not markdown: # Last resort: plain text. - from nemo_retriever.ocr.ocr import _blocks_to_text + from nemo_retriever.ocr.ocr import blocks_to_text - markdown = _blocks_to_text(blocks) + markdown = blocks_to_text(blocks) table_items.append({"bbox_xyxy_norm": bbox, "text": markdown}) diff --git a/nemo_retriever/src/nemo_retriever/txt/split.py b/nemo_retriever/src/nemo_retriever/txt/split.py index d47b8dfd3..7c5cb4172 100644 --- a/nemo_retriever/src/nemo_retriever/txt/split.py +++ b/nemo_retriever/src/nemo_retriever/txt/split.py @@ -22,7 +22,7 @@ DEFAULT_OVERLAP_TOKENS = 0 -def _get_tokenizer(model_id: str, cache_dir: Optional[str] = None): # noqa: ANN201 +def get_tokenizer(model_id: str, cache_dir: Optional[str] = None): # noqa: ANN201 """Lazy-load HuggingFace tokenizer.""" from transformers import AutoTokenizer @@ -131,7 +131,7 @@ def txt_file_to_chunks_df( path = str(Path(path).resolve()) raw = Path(path).read_text(encoding=encoding, errors="replace") model_id = tokenizer_model_id or DEFAULT_TOKENIZER_MODEL_ID - tokenizer = _get_tokenizer(model_id, cache_dir=tokenizer_cache_dir) + tokenizer = get_tokenizer(model_id, cache_dir=tokenizer_cache_dir) chunk_texts = split_text_by_tokens( raw, tokenizer=tokenizer, @@ -183,7 +183,7 @@ def txt_bytes_to_chunks_df( path = str(Path(path).resolve()) raw = content_bytes.decode(encoding, errors="replace") model_id = tokenizer_model_id or DEFAULT_TOKENIZER_MODEL_ID - tokenizer = _get_tokenizer(model_id, cache_dir=tokenizer_cache_dir) + tokenizer = get_tokenizer(model_id, cache_dir=tokenizer_cache_dir) chunk_texts = split_text_by_tokens( raw, tokenizer=tokenizer, diff --git a/nemo_retriever/src/nemo_retriever/utils/detection.py b/nemo_retriever/src/nemo_retriever/utils/detection.py new file mode 100644 index 000000000..882cfe5ae --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/utils/detection.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +try: + import torch +except Exception: # pragma: no cover + torch = None # type: ignore[assignment] + + +def prediction_to_detections(pred: Any, *, label_names: List[str]) -> List[Dict[str, Any]]: + """ + Best-effort conversion of model output into a standard detection list. + + Produces dicts of the form: + {"bbox_xyxy_norm": [...], "label": int|None, "label_name": str, "score": float|None} + """ + if torch is None: # pragma: no cover + raise ImportError("torch required for prediction parsing.") + + boxes = labels = scores = None + if isinstance(pred, dict): + # IMPORTANT: do not use `or` chains here. torch.Tensor truthiness is ambiguous and raises. + def _get_any(d: Dict[str, Any], *keys: str) -> Any: + for k in keys: + if k in d: + v = d.get(k) + if v is not None: + return v + return None + + boxes = _get_any(pred, "boxes", "bboxes", "bbox", "box") + labels = _get_any(pred, "labels", "classes", "class_ids", "class") + scores = _get_any(pred, "scores", "conf", "confidences", "score") + elif isinstance(pred, (list, tuple)) and len(pred) >= 3: + boxes, labels, scores = pred[0], pred[1], pred[2] + + if boxes is None or labels is None: + return [] + + def _to_tensor(x: Any) -> Optional["torch.Tensor"]: + if x is None: + return None + if isinstance(x, torch.Tensor): + return x.detach().cpu() + try: + return torch.as_tensor(x).detach().cpu() + except Exception: + return None + + # Handle string labels (e.g. NIM returns ["chart_title", "xlabel", ...]). + # torch.as_tensor cannot convert strings, so handle them before tensor conversion. + _string_labels: Optional[List[str]] = None + if isinstance(labels, (list, tuple)) and labels and isinstance(labels[0], str): + _string_labels = [str(x) for x in labels] + + b = _to_tensor(boxes) + labels_t = _to_tensor(labels) if _string_labels is None else None + s = _to_tensor(scores) if scores is not None else None + if b is None: + return [] + if labels_t is None and _string_labels is None: + return [] + + if b.ndim != 2 or int(b.shape[-1]) != 4: + return [] + if labels_t is not None: + if labels_t.ndim == 2 and int(labels_t.shape[-1]) == 1: + labels_t = labels_t.squeeze(-1) + if labels_t.ndim != 1: + return [] + + n_labels = len(_string_labels) if _string_labels is not None else int(labels_t.shape[0]) + n = int(min(b.shape[0], n_labels)) + dets: List[Dict[str, Any]] = [] + for i in range(n): + try: + x1, y1, x2, y2 = [float(x) for x in b[i].tolist()] + except Exception: + continue + + if _string_labels is not None: + label_i = i + label_name = _string_labels[i] + else: + try: + label_i = int(labels_t[i].item()) + except Exception: + label_i = None + + label_name = None + if label_i is not None and 0 <= label_i < len(label_names): + label_name = label_names[label_i] + if not label_name: + label_name = f"label_{label_i}" if label_i is not None else "unknown" + + score_f: Optional[float] + if s is not None and s.ndim >= 1 and int(s.shape[0]) > i: + try: + score_f = float(s[i].item()) + except Exception: + score_f = None + else: + score_f = None + + dets.append( + { + "bbox_xyxy_norm": [x1, y1, x2, y2], + "label": label_i, + "label_name": str(label_name), + "score": score_f, + } + ) + return dets diff --git a/nemo_retriever/src/nemo_retriever/utils/table_and_chart.py b/nemo_retriever/src/nemo_retriever/utils/table_and_chart.py index f28f05567..bcf26abf1 100644 --- a/nemo_retriever/src/nemo_retriever/utils/table_and_chart.py +++ b/nemo_retriever/src/nemo_retriever/utils/table_and_chart.py @@ -541,7 +541,7 @@ def _structure_dets_to_class_boxes( Parameters ---------- dets : list[dict] - Output of ``_prediction_to_detections()`` — each dict has + Output of ``prediction_to_detections()`` — each dict has ``bbox_xyxy_norm`` (normalized [0,1]) and ``label_name``. crop_hw : (int, int) ``(H, W)`` of the crop image. @@ -575,7 +575,7 @@ def join_table_structure_and_ocr_output( Parameters ---------- structure_dets : list[dict] - From ``_prediction_to_detections()`` with label_names cell/row/column + From ``prediction_to_detections()`` with label_names cell/row/column and ``bbox_xyxy_norm`` in [0, 1]. ocr_preds : list | dict Raw OCR output from ``NemotronOCRV1.invoke()``. @@ -613,7 +613,7 @@ def join_graphic_elements_and_ocr_output( Parameters ---------- ge_dets : list[dict] - From ``_prediction_to_detections()`` with chart-element label_names + From ``prediction_to_detections()`` with chart-element label_names and ``bbox_xyxy_norm`` in [0, 1]. ocr_preds : list | dict Raw OCR output from ``NemotronOCRV1.invoke()``. diff --git a/nemo_retriever/tests/test_audio_pipeline_batch.py b/nemo_retriever/tests/test_audio_pipeline_batch.py index 09ce1a7ad..9325ca610 100644 --- a/nemo_retriever/tests/test_audio_pipeline_batch.py +++ b/nemo_retriever/tests/test_audio_pipeline_batch.py @@ -177,7 +177,7 @@ def test_inprocess_audio_pipeline_local_asr_mocked(tmp_path: Path): @pytest.mark.skipif(not is_media_available(), reason="ffmpeg not available") def test_fused_audio_pipeline_with_mocked_asr(tmp_path: Path): - """Fused: same as batch but FusedIngestor; embed() uses explode + _BatchEmbedActor when _pipeline_type==audio.""" + """Fused: same as batch but FusedIngestor; embed() uses explode + BatchEmbedActor when _pipeline_type==audio.""" ray = pytest.importorskip("ray") pytest.importorskip("lancedb") diff --git a/nemo_retriever/tests/test_chart_graphic_elements.py b/nemo_retriever/tests/test_chart_graphic_elements.py index 25fd499a7..b462af553 100644 --- a/nemo_retriever/tests/test_chart_graphic_elements.py +++ b/nemo_retriever/tests/test_chart_graphic_elements.py @@ -382,7 +382,7 @@ def test_graphic_elements_flag_does_not_affect_table_stages(self) -> None: # --------------------------------------------------------------------------- -# _prediction_to_detections string labels test +# prediction_to_detections string labels test # --------------------------------------------------------------------------- @@ -390,28 +390,28 @@ def test_graphic_elements_flag_does_not_affect_table_stages(self) -> None: class TestPredictionToDetectionsStringLabels: def test_string_labels_handled(self) -> None: import torch - from nemo_retriever.chart.chart_detection import _prediction_to_detections + from nemo_retriever.utils.detection import prediction_to_detections pred = { "boxes": torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]), "labels": ["chart_title", "xlabel"], "scores": torch.tensor([0.9, 0.8]), } - dets = _prediction_to_detections(pred, label_names=[]) + dets = prediction_to_detections(pred, label_names=[]) assert len(dets) == 2 assert dets[0]["label_name"] == "chart_title" assert dets[1]["label_name"] == "xlabel" def test_integer_labels_still_work(self) -> None: import torch - from nemo_retriever.chart.chart_detection import _prediction_to_detections + from nemo_retriever.utils.detection import prediction_to_detections pred = { "boxes": torch.tensor([[0.1, 0.2, 0.3, 0.4]]), "labels": torch.tensor([1]), "scores": torch.tensor([0.9]), } - dets = _prediction_to_detections(pred, label_names=["chart_title", "xlabel"]) + dets = prediction_to_detections(pred, label_names=["chart_title", "xlabel"]) assert len(dets) == 1 assert dets[0]["label_name"] == "xlabel" assert dets[0]["label"] == 1 diff --git a/nemo_retriever/tests/test_multimodal_embed.py b/nemo_retriever/tests/test_multimodal_embed.py index f357193ef..caa8259ed 100644 --- a/nemo_retriever/tests/test_multimodal_embed.py +++ b/nemo_retriever/tests/test_multimodal_embed.py @@ -204,7 +204,7 @@ def test_text_mode_tags_modality(self): assert list(result["_embed_modality"]) == ["text", "text"] assert "_image_b64" not in result.columns - @patch("nemo_retriever.ingest_modes.inprocess._crop_b64_image_by_norm_bbox") + @patch("nemo_retriever.ingest_modes.inprocess.crop_b64_image_by_norm_bbox") def test_text_image_carries_image(self, mock_crop): """text_image mode copies page image to _image_b64, crops for structured content.""" mock_crop.return_value = ("cropped_b64", None)