diff --git a/docling_eval/dataset_builders/dataset_builder.py b/docling_eval/dataset_builders/dataset_builder.py index acd6d829..ea881046 100644 --- a/docling_eval/dataset_builders/dataset_builder.py +++ b/docling_eval/dataset_builders/dataset_builder.py @@ -7,17 +7,10 @@ from typing import TYPE_CHECKING, Iterable, Optional, Union from docling.utils.utils import chunkify -from docling_core.types.doc.document import ImageRefMode from huggingface_hub import snapshot_download from pydantic import BaseModel -from docling_eval.datamodels.dataset_record import ( - DatasetRecord, - DatasetRecordWithPrediction, -) -from docling_eval.prediction_providers.base_prediction_provider import ( - TRUE_HTML_EXPORT_LABELS, -) +from docling_eval.datamodels.dataset_record import DatasetRecord from docling_eval.utils.utils import ( insert_images_from_pil, save_shard_to_disk, @@ -192,23 +185,20 @@ def retrieve_input_dataset(self) -> Path: Path to the retrieved dataset """ if isinstance(self.dataset_source, HFSource): + download_kwargs = { + "repo_id": self.dataset_source.repo_id, + "revision": self.dataset_source.revision, + "repo_type": "dataset", + "token": self.dataset_source.hf_token, + } + if not self.dataset_local_path: - path_str = snapshot_download( - repo_id=self.dataset_source.repo_id, - revision=self.dataset_source.revision, - repo_type="dataset", - token=self.dataset_source.hf_token, - ) + path_str = snapshot_download(**download_kwargs) path: Path = Path(path_str) self.dataset_local_path = path else: - path_str = snapshot_download( - repo_id=self.dataset_source.repo_id, - revision=self.dataset_source.revision, - repo_type="dataset", - token=self.dataset_source.hf_token, - local_dir=self.dataset_local_path, - ) + download_kwargs["local_dir"] = str(self.dataset_local_path) + path_str = snapshot_download(**download_kwargs) path = Path(path_str) elif isinstance(self.dataset_source, Path): path = self.dataset_source diff --git a/docling_eval/dataset_builders/omnidocbench_builder.py b/docling_eval/dataset_builders/omnidocbench_builder.py index 68d0d9a7..d18e0fbe 100644 --- a/docling_eval/dataset_builders/omnidocbench_builder.py +++ b/docling_eval/dataset_builders/omnidocbench_builder.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Dict, Iterable, List, Tuple +from datasets import load_dataset from docling_core.types import DoclingDocument from docling_core.types.doc import ( BoundingBox, @@ -17,6 +18,8 @@ Size, ) from docling_core.types.io import DocumentStream +from huggingface_hub import snapshot_download +from PIL import Image as PILImage from PIL.Image import Image from tqdm import tqdm @@ -88,11 +91,18 @@ class OmniDocBenchDatasetBuilder(BaseEvaluationDatasetBuilder): This builder processes the OmniDocBench dataset, which contains document layout annotations for a variety of document types. + + Supports two modes: + - Raw mode: Downloads raw files via snapshot_download (many requests) + - Parquet mode: Downloads Parquet shards and extracts files (few requests) """ def __init__( self, target: Path, + repo_id: str = "opendatalab/OmniDocBench", + revision: str = "v1_0", + use_parquet: bool = False, split: str = "test", begin_index: int = 0, end_index: int = -1, @@ -102,23 +112,58 @@ def __init__( Args: target: Path where processed dataset will be saved + repo_id: HuggingFace repository ID (default: opendatalab/OmniDocBench) + revision: Repository revision/branch + use_parquet: If True, download Parquet and extract files (avoids rate limits) split: Dataset split to use begin_index: Start index for processing (inclusive) end_index: End index for processing (exclusive), -1 means process all """ super().__init__( name="OmniDocBench: end-to-end", - dataset_source=HFSource( - repo_id="opendatalab/OmniDocBench", revision="v1_0" - ), + dataset_source=HFSource(repo_id=repo_id, revision=revision), target=target, split=split, begin_index=begin_index, end_index=end_index, ) + self.use_parquet = use_parquet self.must_retrieve = True + def retrieve_input_dataset(self) -> Path: + """ + Download and retrieve the input dataset. + + In Parquet mode, this is a no-op since iterate() loads data directly. + In raw mode, downloads all files via snapshot_download. + """ + if self.use_parquet: + # Parquet mode: iterate() uses load_dataset directly, no download needed + _log.info("Parquet mode: skipping download (data loaded in iterate)") + self.retrieved = True + return self.target + + # Raw mode: download all raw files + if not self.dataset_local_path: + self.dataset_local_path = self.target / "source_data" + + self.dataset_local_path.mkdir(parents=True, exist_ok=True) + + _log.info("Downloading files (raw mode)...") + assert isinstance( + self.dataset_source, HFSource + ), "dataset_source must be HFSource" + snapshot_download( + repo_id=self.dataset_source.repo_id, + revision=self.dataset_source.revision, + repo_type="dataset", + token=self.dataset_source.hf_token, + local_dir=str(self.dataset_local_path), + ) + self.retrieved = True + return self.dataset_local_path + def update_gt_into_map(self, gt: List[Dict]) -> Dict[str, Dict]: """ Convert list of annotation items to a map keyed by image path. @@ -330,6 +375,11 @@ def iterate(self) -> Iterable[DatasetRecord]: Yields: DatasetRecord objects """ + # Parquet mode: use load_dataset directly + if self.use_parquet: + yield from self._iterate_parquet() + return + if not self.retrieved and self.must_retrieve: raise RuntimeError( "You must first retrieve the source dataset. Call retrieve_input_dataset()." @@ -421,3 +471,83 @@ def iterate(self) -> Iterable[DatasetRecord]: ) yield record + + def _iterate_parquet(self) -> Iterable[DatasetRecord]: + """ + Iterate through the Parquet dataset and yield DatasetRecord objects. + + This method loads data directly via load_dataset, avoiding rate limits + from downloading many individual files. + """ + _log.info("Loading dataset via load_dataset (Parquet mode)...") + + assert isinstance( + self.dataset_source, HFSource + ), "dataset_source must be HFSource" + ds = load_dataset( + self.dataset_source.repo_id, + split="train", + ) + + total_items = len(ds) + begin, end = self.get_effective_indices(total_items) + ds = ds.select(range(begin, end)) + selected_items = len(ds) + + self.log_dataset_stats(total_items, selected_items) + + for item in tqdm( + ds, total=selected_items, ncols=128, desc="Processing Parquet records" + ): + filename = item["filename"] + gt_data = json.loads(item["ground_truth"]) + pdf_bytes = item["pdf"] + page_image: PILImage.Image = item["image"] + + # Create document and add page + true_doc = DoclingDocument(name=f"ground-truth {filename}") + page_image_rgb = page_image.convert("RGB") + page_width = float(page_image_rgb.width) + page_height = float(page_image_rgb.height) + + page_item = PageItem( + page_no=1, + size=Size(width=page_width, height=page_height), + ) + true_doc.pages[1] = page_item + + # Update document with ground truth + true_doc = self.update_doc_with_gt( + gt=gt_data, + true_doc=true_doc, + page=true_doc.pages[1], + page_image=page_image_rgb, + page_width=page_width, + page_height=page_height, + ) + + # Extract images from the ground truth document + true_doc, true_pictures, true_page_images = extract_images( + document=true_doc, + pictures_column=BenchMarkColumns.GROUNDTRUTH_PICTURES.value, + page_images_column=BenchMarkColumns.GROUNDTRUTH_PAGE_IMAGES.value, + ) + + # Create PDF stream from bytes + pdf_stream = DocumentStream( + name=Path(filename).stem + ".pdf", + stream=BytesIO(pdf_bytes), + ) + + # Create dataset record + record = DatasetRecord( + doc_id=filename, + doc_hash=get_binhash(pdf_bytes), + ground_truth_doc=true_doc, + ground_truth_pictures=true_pictures, + ground_truth_page_images=true_page_images, + original=pdf_stream, + mime_type="application/pdf", + ) + + yield record diff --git a/tests/test_omnidocbench_parquet.py b/tests/test_omnidocbench_parquet.py new file mode 100644 index 00000000..3260711a --- /dev/null +++ b/tests/test_omnidocbench_parquet.py @@ -0,0 +1,92 @@ +import os +from pathlib import Path + +import pytest + +from docling_eval.cli.main import evaluate, get_prediction_provider, visualize +from docling_eval.datamodels.types import ( + BenchMarkNames, + EvaluationModality, + PredictionProviderType, +) +from docling_eval.dataset_builders.omnidocbench_builder import ( + OmniDocBenchDatasetBuilder, +) + +IS_CI = bool(os.getenv("CI")) + + +@pytest.mark.skipif( + IS_CI, reason="Skipping test in CI because the dataset is too heavy." +) +def test_run_omnidocbench_parquet_e2e(): + """ + Test OmniDocBench with Parquet mode (use_parquet=True). + + This uses the `samiuc/OmniDocBench-parquet` dataset which contains + filename, image, pdf, and ground_truth columns in Parquet format, + avoiding HuggingFace rate limits from downloading individual files. + """ + target_path = Path(f"./scratch/{BenchMarkNames.OMNIDOCBENCH.value}-parquet/") + docling_provider = get_prediction_provider(PredictionProviderType.DOCLING) + + dataset_layout = OmniDocBenchDatasetBuilder( + target=target_path / "gt_dataset", + repo_id="samiuc/OmniDocBench-parquet", + use_parquet=True, + end_index=5, + ) + + dataset_layout.retrieve_input_dataset() # No-op in Parquet mode + dataset_layout.save_to_disk() # Iterates dataset and saves as parquet shards + + docling_provider.create_prediction_dataset( + name=dataset_layout.name, + gt_dataset_dir=target_path / "gt_dataset", + target_dataset_dir=target_path / "eval_dataset", + ) + + # Evaluate Layout + evaluate( + modality=EvaluationModality.LAYOUT, + benchmark=BenchMarkNames.OMNIDOCBENCH, + idir=target_path / "eval_dataset", + odir=target_path / "evaluations" / EvaluationModality.LAYOUT.value, + ) + + visualize( + modality=EvaluationModality.LAYOUT, + benchmark=BenchMarkNames.OMNIDOCBENCH, + idir=target_path / "eval_dataset", + odir=target_path / "evaluations" / EvaluationModality.LAYOUT.value, + ) + + # Evaluate Reading Order + evaluate( + modality=EvaluationModality.READING_ORDER, + benchmark=BenchMarkNames.OMNIDOCBENCH, + idir=target_path / "eval_dataset", + odir=target_path / "evaluations" / EvaluationModality.READING_ORDER.value, + ) + + visualize( + modality=EvaluationModality.READING_ORDER, + benchmark=BenchMarkNames.OMNIDOCBENCH, + idir=target_path / "eval_dataset", + odir=target_path / "evaluations" / EvaluationModality.READING_ORDER.value, + ) + + # Evaluate Markdown Text + evaluate( + modality=EvaluationModality.MARKDOWN_TEXT, + benchmark=BenchMarkNames.OMNIDOCBENCH, + idir=target_path / "eval_dataset", + odir=target_path / "evaluations" / EvaluationModality.MARKDOWN_TEXT.value, + ) + + visualize( + modality=EvaluationModality.MARKDOWN_TEXT, + benchmark=BenchMarkNames.OMNIDOCBENCH, + idir=target_path / "eval_dataset", + odir=target_path / "evaluations" / EvaluationModality.MARKDOWN_TEXT.value, + )