Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions docling_eval/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from docling_eval.dataset_builders.doclingdpbench_builder import (
DoclingDPBenchDatasetBuilder,
)
from docling_eval.dataset_builders.doclingsdg_builder import DoclingSDGDatasetBuilder
from docling_eval.dataset_builders.docvqa_builder import DocVQADatasetBuilder
from docling_eval.dataset_builders.dpbench_builder import DPBenchDatasetBuilder
from docling_eval.dataset_builders.file_dataset_builder import FileDatasetBuilder
Expand Down Expand Up @@ -136,9 +137,17 @@
from docling_eval.prediction_providers.google_prediction_provider import (
GoogleDocAIPredictionProvider,
)
from docling_eval.prediction_providers.tableformer_provider import (
TableFormerPredictionProvider,
)

# TableFormer provider may not be available for all docling installations.
try:
from docling_eval.prediction_providers.tableformer_provider import (
TableFormerPredictionProvider,
)

TABLEFORMER_AVAILABLE = True
except ImportError:
TABLEFORMER_AVAILABLE = False
TableFormerPredictionProvider = None # type: ignore
from docling_eval.utils.external_docling_document_loader import (
ExternalDoclingDocumentLoader,
)
Expand Down Expand Up @@ -315,6 +324,11 @@ def get_dataset_builder(
elif benchmark == BenchMarkNames.DOCLING_DPBENCH:
return DoclingDPBenchDatasetBuilder(**common_params) # type: ignore

elif benchmark == BenchMarkNames.DOCLING_SDG:
if dataset_source is None:
raise ValueError("dataset_source is required for DOCLING_SDG")
return DoclingSDGDatasetBuilder(dataset_source=dataset_source, **common_params) # type: ignore

elif benchmark == BenchMarkNames.DOCLAYNETV1:
return DocLayNetV1DatasetBuilder(**common_params) # type: ignore

Expand Down Expand Up @@ -620,6 +634,12 @@ def get_prediction_provider(
ignore_missing_predictions=True,
)
elif provider_type == PredictionProviderType.TABLEFORMER:
if not TABLEFORMER_AVAILABLE:
raise ImportError(
"TableFormer provider is not available in this environment. "
"Please install a compatible docling/docling-eval setup "
"that provides `docling.models.stages`."
)
return TableFormerPredictionProvider(
do_visualization=do_visualization,
ignore_missing_predictions=True,
Expand Down
1 change: 1 addition & 0 deletions docling_eval/datamodels/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class BenchMarkNames(str, Enum):
# End-to-End
DPBENCH = "DPBench"
DOCLING_DPBENCH = "DoclingDPBench"
DOCLING_SDG = "DoclingSDG"
OMNIDOCBENCH = "OmniDocBench"
WORDSCAPE = "WordScape"
CANVA_A = "canva_a"
Expand Down
247 changes: 247 additions & 0 deletions docling_eval/dataset_builders/doclingsdg_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import logging
import re
from io import BytesIO
from pathlib import Path
from typing import Iterable, List

from docling_core.types import DoclingDocument
from docling_core.types.doc import ImageRef, PageItem, Size
from docling_core.types.io import DocumentStream
from PIL import Image
from pydantic import ValidationError
from tqdm import tqdm

from docling_eval.datamodels.dataset_record import DatasetRecord
from docling_eval.datamodels.types import BenchMarkColumns
from docling_eval.dataset_builders.dataset_builder import BaseEvaluationDatasetBuilder
from docling_eval.utils.utils import (
extract_images,
from_pil_to_base64uri,
get_binary,
get_binhash,
)

_log = logging.getLogger(__name__)

_PAGE_SUFFIX_PATTERN = re.compile(r"_page_(\\d+)$", re.IGNORECASE)


class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
"""
Dataset builder for local Docling JSON + PNG source files.

Expected input layout in ``dataset_source``:
- one Docling JSON file per document (``<doc_id>.json``), and
- either one PNG (``<doc_id>.png``) or page-wise PNGs
(``<doc_id>_page_000001.png``, ...).
"""

def __init__(
self,
dataset_source: Path,
target: Path,
split: str = "test",
begin_index: int = 0,
end_index: int = -1,
):
super().__init__(
name="DoclingSDG",
dataset_source=dataset_source,
target=target,
split=split,
begin_index=begin_index,
end_index=end_index,
)

self.must_retrieve = False

@staticmethod
def _sort_by_page_suffix(path: Path) -> tuple[int, str]:
match = _PAGE_SUFFIX_PATTERN.search(path.stem)
page_no = int(match.group(1)) if match else 0
return page_no, path.name.lower()

def _find_json_files(self) -> List[Path]:
assert isinstance(self.dataset_source, Path)

files = list(self.dataset_source.glob("*.json"))
files.extend(self.dataset_source.glob("*.JSON"))

deduped = {f.resolve(): f for f in files}
return sorted(deduped.values(), key=lambda p: p.name.lower())

def _find_png_files_for_doc(self, doc_id: str) -> List[Path]:
assert isinstance(self.dataset_source, Path)

base_names = [doc_id]
if doc_id.lower().endswith(".png"):
base_names.append(doc_id[:-4])

exact_matches: List[Path] = []
paged_matches: List[Path] = []

for base_name in dict.fromkeys(base_names):
if not base_name:
continue

exact_matches.extend(self.dataset_source.glob(f"{base_name}.png"))
exact_matches.extend(self.dataset_source.glob(f"{base_name}.PNG"))

paged_matches.extend(self.dataset_source.glob(f"{base_name}_page_*.png"))
paged_matches.extend(self.dataset_source.glob(f"{base_name}_page_*.PNG"))

if exact_matches:
deduped = {f.resolve(): f for f in exact_matches}
return sorted(deduped.values(), key=lambda p: p.name.lower())

deduped = {f.resolve(): f for f in paged_matches}
return sorted(deduped.values(), key=self._sort_by_page_suffix)

@staticmethod
def _load_png_images(files: List[Path]) -> List[Image.Image]:
images: List[Image.Image] = []

for png_path in files:
with Image.open(png_path) as img:
images.append(img.convert("RGB"))

return images

@staticmethod
def _attach_page_images(
document: DoclingDocument,
page_images: List[Image.Image],
) -> DoclingDocument:
for page_no, page_image in enumerate(page_images, start=1):
image_ref = ImageRef(
mimetype="image/png",
dpi=72,
size=Size(width=page_image.width, height=page_image.height),
uri=from_pil_to_base64uri(page_image),
)

if page_no in document.pages:
page_item = document.pages[page_no]
page_item.image = image_ref
if (
page_item.size is None
or page_item.size.width <= 0
or page_item.size.height <= 0
):
page_item.size = Size(
width=float(page_image.width),
height=float(page_image.height),
)
else:
document.pages[page_no] = PageItem(
page_no=page_no,
size=Size(
width=float(page_image.width),
height=float(page_image.height),
),
image=image_ref,
)

return document

@staticmethod
def _page_images_to_pdf_bytes(page_images: List[Image.Image]) -> bytes:
if not page_images:
raise ValueError("page_images must not be empty")

pdf_buffer = BytesIO()
first_page = page_images[0]
other_pages = page_images[1:]

first_page.save(
pdf_buffer,
format="PDF",
save_all=True,
append_images=other_pages,
)

return pdf_buffer.getvalue()

def iterate(self) -> Iterable[DatasetRecord]:
assert isinstance(self.dataset_source, Path)

json_files = self._find_json_files()

begin, end = self.get_effective_indices(len(json_files))
selected_json_files = json_files[begin:end]

self.log_dataset_stats(len(json_files), len(selected_json_files))
_log.info(
"Processing DoclingSDG dataset with %d documents",
len(selected_json_files),
)

for json_path in tqdm(
selected_json_files,
desc="Processing files for DoclingSDG",
ncols=128,
):
doc_id = json_path.stem

try:
document = DoclingDocument.load_from_json(json_path)
except ValidationError as exc:
_log.warning("Validation error for %s: %s. Skipping.", json_path, exc)
continue
except Exception as exc: # noqa: BLE001
_log.warning("Failed to load %s: %s. Skipping.", json_path, exc)
continue

png_files = self._find_png_files_for_doc(doc_id)
if len(png_files) == 0:
_log.warning(
"No matching PNG found for %s. Expected '%s.png' or '%s_page_*.png'. Skipping.",
json_path.name,
doc_id,
doc_id,
)
continue

try:
page_images = self._load_png_images(png_files)
except Exception as exc: # noqa: BLE001
_log.warning(
"Failed to read PNG files for %s: %s. Skipping.", doc_id, exc
)
continue

self._attach_page_images(document, page_images)
document, pictures, extracted_page_images = extract_images(
document=document,
pictures_column=BenchMarkColumns.GROUNDTRUTH_PICTURES.value,
page_images_column=BenchMarkColumns.GROUNDTRUTH_PAGE_IMAGES.value,
)

if len(extracted_page_images) == 0:
extracted_page_images = page_images

if len(png_files) == 1:
original_bytes = get_binary(png_files[0])
original_stream = DocumentStream(
name=png_files[0].name,
stream=BytesIO(original_bytes),
)
mime_type = "image/png"
else:
original_bytes = self._page_images_to_pdf_bytes(page_images)
original_stream = DocumentStream(
name=f"{doc_id}.pdf",
stream=BytesIO(original_bytes),
)
mime_type = "application/pdf"

yield DatasetRecord(
doc_id=doc_id,
doc_path=json_path,
doc_hash=get_binhash(original_bytes),
ground_truth_doc=document,
ground_truth_pictures=pictures,
ground_truth_page_images=extracted_page_images,
original=original_stream,
mime_type=mime_type,
)
Loading