Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 11 additions & 21 deletions docling_eval/dataset_builders/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,10 @@
from typing import TYPE_CHECKING, Iterable, Optional, Union

from docling.utils.utils import chunkify
from docling_core.types.doc.document import ImageRefMode
from huggingface_hub import snapshot_download
from pydantic import BaseModel

from docling_eval.datamodels.dataset_record import (
DatasetRecord,
DatasetRecordWithPrediction,
)
from docling_eval.prediction_providers.base_prediction_provider import (
TRUE_HTML_EXPORT_LABELS,
)
from docling_eval.datamodels.dataset_record import DatasetRecord
from docling_eval.utils.utils import (
insert_images_from_pil,
save_shard_to_disk,
Expand Down Expand Up @@ -192,23 +185,20 @@ def retrieve_input_dataset(self) -> Path:
Path to the retrieved dataset
"""
if isinstance(self.dataset_source, HFSource):
download_kwargs = {
"repo_id": self.dataset_source.repo_id,
"revision": self.dataset_source.revision,
"repo_type": "dataset",
"token": self.dataset_source.hf_token,
}

if not self.dataset_local_path:
path_str = snapshot_download(
repo_id=self.dataset_source.repo_id,
revision=self.dataset_source.revision,
repo_type="dataset",
token=self.dataset_source.hf_token,
)
path_str = snapshot_download(**download_kwargs)
path: Path = Path(path_str)
self.dataset_local_path = path
else:
path_str = snapshot_download(
repo_id=self.dataset_source.repo_id,
revision=self.dataset_source.revision,
repo_type="dataset",
token=self.dataset_source.hf_token,
local_dir=self.dataset_local_path,
)
download_kwargs["local_dir"] = str(self.dataset_local_path)
path_str = snapshot_download(**download_kwargs)
path = Path(path_str)
elif isinstance(self.dataset_source, Path):
path = self.dataset_source
Expand Down
136 changes: 133 additions & 3 deletions docling_eval/dataset_builders/omnidocbench_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import Dict, Iterable, List, Tuple

from datasets import load_dataset
from docling_core.types import DoclingDocument
from docling_core.types.doc import (
BoundingBox,
Expand All @@ -17,6 +18,8 @@
Size,
)
from docling_core.types.io import DocumentStream
from huggingface_hub import snapshot_download
from PIL import Image as PILImage
from PIL.Image import Image
from tqdm import tqdm

Expand Down Expand Up @@ -88,11 +91,18 @@ class OmniDocBenchDatasetBuilder(BaseEvaluationDatasetBuilder):

This builder processes the OmniDocBench dataset, which contains document
layout annotations for a variety of document types.

Supports two modes:
- Raw mode: Downloads raw files via snapshot_download (many requests)
- Parquet mode: Downloads Parquet shards and extracts files (few requests)
"""

def __init__(
self,
target: Path,
repo_id: str = "opendatalab/OmniDocBench",
revision: str = "v1_0",
use_parquet: bool = False,
split: str = "test",
begin_index: int = 0,
end_index: int = -1,
Expand All @@ -102,23 +112,58 @@ def __init__(

Args:
target: Path where processed dataset will be saved
repo_id: HuggingFace repository ID (default: opendatalab/OmniDocBench)
revision: Repository revision/branch
use_parquet: If True, download Parquet and extract files (avoids rate limits)
split: Dataset split to use
begin_index: Start index for processing (inclusive)
end_index: End index for processing (exclusive), -1 means process all
"""
super().__init__(
name="OmniDocBench: end-to-end",
dataset_source=HFSource(
repo_id="opendatalab/OmniDocBench", revision="v1_0"
),
dataset_source=HFSource(repo_id=repo_id, revision=revision),
target=target,
split=split,
begin_index=begin_index,
end_index=end_index,
)

self.use_parquet = use_parquet
self.must_retrieve = True

def retrieve_input_dataset(self) -> Path:
"""
Download and retrieve the input dataset.

In Parquet mode, this is a no-op since iterate() loads data directly.
In raw mode, downloads all files via snapshot_download.
"""
if self.use_parquet:
# Parquet mode: iterate() uses load_dataset directly, no download needed
_log.info("Parquet mode: skipping download (data loaded in iterate)")
self.retrieved = True
return self.target

# Raw mode: download all raw files
if not self.dataset_local_path:
self.dataset_local_path = self.target / "source_data"

self.dataset_local_path.mkdir(parents=True, exist_ok=True)

_log.info("Downloading files (raw mode)...")
assert isinstance(
self.dataset_source, HFSource
), "dataset_source must be HFSource"
snapshot_download(
repo_id=self.dataset_source.repo_id,
revision=self.dataset_source.revision,
repo_type="dataset",
token=self.dataset_source.hf_token,
local_dir=str(self.dataset_local_path),
)
self.retrieved = True
return self.dataset_local_path

def update_gt_into_map(self, gt: List[Dict]) -> Dict[str, Dict]:
"""
Convert list of annotation items to a map keyed by image path.
Expand Down Expand Up @@ -330,6 +375,11 @@ def iterate(self) -> Iterable[DatasetRecord]:
Yields:
DatasetRecord objects
"""
# Parquet mode: use load_dataset directly
if self.use_parquet:
yield from self._iterate_parquet()
return

if not self.retrieved and self.must_retrieve:
raise RuntimeError(
"You must first retrieve the source dataset. Call retrieve_input_dataset()."
Expand Down Expand Up @@ -421,3 +471,83 @@ def iterate(self) -> Iterable[DatasetRecord]:
)

yield record

def _iterate_parquet(self) -> Iterable[DatasetRecord]:
"""
Iterate through the Parquet dataset and yield DatasetRecord objects.

This method loads data directly via load_dataset, avoiding rate limits
from downloading many individual files.
"""
_log.info("Loading dataset via load_dataset (Parquet mode)...")

assert isinstance(
self.dataset_source, HFSource
), "dataset_source must be HFSource"
ds = load_dataset(
self.dataset_source.repo_id,
split="train",
)

total_items = len(ds)
begin, end = self.get_effective_indices(total_items)
ds = ds.select(range(begin, end))
selected_items = len(ds)

self.log_dataset_stats(total_items, selected_items)

for item in tqdm(
ds, total=selected_items, ncols=128, desc="Processing Parquet records"
):
filename = item["filename"]
gt_data = json.loads(item["ground_truth"])
pdf_bytes = item["pdf"]
page_image: PILImage.Image = item["image"]

# Create document and add page
true_doc = DoclingDocument(name=f"ground-truth {filename}")
page_image_rgb = page_image.convert("RGB")
page_width = float(page_image_rgb.width)
page_height = float(page_image_rgb.height)

page_item = PageItem(
page_no=1,
size=Size(width=page_width, height=page_height),
)
true_doc.pages[1] = page_item

# Update document with ground truth
true_doc = self.update_doc_with_gt(
gt=gt_data,
true_doc=true_doc,
page=true_doc.pages[1],
page_image=page_image_rgb,
page_width=page_width,
page_height=page_height,
)

# Extract images from the ground truth document
true_doc, true_pictures, true_page_images = extract_images(
document=true_doc,
pictures_column=BenchMarkColumns.GROUNDTRUTH_PICTURES.value,
page_images_column=BenchMarkColumns.GROUNDTRUTH_PAGE_IMAGES.value,
)

# Create PDF stream from bytes
pdf_stream = DocumentStream(
name=Path(filename).stem + ".pdf",
stream=BytesIO(pdf_bytes),
)

# Create dataset record
record = DatasetRecord(
doc_id=filename,
doc_hash=get_binhash(pdf_bytes),
ground_truth_doc=true_doc,
ground_truth_pictures=true_pictures,
ground_truth_page_images=true_page_images,
original=pdf_stream,
mime_type="application/pdf",
)

yield record
92 changes: 92 additions & 0 deletions tests/test_omnidocbench_parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
from pathlib import Path

import pytest

from docling_eval.cli.main import evaluate, get_prediction_provider, visualize
from docling_eval.datamodels.types import (
BenchMarkNames,
EvaluationModality,
PredictionProviderType,
)
from docling_eval.dataset_builders.omnidocbench_builder import (
OmniDocBenchDatasetBuilder,
)

IS_CI = bool(os.getenv("CI"))


@pytest.mark.skipif(
IS_CI, reason="Skipping test in CI because the dataset is too heavy."
)
def test_run_omnidocbench_parquet_e2e():
"""
Test OmniDocBench with Parquet mode (use_parquet=True).

This uses the `samiuc/OmniDocBench-parquet` dataset which contains
filename, image, pdf, and ground_truth columns in Parquet format,
avoiding HuggingFace rate limits from downloading individual files.
"""
target_path = Path(f"./scratch/{BenchMarkNames.OMNIDOCBENCH.value}-parquet/")
docling_provider = get_prediction_provider(PredictionProviderType.DOCLING)

dataset_layout = OmniDocBenchDatasetBuilder(
target=target_path / "gt_dataset",
repo_id="samiuc/OmniDocBench-parquet",
use_parquet=True,
end_index=5,
)

dataset_layout.retrieve_input_dataset() # No-op in Parquet mode
dataset_layout.save_to_disk() # Iterates dataset and saves as parquet shards

docling_provider.create_prediction_dataset(
name=dataset_layout.name,
gt_dataset_dir=target_path / "gt_dataset",
target_dataset_dir=target_path / "eval_dataset",
)

# Evaluate Layout
evaluate(
modality=EvaluationModality.LAYOUT,
benchmark=BenchMarkNames.OMNIDOCBENCH,
idir=target_path / "eval_dataset",
odir=target_path / "evaluations" / EvaluationModality.LAYOUT.value,
)

visualize(
modality=EvaluationModality.LAYOUT,
benchmark=BenchMarkNames.OMNIDOCBENCH,
idir=target_path / "eval_dataset",
odir=target_path / "evaluations" / EvaluationModality.LAYOUT.value,
)

# Evaluate Reading Order
evaluate(
modality=EvaluationModality.READING_ORDER,
benchmark=BenchMarkNames.OMNIDOCBENCH,
idir=target_path / "eval_dataset",
odir=target_path / "evaluations" / EvaluationModality.READING_ORDER.value,
)

visualize(
modality=EvaluationModality.READING_ORDER,
benchmark=BenchMarkNames.OMNIDOCBENCH,
idir=target_path / "eval_dataset",
odir=target_path / "evaluations" / EvaluationModality.READING_ORDER.value,
)

# Evaluate Markdown Text
evaluate(
modality=EvaluationModality.MARKDOWN_TEXT,
benchmark=BenchMarkNames.OMNIDOCBENCH,
idir=target_path / "eval_dataset",
odir=target_path / "evaluations" / EvaluationModality.MARKDOWN_TEXT.value,
)

visualize(
modality=EvaluationModality.MARKDOWN_TEXT,
benchmark=BenchMarkNames.OMNIDOCBENCH,
idir=target_path / "eval_dataset",
odir=target_path / "evaluations" / EvaluationModality.MARKDOWN_TEXT.value,
)
Loading