diff --git a/docling_eval/cli/main.py b/docling_eval/cli/main.py index f898a40d..27eba766 100644 --- a/docling_eval/cli/main.py +++ b/docling_eval/cli/main.py @@ -70,6 +70,7 @@ from docling_eval.dataset_builders.doclingdpbench_builder import ( DoclingDPBenchDatasetBuilder, ) +from docling_eval.dataset_builders.doclingsdg_builder import DoclingSDGDatasetBuilder from docling_eval.dataset_builders.docvqa_builder import DocVQADatasetBuilder from docling_eval.dataset_builders.dpbench_builder import DPBenchDatasetBuilder from docling_eval.dataset_builders.file_dataset_builder import FileDatasetBuilder @@ -136,9 +137,17 @@ from docling_eval.prediction_providers.google_prediction_provider import ( GoogleDocAIPredictionProvider, ) -from docling_eval.prediction_providers.tableformer_provider import ( - TableFormerPredictionProvider, -) + +# TableFormer provider may not be available for all docling installations. +try: + from docling_eval.prediction_providers.tableformer_provider import ( + TableFormerPredictionProvider, + ) + + TABLEFORMER_AVAILABLE = True +except ImportError: + TABLEFORMER_AVAILABLE = False + TableFormerPredictionProvider = None # type: ignore from docling_eval.utils.external_docling_document_loader import ( ExternalDoclingDocumentLoader, ) @@ -315,6 +324,11 @@ def get_dataset_builder( elif benchmark == BenchMarkNames.DOCLING_DPBENCH: return DoclingDPBenchDatasetBuilder(**common_params) # type: ignore + elif benchmark == BenchMarkNames.DOCLING_SDG: + if dataset_source is None: + raise ValueError("dataset_source is required for DOCLING_SDG") + return DoclingSDGDatasetBuilder(dataset_source=dataset_source, **common_params) # type: ignore + elif benchmark == BenchMarkNames.DOCLAYNETV1: return DocLayNetV1DatasetBuilder(**common_params) # type: ignore @@ -620,6 +634,12 @@ def get_prediction_provider( ignore_missing_predictions=True, ) elif provider_type == PredictionProviderType.TABLEFORMER: + if not TABLEFORMER_AVAILABLE: + raise ImportError( + "TableFormer provider is not available in this environment. " + "Please install a compatible docling/docling-eval setup " + "that provides `docling.models.stages`." + ) return TableFormerPredictionProvider( do_visualization=do_visualization, ignore_missing_predictions=True, diff --git a/docling_eval/datamodels/types.py b/docling_eval/datamodels/types.py index 8786f48b..0a3057c2 100644 --- a/docling_eval/datamodels/types.py +++ b/docling_eval/datamodels/types.py @@ -59,6 +59,7 @@ class BenchMarkNames(str, Enum): # End-to-End DPBENCH = "DPBench" DOCLING_DPBENCH = "DoclingDPBench" + DOCLING_SDG = "DoclingSDG" OMNIDOCBENCH = "OmniDocBench" WORDSCAPE = "WordScape" CANVA_A = "canva_a" diff --git a/docling_eval/dataset_builders/doclingsdg_builder.py b/docling_eval/dataset_builders/doclingsdg_builder.py new file mode 100644 index 00000000..5b79c71f --- /dev/null +++ b/docling_eval/dataset_builders/doclingsdg_builder.py @@ -0,0 +1,285 @@ +import logging +import re +from io import BytesIO +from pathlib import Path +from typing import Dict, Iterable, List + +from docling_core.types import DoclingDocument +from docling_core.types.doc import ImageRef, PageItem, Size +from docling_core.types.io import DocumentStream +from PIL import Image +from pydantic import ValidationError +from tqdm import tqdm + +from docling_eval.datamodels.dataset_record import DatasetRecord +from docling_eval.datamodels.types import BenchMarkColumns +from docling_eval.dataset_builders.dataset_builder import BaseEvaluationDatasetBuilder +from docling_eval.utils.utils import ( + extract_images, + from_pil_to_base64uri, + get_binary, + get_binhash, +) + +_log = logging.getLogger(__name__) + +_PAGE_SUFFIX_PATTERN = re.compile(r"_page_(\d+)$", re.IGNORECASE) + + +class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder): + """ + Dataset builder for local Docling JSON + PNG source files. + + Expected input layout in ``dataset_source``: + - one Docling JSON file per document (``.json``), and + - either one PNG (``.png``) or page-wise PNGs + (``_page_000001.png``, ...). + """ + + def __init__( + self, + dataset_source: Path, + target: Path, + split: str = "test", + begin_index: int = 0, + end_index: int = -1, + ): + super().__init__( + name="DoclingSDG", + dataset_source=dataset_source, + target=target, + split=split, + begin_index=begin_index, + end_index=end_index, + ) + + self.must_retrieve = False + + @staticmethod + def _sort_by_page_suffix(path: Path) -> tuple[int, str]: + match = _PAGE_SUFFIX_PATTERN.search(path.stem) + page_no = int(match.group(1)) if match else 0 + return page_no, path.name.lower() + + def _find_json_files(self) -> List[Path]: + assert isinstance(self.dataset_source, Path) + + files = list(self.dataset_source.glob("*.json")) + files.extend(self.dataset_source.glob("*.JSON")) + + deduped = {f.resolve(): f for f in files} + return sorted(deduped.values(), key=lambda p: p.name.lower()) + + def _build_png_indices(self) -> tuple[Dict[str, List[Path]], Dict[str, List[Path]]]: + assert isinstance(self.dataset_source, Path) + + png_candidates = list(self.dataset_source.glob("*.png")) + png_candidates.extend(self.dataset_source.glob("*.PNG")) + + exact_index: Dict[str, List[Path]] = {} + paged_index: Dict[str, List[Path]] = {} + + for png_path in png_candidates: + stem = png_path.stem + page_match = _PAGE_SUFFIX_PATTERN.search(stem) + + if page_match is None: + exact_index.setdefault(stem, []).append(png_path) + continue + + base_name = stem[: page_match.start()] + paged_index.setdefault(base_name, []).append(png_path) + + for key, values in exact_index.items(): + exact_index[key] = sorted( + {f.resolve(): f for f in values}.values(), + key=lambda p: p.name.lower(), + ) + + for key, values in paged_index.items(): + paged_index[key] = sorted( + {f.resolve(): f for f in values}.values(), + key=self._sort_by_page_suffix, + ) + + return exact_index, paged_index + + def _find_png_files_for_doc( + self, + doc_id: str, + exact_index: Dict[str, List[Path]], + paged_index: Dict[str, List[Path]], + ) -> List[Path]: + base_names = [doc_id] + if doc_id.lower().endswith(".png"): + base_names.append(doc_id[:-4]) + + for base_name in dict.fromkeys(base_names): + if not base_name: + continue + exact_matches = exact_index.get(base_name) + if exact_matches: + return exact_matches + + for base_name in dict.fromkeys(base_names): + if not base_name: + continue + paged_matches = paged_index.get(base_name) + if paged_matches: + return paged_matches + + return [] + + @staticmethod + def _load_png_images(files: List[Path]) -> List[Image.Image]: + images: List[Image.Image] = [] + + for png_path in files: + with Image.open(png_path) as img: + images.append(img.convert("RGB")) + + return images + + @staticmethod + def _attach_page_images( + document: DoclingDocument, + page_images: List[Image.Image], + ) -> DoclingDocument: + for page_no, page_image in enumerate(page_images, start=1): + image_ref = ImageRef( + mimetype="image/png", + dpi=72, + size=Size(width=page_image.width, height=page_image.height), + uri=from_pil_to_base64uri(page_image), + ) + + if page_no in document.pages: + page_item = document.pages[page_no] + page_item.image = image_ref + if ( + page_item.size is None + or page_item.size.width <= 0 + or page_item.size.height <= 0 + ): + page_item.size = Size( + width=float(page_image.width), + height=float(page_image.height), + ) + else: + document.pages[page_no] = PageItem( + page_no=page_no, + size=Size( + width=float(page_image.width), + height=float(page_image.height), + ), + image=image_ref, + ) + + return document + + @staticmethod + def _page_images_to_pdf_bytes(page_images: List[Image.Image]) -> bytes: + if not page_images: + raise ValueError("page_images must not be empty") + + pdf_buffer = BytesIO() + first_page = page_images[0] + other_pages = page_images[1:] + + first_page.save( + pdf_buffer, + format="PDF", + save_all=True, + append_images=other_pages, + ) + + return pdf_buffer.getvalue() + + def iterate(self) -> Iterable[DatasetRecord]: + assert isinstance(self.dataset_source, Path) + + json_files = self._find_json_files() + exact_png_index, paged_png_index = self._build_png_indices() + + begin, end = self.get_effective_indices(len(json_files)) + selected_json_files = json_files[begin:end] + + self.log_dataset_stats(len(json_files), len(selected_json_files)) + _log.info( + "Processing DoclingSDG dataset with %d documents", + len(selected_json_files), + ) + + for json_path in tqdm( + selected_json_files, + desc="Processing files for DoclingSDG", + ncols=128, + ): + doc_id = json_path.stem + + try: + document = DoclingDocument.load_from_json(json_path) + except ValidationError as exc: + _log.warning("Validation error for %s: %s. Skipping.", json_path, exc) + continue + except Exception as exc: # noqa: BLE001 + _log.warning("Failed to load %s: %s. Skipping.", json_path, exc) + continue + + png_files = self._find_png_files_for_doc( + doc_id=doc_id, + exact_index=exact_png_index, + paged_index=paged_png_index, + ) + if len(png_files) == 0: + _log.warning( + "No matching PNG found for %s. Expected '%s.png' or '%s_page_*.png'. Skipping.", + json_path.name, + doc_id, + doc_id, + ) + continue + + try: + page_images = self._load_png_images(png_files) + except Exception as exc: # noqa: BLE001 + _log.warning( + "Failed to read PNG files for %s: %s. Skipping.", doc_id, exc + ) + continue + + self._attach_page_images(document, page_images) + document, pictures, extracted_page_images = extract_images( + document=document, + pictures_column=BenchMarkColumns.GROUNDTRUTH_PICTURES.value, + page_images_column=BenchMarkColumns.GROUNDTRUTH_PAGE_IMAGES.value, + ) + + if len(extracted_page_images) == 0: + extracted_page_images = page_images + + if len(png_files) == 1: + original_bytes = get_binary(png_files[0]) + original_stream = DocumentStream( + name=png_files[0].name, + stream=BytesIO(original_bytes), + ) + mime_type = "image/png" + else: + original_bytes = self._page_images_to_pdf_bytes(page_images) + original_stream = DocumentStream( + name=f"{doc_id}.pdf", + stream=BytesIO(original_bytes), + ) + mime_type = "application/pdf" + + yield DatasetRecord( + doc_id=doc_id, + doc_path=json_path, + doc_hash=get_binhash(original_bytes), + ground_truth_doc=document, + ground_truth_pictures=pictures, + ground_truth_page_images=extracted_page_images, + original=original_stream, + mime_type=mime_type, + ) diff --git a/docling_eval/utils/split_input_data.py b/docling_eval/utils/split_input_data.py new file mode 100644 index 00000000..0d130e5d --- /dev/null +++ b/docling_eval/utils/split_input_data.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +""" +Split paired (name.json, name.png) files from an input directory into train/test/val +output directories by moving the files. +""" + +from __future__ import annotations + +import argparse +import random +import shutil +import sys +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class FilePair: + stem: str + json_path: Path + png_path: Path + + +def _collect_pairs(input_dir: Path) -> tuple[list[FilePair], list[str], list[str]]: + json_candidates = list(input_dir.glob("*.json")) + json_candidates.extend(input_dir.glob("*.JSON")) + + # Dedupe in case of case-insensitive overlap + json_files = sorted({p.resolve(): p for p in json_candidates}.values()) + + pairs: list[FilePair] = [] + missing_png: list[str] = [] + + for json_path in json_files: + stem = json_path.stem + png_lower = input_dir / f"{stem}.png" + png_upper = input_dir / f"{stem}.PNG" + + if png_lower.exists(): + png_path = png_lower + elif png_upper.exists(): + png_path = png_upper + else: + missing_png.append(stem) + continue + + pairs.append(FilePair(stem=stem, json_path=json_path, png_path=png_path)) + + paired_stems = {p.stem for p in pairs} + png_candidates = list(input_dir.glob("*.png")) + png_candidates.extend(input_dir.glob("*.PNG")) + orphan_png = sorted( + { + png_path.stem + for png_path in png_candidates + if png_path.stem not in paired_stems + } + ) + + return pairs, missing_png, orphan_png + + +def _split_counts( + total: int, train_ratio: float, test_ratio: float, val_ratio: float +) -> tuple[int, int, int]: + if total == 0: + return 0, 0, 0 + + train_count = int(total * train_ratio) + test_count = int(total * test_ratio) + # Make sure all items are assigned. + val_count = total - train_count - test_count + return train_count, test_count, val_count + + +def _validate_ratios( + train_ratio: float, test_ratio: float, val_ratio: float +) -> tuple[float, float, float]: + ratios = [train_ratio, test_ratio, val_ratio] + if any(r < 0 for r in ratios): + raise ValueError("Split ratios must be non-negative.") + + ratio_sum = sum(ratios) + if ratio_sum <= 0: + raise ValueError("At least one split ratio must be > 0.") + + if abs(ratio_sum - 1.0) < 1e-9: + return train_ratio, test_ratio, val_ratio + + # Normalize if the user passed ratios that do not sum to 1 exactly. + return train_ratio / ratio_sum, test_ratio / ratio_sum, val_ratio / ratio_sum + + +def _ensure_targets(*dirs: Path) -> None: + for directory in dirs: + directory.mkdir(parents=True, exist_ok=True) + + +def _assert_no_collisions(target_dir: Path, pair: FilePair) -> None: + json_target = target_dir / pair.json_path.name + png_target = target_dir / pair.png_path.name + + if json_target.exists() or png_target.exists(): + raise FileExistsError( + f"Target collision in '{target_dir}': " + f"{json_target.name if json_target.exists() else ''} " + f"{png_target.name if png_target.exists() else ''}".strip() + ) + + +def _move_pair(pair: FilePair, target_dir: Path, dry_run: bool) -> None: + _assert_no_collisions(target_dir, pair) + + if dry_run: + print(f"[DRY-RUN] {pair.json_path} -> {target_dir / pair.json_path.name}") + print(f"[DRY-RUN] {pair.png_path} -> {target_dir / pair.png_path.name}") + return + + shutil.move(str(pair.json_path), str(target_dir / pair.json_path.name)) + shutil.move(str(pair.png_path), str(target_dir / pair.png_path.name)) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Move paired (name.json, name.png) files from input directory into " + "train/test/val directories according to split ratios." + ) + ) + + parser.add_argument( + "input_dir", type=Path, help="Directory with source JSON/PNG pairs" + ) + parser.add_argument( + "train_output_dir", type=Path, help="Output directory for train split" + ) + parser.add_argument( + "test_output_dir", type=Path, help="Output directory for test split" + ) + parser.add_argument( + "val_output_dir", type=Path, help="Output directory for val split" + ) + + parser.add_argument( + "--train-ratio", + type=float, + default=0.8, + help="Train split ratio (default: 0.8)", + ) + parser.add_argument( + "--test-ratio", type=float, default=0.1, help="Test split ratio (default: 0.1)" + ) + parser.add_argument( + "--val-ratio", type=float, default=0.1, help="Val split ratio (default: 0.1)" + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducible splitting (default: 42)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print planned moves without changing files", + ) + + return parser + + +def main() -> int: + args = build_parser().parse_args() + + input_dir: Path = args.input_dir + if not input_dir.exists() or not input_dir.is_dir(): + print( + f"Error: input directory does not exist or is not a directory: {input_dir}", + file=sys.stderr, + ) + return 1 + + try: + train_ratio, test_ratio, val_ratio = _validate_ratios( + args.train_ratio, + args.test_ratio, + args.val_ratio, + ) + except ValueError as exc: + print(f"Error: {exc}", file=sys.stderr) + return 1 + + train_out: Path = args.train_output_dir + test_out: Path = args.test_output_dir + val_out: Path = args.val_output_dir + + _ensure_targets(train_out, test_out, val_out) + + pairs, missing_png, orphan_png = _collect_pairs(input_dir) + if not pairs: + print("No valid (name.json, name.png) pairs found. Nothing to move.") + if missing_png: + print(f"JSON without PNG pairs: {len(missing_png)}") + if orphan_png: + print(f"PNG without JSON pairs: {len(orphan_png)}") + return 0 + + rng = random.Random(args.seed) + rng.shuffle(pairs) + + train_count, test_count, val_count = _split_counts( + len(pairs), train_ratio, test_ratio, val_ratio + ) + + train_pairs = pairs[:train_count] + test_pairs = pairs[train_count : train_count + test_count] + val_pairs = pairs[train_count + test_count :] + + assert len(train_pairs) + len(test_pairs) + len(val_pairs) == len(pairs) + + for pair in train_pairs: + _move_pair(pair, train_out, args.dry_run) + for pair in test_pairs: + _move_pair(pair, test_out, args.dry_run) + for pair in val_pairs: + _move_pair(pair, val_out, args.dry_run) + + print( + "Done. " + f"Moved pairs -> train: {len(train_pairs)}, test: {len(test_pairs)}, val: {len(val_pairs)}" + ) + + if missing_png: + print(f"Skipped JSON without matching PNG: {len(missing_png)}") + if orphan_png: + print(f"Found PNG without matching JSON: {len(orphan_png)}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())