diff --git a/pyproject.toml b/pyproject.toml index e47e574..cd5f41b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "document-analysis-mcp" -version = "0.2.0" +version = "0.3.0" description = "General-purpose Document Analysis MCP server for PDF processing" readme = "README.md" requires-python = ">=3.10" diff --git a/src/document_analysis_mcp/cache/__init__.py b/src/document_analysis_mcp/cache/__init__.py index a786018..aa11907 100644 --- a/src/document_analysis_mcp/cache/__init__.py +++ b/src/document_analysis_mcp/cache/__init__.py @@ -1 +1,383 @@ -"""Hash-based document caching for deduplication.""" +"""Hash-based document caching for deduplication. + +This module provides a persistent cache for document extraction and analysis results. +It uses content hashing to detect duplicate documents and avoid re-processing. + +Key Features: +- SHA-256 content hashing for document deduplication +- TTL-based cache expiration (configurable via CACHE_TTL_DAYS) +- JSON storage for cache metadata and results +- Automatic cleanup of expired entries +""" + +import hashlib +import json +import logging +import time +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from document_analysis_mcp.config import get_settings + +logger = logging.getLogger(__name__) + +# Cache file names +METADATA_FILE = "cache_metadata.json" + + +@dataclass +class CacheEntry: + """Represents a cached document entry. + + Attributes: + content_hash: SHA-256 hash of the original PDF content. + created_at: UTC timestamp when the entry was created. + expires_at: UTC timestamp when the entry expires. + file_path: Path to the cached result file. + operation: Type of operation (e.g., 'extract', 'analyze', 'structure'). + metadata: Additional metadata about the cached content. + """ + + content_hash: str + created_at: datetime + expires_at: datetime + file_path: str + operation: str + metadata: dict[str, Any] = field(default_factory=dict) + + def is_expired(self) -> bool: + """Check if this cache entry has expired.""" + return datetime.now(timezone.utc) > self.expires_at + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "content_hash": self.content_hash, + "created_at": self.created_at.isoformat(), + "expires_at": self.expires_at.isoformat(), + "file_path": self.file_path, + "operation": self.operation, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "CacheEntry": + """Create a CacheEntry from a dictionary.""" + return cls( + content_hash=data["content_hash"], + created_at=datetime.fromisoformat(data["created_at"]), + expires_at=datetime.fromisoformat(data["expires_at"]), + file_path=data["file_path"], + operation=data["operation"], + metadata=data.get("metadata", {}), + ) + + +class DocumentCache: + """Hash-based cache for document processing results. + + Provides persistent caching of extraction and analysis results to avoid + re-processing identical documents. Uses SHA-256 content hashing for + deduplication. + + Usage: + cache = DocumentCache() + + # Check for cached result + cached = cache.get(pdf_content, "extract") + if cached: + return cached + + # Process document... + result = extract_pdf(pdf_content) + + # Cache the result + cache.put(pdf_content, "extract", result) + """ + + def __init__(self, cache_dir: Path | None = None, ttl_days: int | None = None) -> None: + """Initialize the document cache. + + Args: + cache_dir: Directory for cache storage. Defaults to settings.cache_dir. + ttl_days: Cache TTL in days. Defaults to settings.cache_ttl_days. + """ + settings = get_settings() + self._cache_dir = cache_dir or settings.cache_dir + self._ttl_days = ttl_days or settings.cache_ttl_days + self._metadata: dict[str, CacheEntry] = {} + self._initialized = False + + def _ensure_initialized(self) -> None: + """Ensure cache directory exists and metadata is loaded.""" + if self._initialized: + return + + # Create cache directory + self._cache_dir.mkdir(parents=True, exist_ok=True) + + # Load existing metadata + metadata_path = self._cache_dir / METADATA_FILE + if metadata_path.exists(): + try: + with open(metadata_path) as f: + data = json.load(f) + self._metadata = { + key: CacheEntry.from_dict(entry) for key, entry in data.items() + } + logger.debug("Loaded %d cache entries from %s", len(self._metadata), metadata_path) + except (json.JSONDecodeError, KeyError, ValueError) as e: + logger.warning("Failed to load cache metadata: %s", e) + self._metadata = {} + + self._initialized = True + + def _save_metadata(self) -> None: + """Save cache metadata to disk.""" + metadata_path = self._cache_dir / METADATA_FILE + try: + with open(metadata_path, "w") as f: + data = {key: entry.to_dict() for key, entry in self._metadata.items()} + json.dump(data, f, indent=2) + except OSError as e: + logger.error("Failed to save cache metadata: %s", e) + + @staticmethod + def compute_hash(content: str | bytes) -> str: + """Compute SHA-256 hash of content. + + Args: + content: String or bytes content to hash. + + Returns: + Hex-encoded SHA-256 hash. + """ + if isinstance(content, str): + content = content.encode("utf-8") + return hashlib.sha256(content).hexdigest() + + def _make_cache_key(self, content_hash: str, operation: str) -> str: + """Create a cache key from content hash and operation. + + Args: + content_hash: SHA-256 hash of the content. + operation: Operation type (e.g., 'extract', 'analyze'). + + Returns: + Cache key string. + """ + return f"{content_hash}:{operation}" + + def get( + self, + content: str | bytes, + operation: str, + params_hash: str | None = None, + ) -> dict[str, Any] | None: + """Retrieve a cached result for the given content and operation. + + Args: + content: PDF content (base64 string or bytes). + operation: Operation type to look up. + params_hash: Optional hash of operation parameters for cache keying. + + Returns: + Cached result dictionary, or None if not found or expired. + """ + self._ensure_initialized() + + content_hash = self.compute_hash(content) + cache_key = self._make_cache_key(content_hash, operation) + if params_hash: + cache_key = f"{cache_key}:{params_hash}" + + entry = self._metadata.get(cache_key) + if entry is None: + logger.debug("Cache miss: %s (not found)", cache_key[:16]) + return None + + if entry.is_expired(): + logger.debug("Cache miss: %s (expired)", cache_key[:16]) + self._remove_entry(cache_key) + return None + + # Load cached result + result_path = Path(entry.file_path) + if not result_path.exists(): + logger.warning("Cache file missing: %s", result_path) + self._remove_entry(cache_key) + return None + + try: + with open(result_path) as f: + result = json.load(f) + logger.info("Cache hit: %s", cache_key[:16]) + return result + except (json.JSONDecodeError, OSError) as e: + logger.warning("Failed to load cached result: %s", e) + self._remove_entry(cache_key) + return None + + def put( + self, + content: str | bytes, + operation: str, + result: dict[str, Any], + params_hash: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> str: + """Store a result in the cache. + + Args: + content: Original PDF content (base64 string or bytes). + operation: Operation type being cached. + result: Result dictionary to cache. + params_hash: Optional hash of operation parameters. + metadata: Optional metadata to store with the entry. + + Returns: + The content hash used for caching. + """ + self._ensure_initialized() + + content_hash = self.compute_hash(content) + cache_key = self._make_cache_key(content_hash, operation) + if params_hash: + cache_key = f"{cache_key}:{params_hash}" + + # Generate file path (include params_hash to avoid collisions) + timestamp = int(time.time()) + params_suffix = f"_{params_hash}" if params_hash else "" + result_filename = f"{content_hash[:16]}_{operation}{params_suffix}_{timestamp}.json" + result_path = self._cache_dir / result_filename + + # Calculate expiration + now = datetime.now(timezone.utc) + expires_at = datetime.fromtimestamp( + now.timestamp() + (self._ttl_days * 24 * 60 * 60), + tz=timezone.utc, + ) + + # Save result file + try: + with open(result_path, "w") as f: + json.dump(result, f, indent=2) + except OSError as e: + logger.error("Failed to write cache file: %s", e) + return content_hash + + # Create and save entry + entry = CacheEntry( + content_hash=content_hash, + created_at=now, + expires_at=expires_at, + file_path=str(result_path), + operation=operation, + metadata=metadata or {}, + ) + self._metadata[cache_key] = entry + self._save_metadata() + + logger.info( + "Cached result: %s (expires %s)", + cache_key[:16], + expires_at.isoformat(), + ) + return content_hash + + def _remove_entry(self, cache_key: str) -> None: + """Remove a cache entry and its associated file. + + Args: + cache_key: Cache key to remove. + """ + entry = self._metadata.pop(cache_key, None) + if entry: + try: + Path(entry.file_path).unlink(missing_ok=True) + except OSError as e: + logger.debug("Failed to remove cache file: %s", e) + self._save_metadata() + + def cleanup_expired(self) -> int: + """Remove all expired cache entries. + + Returns: + Number of entries removed. + """ + self._ensure_initialized() + + expired_keys = [key for key, entry in self._metadata.items() if entry.is_expired()] + + for key in expired_keys: + self._remove_entry(key) + + if expired_keys: + logger.info("Cleaned up %d expired cache entries", len(expired_keys)) + + return len(expired_keys) + + def clear(self) -> int: + """Clear all cache entries. + + Returns: + Number of entries removed. + """ + self._ensure_initialized() + + count = len(self._metadata) + for key in list(self._metadata.keys()): + self._remove_entry(key) + + logger.info("Cleared %d cache entries", count) + return count + + def get_stats(self) -> dict[str, Any]: + """Get cache statistics. + + Returns: + Dictionary containing cache statistics. + """ + self._ensure_initialized() + + expired_count = sum(1 for e in self._metadata.values() if e.is_expired()) + operations = {} + for entry in self._metadata.values(): + operations[entry.operation] = operations.get(entry.operation, 0) + 1 + + # Calculate total cache size + total_size = 0 + for entry in self._metadata.values(): + try: + total_size += Path(entry.file_path).stat().st_size + except OSError: + pass + + return { + "cache_dir": str(self._cache_dir), + "ttl_days": self._ttl_days, + "total_entries": len(self._metadata), + "expired_entries": expired_count, + "valid_entries": len(self._metadata) - expired_count, + "operations": operations, + "total_size_bytes": total_size, + "total_size_mb": round(total_size / (1024 * 1024), 2), + } + + +# Module-level cache instance (lazily initialized) +_cache: DocumentCache | None = None + + +def get_cache() -> DocumentCache: + """Get the global document cache instance. + + Returns: + Singleton DocumentCache instance. + """ + global _cache + if _cache is None: + _cache = DocumentCache() + return _cache diff --git a/src/document_analysis_mcp/server.py b/src/document_analysis_mcp/server.py index 9aa6acc..936bd78 100644 --- a/src/document_analysis_mcp/server.py +++ b/src/document_analysis_mcp/server.py @@ -13,12 +13,16 @@ from starlette.requests import Request from starlette.responses import JSONResponse +from document_analysis_mcp.cache import get_cache from document_analysis_mcp.config import get_settings from document_analysis_mcp.tools.classify import pdf_classify from document_analysis_mcp.tools.extract import pdf_extract_full +from document_analysis_mcp.tools.ocr import pdf_ocr +from document_analysis_mcp.tools.structure import pdf_extract_structure +from document_analysis_mcp.tracking import get_tracker # Server version - should match pyproject.toml -__version__ = "0.2.0" +__version__ = "0.3.0" # Track server startup time for uptime calculation _startup_time: datetime | None = None @@ -166,6 +170,136 @@ def pdf_classify_tool( ) +@mcp.tool() +def pdf_ocr_tool( + pdf_content: str, + max_pages: int = 50, + max_file_size_mb: float = 50.0, + language: str = "eng", + dpi: int = 300, + force_ocr: bool = False, + use_cache: bool = True, +) -> dict[str, Any]: + """Extract text from scanned PDFs using OCR. + + Uses Tesseract OCR to extract text from image-based PDFs where + standard text extraction fails or returns minimal text. + + Best for: + - Scanned documents + - PDFs created from images + - Documents where pdf_extract_full returns little text + + Args: + pdf_content: Base64-encoded PDF content. + max_pages: Maximum pages to OCR (OCR is slow, limit carefully). + max_file_size_mb: Maximum allowed file size in megabytes. + language: Tesseract language code (eng, fra, deu, etc.). + dpi: Resolution for page rendering (higher = better but slower). + force_ocr: Force OCR even if text extraction works. + use_cache: Use cache for previously processed documents. + + Returns: + Dictionary containing text, pages, metadata, processing_stats, + and ocr_stats with OCR-specific information. + """ + return pdf_ocr( + pdf_content=pdf_content, + max_pages=max_pages, + max_file_size_mb=max_file_size_mb, + language=language, + dpi=dpi, + force_ocr=force_ocr, + use_cache=use_cache, + ) + + +@mcp.tool() +def pdf_extract_structure_tool( + pdf_content: str, + max_pages: int = 100, + max_file_size_mb: float = 50.0, + extract_tables: bool = True, + extract_headings: bool = True, + extract_toc: bool = True, + use_cache: bool = True, +) -> dict[str, Any]: + """Extract document structure from a PDF. + + Analyzes document organization including: + - Table of Contents (if present) + - Tables with data and markdown formatting + - Section headings and hierarchy + + Best for: + - Understanding document organization + - Extracting tabular data + - Building document outlines + + Args: + pdf_content: Base64-encoded PDF content. + max_pages: Maximum pages to process. + max_file_size_mb: Maximum allowed file size in megabytes. + extract_tables: Whether to extract tables. + extract_headings: Whether to detect section headings. + extract_toc: Whether to detect Table of Contents. + use_cache: Use cache for previously processed documents. + + Returns: + Dictionary containing toc, tables, headings, structure_summary, + and processing_stats. + """ + return pdf_extract_structure( + pdf_content=pdf_content, + max_pages=max_pages, + max_file_size_mb=max_file_size_mb, + extract_tables=extract_tables, + extract_headings=extract_headings, + extract_toc=extract_toc, + use_cache=use_cache, + ) + + +@mcp.tool() +def cache_stats() -> dict[str, Any]: + """Get cache statistics and usage information. + + Returns statistics about the document cache including: + - Total entries and sizes + - Breakdown by operation type + - Expired entry counts + + Returns: + Dictionary containing cache statistics. + """ + cache = get_cache() + return cache.get_stats() + + +@mcp.tool() +def usage_summary(days: int = 7) -> dict[str, Any]: + """Get API usage summary and cost tracking. + + Returns usage statistics including: + - Total tokens used + - Estimated costs by model + - Daily usage breakdown + + Args: + days: Number of days to include in summary. + + Returns: + Dictionary containing usage summary and daily breakdown. + """ + tracker = get_tracker() + summary = tracker.get_summary() + daily = tracker.get_daily_summary(days=days) + return { + "summary": summary, + "daily": daily, + } + + @mcp.tool() def health_check() -> dict[str, Any]: """Check server health and configuration status. @@ -180,15 +314,19 @@ def health_check() -> dict[str, Any]: - classification_model: Model used for classification """ settings = get_settings() + cache = get_cache() + return { "status": "healthy", "version": __version__, "api_key_configured": settings.has_api_key, "cache_dir": str(settings.cache_dir), + "cache_ttl_days": settings.cache_ttl_days, "host": settings.doc_analysis_host, "port": settings.doc_analysis_port, "default_model": settings.default_model, "classification_model": settings.classification_model, + "cache_stats": cache.get_stats(), } diff --git a/src/document_analysis_mcp/tools/ocr.py b/src/document_analysis_mcp/tools/ocr.py new file mode 100644 index 0000000..7799e79 --- /dev/null +++ b/src/document_analysis_mcp/tools/ocr.py @@ -0,0 +1,431 @@ +"""PDF OCR tool for scanned documents. + +This module provides OCR capabilities for image-based PDFs using Tesseract. +It handles documents where standard text extraction fails or returns minimal text. + +Dependencies: + - pytesseract: Python bindings for Tesseract OCR + - Pillow: Image processing + - Tesseract OCR engine must be installed on the system +""" + +import base64 +import binascii +import io +import logging +import time +from typing import Any + +from pdfplumber.page import Page +from PIL import Image + +from document_analysis_mcp.cache import get_cache +from document_analysis_mcp.models.extraction import ( + DocumentMetadata, + PageContent, +) +from document_analysis_mcp.tracking import get_tracker + +logger = logging.getLogger(__name__) + +# OCR configuration +DEFAULT_DPI = 300 # Resolution for PDF to image conversion +DEFAULT_LANGUAGE = "eng" # Tesseract language code +MIN_TEXT_CHARS = 100 # Minimum chars before OCR is considered necessary + + +def _is_tesseract_available() -> bool: + """Check if Tesseract OCR is available. + + Returns: + True if Tesseract is installed and accessible. + """ + try: + import pytesseract + + pytesseract.get_tesseract_version() + return True + except Exception: + return False + + +def _page_to_image(page: Page, dpi: int = DEFAULT_DPI) -> Image.Image: + """Convert a pdfplumber page to a PIL Image. + + Args: + page: pdfplumber Page object. + dpi: Resolution for rendering. + + Returns: + PIL Image of the page. + """ + # Render page to image at specified DPI + image = page.to_image(resolution=dpi) + return image.original + + +def _ocr_image( + image: Image.Image, + language: str = DEFAULT_LANGUAGE, +) -> str: + """Perform OCR on an image. + + Args: + image: PIL Image to process. + language: Tesseract language code. + + Returns: + Extracted text from the image. + """ + import pytesseract + + # Convert to RGB if necessary (Tesseract works best with RGB) + if image.mode not in ("RGB", "L"): + image = image.convert("RGB") + + # Perform OCR + text = pytesseract.image_to_string(image, lang=language) + return text.strip() + + +def _validate_base64_size(pdf_content: str, max_file_size_mb: float) -> None: + """Validate base64-encoded content size before decoding. + + Args: + pdf_content: Base64-encoded PDF content. + max_file_size_mb: Maximum allowed file size in megabytes. + + Raises: + ValueError: If the encoded content exceeds the size limit. + """ + encoded_size_bytes = len(pdf_content) + estimated_decoded_size_mb = (encoded_size_bytes * 3 / 4) / (1024 * 1024) + + if estimated_decoded_size_mb > max_file_size_mb: + raise ValueError( + f"PDF content exceeds maximum size limit. " + f"Estimated size: {estimated_decoded_size_mb:.1f}MB, " + f"maximum allowed: {max_file_size_mb}MB" + ) + + +def pdf_ocr( + pdf_content: str, + max_pages: int = 50, + max_file_size_mb: float = 50.0, + language: str = DEFAULT_LANGUAGE, + dpi: int = DEFAULT_DPI, + force_ocr: bool = False, + use_cache: bool = True, +) -> dict[str, Any]: + """Extract text from a scanned PDF using OCR. + + This tool performs Optical Character Recognition on PDF pages that + contain images rather than selectable text. It's useful for: + - Scanned documents + - PDFs created from images + - Documents where standard extraction returns minimal text + + Args: + pdf_content: Base64-encoded PDF content. + max_pages: Maximum pages to process (OCR is slow, so limit carefully). + max_file_size_mb: Maximum allowed file size in megabytes. + language: Tesseract language code (default "eng" for English). + dpi: Resolution for rendering pages to images (higher = better quality but slower). + force_ocr: If True, perform OCR even if text extraction works. + use_cache: If True, check and use cache for previously processed documents. + + Returns: + Dictionary containing: + - success: Whether OCR extraction succeeded. + - text: OCR-extracted text content. + - pages: List of page extraction results. + - metadata: Document metadata. + - processing_stats: OCR processing statistics. + - ocr_stats: OCR-specific statistics (pages_ocr_processed, etc.). + - error: Error message if OCR failed. + + Raises: + ValueError: If pdf_content is empty or invalid. + """ + start_time = time.perf_counter() + + if not pdf_content: + raise ValueError("pdf_content cannot be empty") + + # Validate file size + _validate_base64_size(pdf_content, max_file_size_mb) + + # Check Tesseract availability + if not _is_tesseract_available(): + return { + "success": False, + "error": "Tesseract OCR is not installed or not accessible. " + "Please install Tesseract (apt-get install tesseract-ocr).", + "text": "", + "pages": [], + "metadata": {}, + "processing_stats": {}, + "ocr_stats": {"tesseract_available": False}, + } + + # Check cache + cache = get_cache() + if use_cache: + # Create params hash for cache key (includes OCR-specific params) + params_str = f"lang:{language}:dpi:{dpi}:force:{force_ocr}:max:{max_pages}" + params_hash = cache.compute_hash(params_str)[:16] + + cached_result = cache.get(pdf_content, "ocr", params_hash=params_hash) + if cached_result: + cached_result["processing_stats"]["cache_hit"] = True + return cached_result + + logger.info( + "Starting OCR extraction: max_pages=%d, language=%s, dpi=%d, force=%s", + max_pages, + language, + dpi, + force_ocr, + ) + + # Decode PDF + try: + pdf_bytes = base64.b64decode(pdf_content) + except (binascii.Error, ValueError) as e: + logger.error("Failed to decode base64 PDF: %s", e) + return { + "success": False, + "error": f"Invalid base64 encoding: {e}", + "text": "", + "pages": [], + "metadata": {}, + "processing_stats": {}, + "ocr_stats": {}, + } + + # Import pdfplumber for PDF processing + import pdfplumber + from pdfminer.pdfparser import PDFSyntaxError + from pdfminer.psparser import PSEOF + + pages: list[PageContent] = [] + ocr_page_count = 0 + text_page_count = 0 + total_chars = 0 + total_words = 0 + + try: + with pdfplumber.open(io.BytesIO(pdf_bytes)) as pdf: + total_pages = len(pdf.pages) + + # Extract metadata + meta = pdf.metadata or {} + metadata = DocumentMetadata( + title=meta.get("Title"), + author=meta.get("Author"), + subject=meta.get("Subject"), + page_count=total_pages, + file_size_bytes=len(pdf_bytes), + ) + + # Process each page + pages_to_process = min(max_pages, total_pages) + logger.info("Processing %d of %d pages", pages_to_process, total_pages) + + for page_num, page in enumerate(pdf.pages[:pages_to_process], start=1): + page_text = "" + + # First, try standard text extraction + try: + extracted_text = page.extract_text() or "" + except (PDFSyntaxError, PSEOF, ValueError) as e: + logger.debug("Text extraction failed on page %d: %s", page_num, e) + extracted_text = "" + + # Decide whether to use OCR + needs_ocr = force_ocr or len(extracted_text.strip()) < MIN_TEXT_CHARS + + if needs_ocr: + try: + # Convert page to image and perform OCR + image = _page_to_image(page, dpi=dpi) + page_text = _ocr_image(image, language=language) + ocr_page_count += 1 + logger.debug( + "OCR extracted %d chars from page %d", len(page_text), page_num + ) + except Exception as e: + logger.warning("OCR failed on page %d: %s", page_num, e) + # Fall back to whatever text we extracted + page_text = extracted_text + else: + page_text = extracted_text + text_page_count += 1 + + page_content = PageContent( + page_number=page_num, + text=page_text, + tables=[], # OCR doesn't extract tables + ) + pages.append(page_content) + total_chars += len(page_text) + total_words += len(page_text.split()) + + except (PDFSyntaxError, PSEOF) as e: + logger.error("PDF syntax error: %s", e) + return { + "success": False, + "error": f"Invalid PDF structure: {e}", + "text": "", + "pages": [], + "metadata": {}, + "processing_stats": {}, + "ocr_stats": {}, + } + except OSError as e: + logger.error("I/O error reading PDF: %s", e) + return { + "success": False, + "error": f"Failed to read PDF: {e}", + "text": "", + "pages": [], + "metadata": {}, + "processing_stats": {}, + "ocr_stats": {}, + } + + # Combine all page text + combined_text = "\n\n".join( + f"[Page {p.page_number}]\n{p.text}" for p in pages if p.text.strip() + ) + + processing_time_ms = (time.perf_counter() - start_time) * 1000 + + result = { + "success": True, + "text": combined_text, + "pages": [ + { + "page_number": p.page_number, + "text": p.text, + "char_count": len(p.text), + } + for p in pages + ], + "metadata": { + "title": metadata.title, + "author": metadata.author, + "subject": metadata.subject, + "page_count": metadata.page_count, + "file_size_bytes": metadata.file_size_bytes, + }, + "processing_stats": { + "pages_processed": len(pages), + "total_pages": total_pages, + "word_count": total_words, + "char_count": total_chars, + "processing_time_ms": round(processing_time_ms, 2), + "cache_hit": False, + }, + "ocr_stats": { + "tesseract_available": True, + "pages_ocr_processed": ocr_page_count, + "pages_text_extracted": text_page_count, + "language": language, + "dpi": dpi, + "force_ocr": force_ocr, + }, + } + + logger.info( + "OCR extraction complete: %d pages (%d OCR, %d text), %d words, %.0fms", + len(pages), + ocr_page_count, + text_page_count, + total_words, + processing_time_ms, + ) + + # Cache the result + if use_cache: + params_str = f"lang:{language}:dpi:{dpi}:force:{force_ocr}:max:{max_pages}" + params_hash = cache.compute_hash(params_str)[:16] + cache.put( + pdf_content, + "ocr", + result, + params_hash=params_hash, + metadata={"language": language, "dpi": dpi, "ocr_pages": ocr_page_count}, + ) + + # Track usage (no LLM tokens for OCR, but track processing time) + tracker = get_tracker() + tracker.record( + operation="ocr", + model="tesseract", + input_tokens=0, + output_tokens=0, + processing_time_ms=processing_time_ms, + document_hash=cache.compute_hash(pdf_content), + success=True, + metadata={ + "pages_processed": len(pages), + "ocr_pages": ocr_page_count, + "language": language, + }, + ) + + return result + + +# Tool metadata for MCP registration +PDF_OCR_METADATA = { + "name": "pdf_ocr", + "description": ( + "Extract text from scanned PDFs using OCR (Optical Character Recognition). " + "Use this tool for documents where standard text extraction fails or returns " + "minimal text, such as scanned documents or PDFs created from images. " + "Requires Tesseract OCR to be installed on the system." + ), + "parameters": { + "type": "object", + "properties": { + "pdf_content": { + "type": "string", + "description": "Base64-encoded PDF content", + }, + "max_pages": { + "type": "integer", + "default": 50, + "description": "Maximum pages to OCR (OCR is slow, limit carefully)", + }, + "max_file_size_mb": { + "type": "number", + "default": 50, + "description": "Maximum allowed file size in megabytes", + }, + "language": { + "type": "string", + "default": "eng", + "description": "Tesseract language code (eng, fra, deu, etc.)", + }, + "dpi": { + "type": "integer", + "default": 300, + "description": "Resolution for page rendering (higher = better but slower)", + }, + "force_ocr": { + "type": "boolean", + "default": False, + "description": "Force OCR even if text extraction works", + }, + "use_cache": { + "type": "boolean", + "default": True, + "description": "Use cache for previously processed documents", + }, + }, + "required": ["pdf_content"], + }, +} diff --git a/src/document_analysis_mcp/tools/structure.py b/src/document_analysis_mcp/tools/structure.py new file mode 100644 index 0000000..2f98ee3 --- /dev/null +++ b/src/document_analysis_mcp/tools/structure.py @@ -0,0 +1,661 @@ +"""PDF structure extraction tool. + +This module provides extraction of document structure including: +- Table of Contents (TOC) detection +- Table extraction and formatting +- Section/heading hierarchy detection + +Uses pdfplumber for table extraction and heuristic analysis +for TOC and heading detection. +""" + +import base64 +import binascii +import io +import logging +import re +import time +from dataclasses import dataclass, field +from typing import Any + +import pdfplumber +from pdfminer.pdfparser import PDFSyntaxError +from pdfminer.psparser import PSEOF + +from document_analysis_mcp.cache import get_cache +from document_analysis_mcp.tracking import get_tracker + +logger = logging.getLogger(__name__) + +# Default maximum file size +DEFAULT_MAX_FILE_SIZE_MB = 50.0 + +# Heading detection patterns +HEADING_PATTERNS = [ + # Numbered sections: 1., 1.1, 1.1.1, etc. + r"^(\d+(?:\.\d+)*)\s+(.+)$", + # Roman numerals: I., II., III., etc. + r"^([IVXLCDM]+)\.\s+(.+)$", + # Chapter/Section headers + r"^(Chapter|Section|Part)\s+(\d+|[IVXLCDM]+)[\.:]\s*(.+)$", + # Uppercase headers (likely headings) + r"^([A-Z][A-Z\s]{4,50})$", +] + +# TOC detection patterns +TOC_ENTRY_PATTERN = re.compile( + r"^(.+?)\s*[\.ยท\-_\s]{3,}\s*(\d+)\s*$", # "Topic ... 42" or "Topic --- 42" + re.MULTILINE, +) + + +@dataclass +class TableData: + """Extracted table data. + + Attributes: + page_number: Page where the table was found. + table_index: Index of the table on the page (0-indexed). + rows: Table data as a 2D list. + markdown: Table formatted as markdown. + row_count: Number of rows in the table. + col_count: Number of columns in the table. + has_header: Whether the table appears to have a header row. + """ + + page_number: int + table_index: int + rows: list[list[str]] + markdown: str + row_count: int + col_count: int + has_header: bool = True + + +@dataclass +class HeadingData: + """Detected heading/section data. + + Attributes: + page_number: Page where the heading was found. + text: Heading text. + level: Heading level (1 = top level, 2 = subsection, etc.). + number: Section number if detected (e.g., "1.2.3"). + position: Position on the page (line number or y-coordinate). + """ + + page_number: int + text: str + level: int + number: str | None = None + position: float = 0.0 + + +@dataclass +class TOCEntry: + """Table of Contents entry. + + Attributes: + title: Entry title/text. + page_number: Target page number. + level: Indentation/hierarchy level. + """ + + title: str + page_number: int + level: int = 0 + + +@dataclass +class DocumentStructure: + """Complete document structure analysis. + + Attributes: + toc: Detected Table of Contents entries. + tables: Extracted tables with data and formatting. + headings: Detected section headings. + page_count: Total pages in the document. + has_toc: Whether a TOC was detected. + """ + + toc: list[TOCEntry] = field(default_factory=list) + tables: list[TableData] = field(default_factory=list) + headings: list[HeadingData] = field(default_factory=list) + page_count: int = 0 + has_toc: bool = False + + +def _clean_cell(cell: Any) -> str: + """Clean a table cell value. + + Args: + cell: Raw cell value. + + Returns: + Cleaned string value. + """ + if cell is None: + return "" + return str(cell).strip().replace("|", "\\|").replace("\n", " ") + + +def _table_to_markdown(table: list[list[Any]]) -> str: + """Convert a table to markdown format. + + Args: + table: 2D list representing a table. + + Returns: + Markdown-formatted table string. + """ + if not table or not table[0]: + return "" + + # Clean cells + rows = [[_clean_cell(cell) for cell in row] for row in table] + + # Determine column count from first row + col_count = len(rows[0]) + + # Build markdown table + lines = [] + + # Header row + header = rows[0] + lines.append("| " + " | ".join(header) + " |") + + # Separator + lines.append("| " + " | ".join(["---"] * col_count) + " |") + + # Data rows + for row in rows[1:]: + # Pad row if it has fewer columns than header + padded_row = row + [""] * (col_count - len(row)) + lines.append("| " + " | ".join(padded_row[:col_count]) + " |") + + return "\n".join(lines) + + +def _detect_heading_level(text: str, position: float, page_height: float) -> int: + """Detect the heading level based on text patterns and position. + + Args: + text: Heading text. + position: Y-position on the page. + page_height: Total page height. + + Returns: + Heading level (1-4). + """ + text = text.strip() + + # Chapter/Part headers are level 1 + if re.match(r"^(Chapter|Part)\s+", text, re.IGNORECASE): + return 1 + + # Section headers are level 2 + if re.match(r"^Section\s+", text, re.IGNORECASE): + return 2 + + # Check numbered sections + num_match = re.match(r"^(\d+(?:\.\d+)*)\s+", text) + if num_match: + num = num_match.group(1) + dots = num.count(".") + return min(dots + 1, 4) # Cap at level 4 + + # All caps text near top of page is likely level 1-2 + if text.isupper() and len(text) > 5: + if position > page_height * 0.8: # Near top (PDF y increases downward) + return 1 + return 2 + + # Default to level 3 for other detected headings + return 3 + + +def _extract_headings(page: pdfplumber.page.Page, page_num: int) -> list[HeadingData]: + """Extract headings from a page using text analysis. + + Args: + page: pdfplumber Page object. + page_num: 1-indexed page number. + + Returns: + List of detected headings. + """ + headings: list[HeadingData] = [] + + try: + # Get text with positioning + chars = page.chars + if not chars: + return headings + + # Group characters by line (similar y-coordinate) + lines: dict[float, list[dict]] = {} + for char in chars: + y = round(char.get("top", 0), 0) + if y not in lines: + lines[y] = [] + lines[y].append(char) + + # Process each line + page_height = page.height + for y_pos, line_chars in sorted(lines.items()): + # Reconstruct line text + line_chars.sort(key=lambda c: c.get("x0", 0)) + line_text = "".join(c.get("text", "") for c in line_chars).strip() + + if not line_text or len(line_text) < 3: + continue + + # Check against heading patterns + for pattern in HEADING_PATTERNS: + match = re.match(pattern, line_text) + if match: + # Determine section number if present + groups = match.groups() + section_num = None + heading_text = line_text + + if len(groups) >= 2: + if re.match(r"^\d+(?:\.\d+)*$", groups[0]): + section_num = groups[0] + heading_text = groups[1] if len(groups) > 1 else line_text + elif groups[0].upper() in ("CHAPTER", "SECTION", "PART"): + section_num = groups[1] if len(groups) > 1 else None + heading_text = groups[2] if len(groups) > 2 else line_text + + level = _detect_heading_level(line_text, y_pos, page_height) + + headings.append( + HeadingData( + page_number=page_num, + text=heading_text.strip(), + level=level, + number=section_num, + position=y_pos, + ) + ) + break # Only match first pattern + + except (AttributeError, KeyError, ValueError) as e: + logger.debug("Error extracting headings from page %d: %s", page_num, e) + + return headings + + +def _detect_toc(pages_text: dict[int, str]) -> list[TOCEntry]: + """Detect Table of Contents from page text. + + Looks for characteristic TOC patterns in the first few pages. + + Args: + pages_text: Dictionary mapping page numbers to page text. + + Returns: + List of detected TOC entries. + """ + toc_entries: list[TOCEntry] = [] + + # Check first 5 pages for TOC + for page_num in range(1, min(6, max(pages_text.keys()) + 1)): + if page_num not in pages_text: + continue + + text = pages_text[page_num] + + # Look for "Table of Contents" or "Contents" header + if not re.search(r"(Table of\s+)?Contents", text, re.IGNORECASE): + continue + + # Extract TOC entries using pattern matching + for match in TOC_ENTRY_PATTERN.finditer(text): + title = match.group(1).strip() + target_page = int(match.group(2)) + + # Skip if title is too short or looks like a page number itself + if len(title) < 3 or title.isdigit(): + continue + + # Determine indentation level from leading whitespace + line_start = text.rfind("\n", 0, match.start()) + 1 + leading_ws = len(text[line_start : match.start()]) + level = min(leading_ws // 4, 3) # 4 spaces per level, max 3 + + toc_entries.append( + TOCEntry( + title=title, + page_number=target_page, + level=level, + ) + ) + + return toc_entries + + +def _validate_base64_size(pdf_content: str, max_file_size_mb: float) -> None: + """Validate base64-encoded content size before decoding. + + Args: + pdf_content: Base64-encoded PDF content. + max_file_size_mb: Maximum allowed file size in megabytes. + + Raises: + ValueError: If the encoded content exceeds the size limit. + """ + encoded_size_bytes = len(pdf_content) + estimated_decoded_size_mb = (encoded_size_bytes * 3 / 4) / (1024 * 1024) + + if estimated_decoded_size_mb > max_file_size_mb: + raise ValueError( + f"PDF content exceeds maximum size limit. " + f"Estimated size: {estimated_decoded_size_mb:.1f}MB, " + f"maximum allowed: {max_file_size_mb}MB" + ) + + +def pdf_extract_structure( + pdf_content: str, + max_pages: int = 100, + max_file_size_mb: float = DEFAULT_MAX_FILE_SIZE_MB, + extract_tables: bool = True, + extract_headings: bool = True, + extract_toc: bool = True, + use_cache: bool = True, +) -> dict[str, Any]: + """Extract document structure from a PDF. + + Analyzes the PDF to extract: + - Table of Contents (TOC) if present + - All tables with data and markdown formatting + - Section headings and their hierarchy + + This tool is useful for understanding document organization + and extracting structured data like tables. + + Args: + pdf_content: Base64-encoded PDF content. + max_pages: Maximum pages to process. + max_file_size_mb: Maximum allowed file size in megabytes. + extract_tables: Whether to extract tables. + extract_headings: Whether to detect section headings. + extract_toc: Whether to detect Table of Contents. + use_cache: Whether to use caching. + + Returns: + Dictionary containing: + - success: Whether extraction succeeded. + - toc: Table of Contents entries (if detected). + - tables: Extracted tables with data and markdown. + - headings: Detected section headings. + - structure_summary: Summary of document structure. + - processing_stats: Processing statistics. + - error: Error message if extraction failed. + + Raises: + ValueError: If pdf_content is empty or invalid. + """ + start_time = time.perf_counter() + + if not pdf_content: + raise ValueError("pdf_content cannot be empty") + + # Validate file size + _validate_base64_size(pdf_content, max_file_size_mb) + + # Check cache + cache = get_cache() + if use_cache: + params_str = ( + f"tables:{extract_tables}:headings:{extract_headings}:toc:{extract_toc}:max:{max_pages}" + ) + params_hash = cache.compute_hash(params_str)[:16] + + cached_result = cache.get(pdf_content, "structure", params_hash=params_hash) + if cached_result: + cached_result["processing_stats"]["cache_hit"] = True + return cached_result + + logger.info( + "Starting structure extraction: max_pages=%d, tables=%s, headings=%s, toc=%s", + max_pages, + extract_tables, + extract_headings, + extract_toc, + ) + + # Decode PDF + try: + pdf_bytes = base64.b64decode(pdf_content) + except (binascii.Error, ValueError) as e: + logger.error("Failed to decode base64 PDF: %s", e) + return { + "success": False, + "error": f"Invalid base64 encoding: {e}", + "toc": [], + "tables": [], + "headings": [], + "structure_summary": {}, + "processing_stats": {}, + } + + structure = DocumentStructure() + pages_text: dict[int, str] = {} + + try: + with pdfplumber.open(io.BytesIO(pdf_bytes)) as pdf: + structure.page_count = len(pdf.pages) + pages_to_process = min(max_pages, structure.page_count) + + for page_num, page in enumerate(pdf.pages[:pages_to_process], start=1): + # Extract text for TOC detection + try: + page_text = page.extract_text() or "" + pages_text[page_num] = page_text + except (PDFSyntaxError, PSEOF, ValueError) as e: + logger.debug("Text extraction failed on page %d: %s", page_num, e) + pages_text[page_num] = "" + + # Extract tables + if extract_tables: + try: + raw_tables = page.extract_tables() + for table_idx, table in enumerate(raw_tables): + if table and len(table) > 0: + markdown = _table_to_markdown(table) + if markdown: + # Clean and structure table data + clean_rows = [ + [_clean_cell(cell) for cell in row] for row in table + ] + + structure.tables.append( + TableData( + page_number=page_num, + table_index=table_idx, + rows=clean_rows, + markdown=markdown, + row_count=len(clean_rows), + col_count=len(clean_rows[0]) if clean_rows else 0, + has_header=True, # Assume first row is header + ) + ) + except (PDFSyntaxError, PSEOF, ValueError, TypeError) as e: + logger.debug("Table extraction failed on page %d: %s", page_num, e) + + # Extract headings + if extract_headings: + try: + page_headings = _extract_headings(page, page_num) + structure.headings.extend(page_headings) + except (PDFSyntaxError, PSEOF, ValueError) as e: + logger.debug("Heading extraction failed on page %d: %s", page_num, e) + + # Detect TOC from collected page text + if extract_toc and pages_text: + structure.toc = _detect_toc(pages_text) + structure.has_toc = len(structure.toc) > 0 + + except (PDFSyntaxError, PSEOF) as e: + logger.error("PDF syntax error: %s", e) + return { + "success": False, + "error": f"Invalid PDF structure: {e}", + "toc": [], + "tables": [], + "headings": [], + "structure_summary": {}, + "processing_stats": {}, + } + except OSError as e: + logger.error("I/O error reading PDF: %s", e) + return { + "success": False, + "error": f"Failed to read PDF: {e}", + "toc": [], + "tables": [], + "headings": [], + "structure_summary": {}, + "processing_stats": {}, + } + + processing_time_ms = (time.perf_counter() - start_time) * 1000 + + # Build result + result = { + "success": True, + "toc": [ + { + "title": entry.title, + "page_number": entry.page_number, + "level": entry.level, + } + for entry in structure.toc + ], + "tables": [ + { + "page_number": table.page_number, + "table_index": table.table_index, + "rows": table.rows, + "markdown": table.markdown, + "row_count": table.row_count, + "col_count": table.col_count, + "has_header": table.has_header, + } + for table in structure.tables + ], + "headings": [ + { + "page_number": heading.page_number, + "text": heading.text, + "level": heading.level, + "number": heading.number, + } + for heading in structure.headings + ], + "structure_summary": { + "page_count": structure.page_count, + "has_toc": structure.has_toc, + "toc_entry_count": len(structure.toc), + "table_count": len(structure.tables), + "heading_count": len(structure.headings), + "heading_levels": list({h.level for h in structure.headings}), + }, + "processing_stats": { + "pages_processed": min(max_pages, structure.page_count), + "total_pages": structure.page_count, + "processing_time_ms": round(processing_time_ms, 2), + "cache_hit": False, + }, + } + + logger.info( + "Structure extraction complete: %d pages, %d TOC entries, %d tables, %d headings, %.0fms", + result["structure_summary"]["page_count"], + result["structure_summary"]["toc_entry_count"], + result["structure_summary"]["table_count"], + result["structure_summary"]["heading_count"], + processing_time_ms, + ) + + # Cache the result + if use_cache: + params_str = ( + f"tables:{extract_tables}:headings:{extract_headings}:toc:{extract_toc}:max:{max_pages}" + ) + params_hash = cache.compute_hash(params_str)[:16] + cache.put( + pdf_content, + "structure", + result, + params_hash=params_hash, + metadata=result["structure_summary"], + ) + + # Track usage (no LLM tokens for structure extraction) + tracker = get_tracker() + tracker.record( + operation="structure", + model="pdfplumber", + input_tokens=0, + output_tokens=0, + processing_time_ms=processing_time_ms, + document_hash=cache.compute_hash(pdf_content), + success=True, + metadata=result["structure_summary"], + ) + + return result + + +# Tool metadata for MCP registration +PDF_EXTRACT_STRUCTURE_METADATA = { + "name": "pdf_extract_structure", + "description": ( + "Extract document structure from a PDF including Table of Contents, " + "tables, and section headings. Use this tool to understand document " + "organization, extract tabular data, or build a document outline." + ), + "parameters": { + "type": "object", + "properties": { + "pdf_content": { + "type": "string", + "description": "Base64-encoded PDF content", + }, + "max_pages": { + "type": "integer", + "default": 100, + "description": "Maximum pages to process", + }, + "max_file_size_mb": { + "type": "number", + "default": 50, + "description": "Maximum allowed file size in megabytes", + }, + "extract_tables": { + "type": "boolean", + "default": True, + "description": "Whether to extract tables", + }, + "extract_headings": { + "type": "boolean", + "default": True, + "description": "Whether to detect section headings", + }, + "extract_toc": { + "type": "boolean", + "default": True, + "description": "Whether to detect Table of Contents", + }, + "use_cache": { + "type": "boolean", + "default": True, + "description": "Use cache for previously processed documents", + }, + }, + "required": ["pdf_content"], + }, +} diff --git a/src/document_analysis_mcp/tracking/__init__.py b/src/document_analysis_mcp/tracking/__init__.py new file mode 100644 index 0000000..09200b7 --- /dev/null +++ b/src/document_analysis_mcp/tracking/__init__.py @@ -0,0 +1,419 @@ +"""Usage tracking for API costs and processing metrics. + +This module provides tracking and logging of: +- Token usage per document and operation +- Processing time metrics +- API cost estimation +- Historical usage data + +Data is stored in JSON Lines format for efficient append-only writes. +""" + +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from document_analysis_mcp.config import get_settings + +logger = logging.getLogger(__name__) + +# Tracking file name +TRACKING_FILE = "usage_tracking.jsonl" + +# Token pricing (per million tokens) - Claude models as of Jan 2026 +# Source: https://www.anthropic.com/pricing +TOKEN_PRICING = { + "claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00}, + "claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00}, + "claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00}, + "claude-3-opus-20240229": {"input": 15.00, "output": 75.00}, + "default": {"input": 3.00, "output": 15.00}, +} + + +@dataclass +class UsageRecord: + """Record of a single API usage event. + + Attributes: + timestamp: UTC timestamp of the usage event. + operation: Type of operation (e.g., 'extract', 'classify', 'ocr'). + model: Model used for the operation. + input_tokens: Number of input tokens used. + output_tokens: Number of output tokens generated. + processing_time_ms: Total processing time in milliseconds. + document_hash: Optional hash of the processed document. + success: Whether the operation succeeded. + error_message: Error message if the operation failed. + metadata: Additional metadata about the operation. + """ + + timestamp: datetime + operation: str + model: str + input_tokens: int = 0 + output_tokens: int = 0 + processing_time_ms: float = 0.0 + document_hash: str | None = None + success: bool = True + error_message: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def total_tokens(self) -> int: + """Total tokens used (input + output).""" + return self.input_tokens + self.output_tokens + + @property + def estimated_cost_usd(self) -> float: + """Estimated cost in USD based on token pricing. + + Returns: + Cost in USD, or 0.0 if model pricing is unknown. + """ + pricing = TOKEN_PRICING.get(self.model, TOKEN_PRICING["default"]) + input_cost = (self.input_tokens / 1_000_000) * pricing["input"] + output_cost = (self.output_tokens / 1_000_000) * pricing["output"] + return round(input_cost + output_cost, 6) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "timestamp": self.timestamp.isoformat(), + "operation": self.operation, + "model": self.model, + "input_tokens": self.input_tokens, + "output_tokens": self.output_tokens, + "total_tokens": self.total_tokens, + "processing_time_ms": self.processing_time_ms, + "document_hash": self.document_hash, + "success": self.success, + "error_message": self.error_message, + "estimated_cost_usd": self.estimated_cost_usd, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "UsageRecord": + """Create a UsageRecord from a dictionary.""" + return cls( + timestamp=datetime.fromisoformat(data["timestamp"]), + operation=data["operation"], + model=data["model"], + input_tokens=data.get("input_tokens", 0), + output_tokens=data.get("output_tokens", 0), + processing_time_ms=data.get("processing_time_ms", 0.0), + document_hash=data.get("document_hash"), + success=data.get("success", True), + error_message=data.get("error_message"), + metadata=data.get("metadata", {}), + ) + + +class UsageTracker: + """Tracks API usage and costs for document processing operations. + + Provides methods to record usage events and generate usage reports. + Data is stored in JSON Lines format for efficient append operations. + + Usage: + tracker = UsageTracker() + + # Record a usage event + tracker.record( + operation="extract", + model="claude-sonnet-4-20250514", + input_tokens=5000, + output_tokens=1000, + processing_time_ms=2500, + ) + + # Get usage summary + summary = tracker.get_summary() + """ + + def __init__(self, tracking_dir: Path | None = None) -> None: + """Initialize the usage tracker. + + Args: + tracking_dir: Directory for tracking data. Defaults to settings.cache_dir. + """ + settings = get_settings() + self._tracking_dir = tracking_dir or settings.cache_dir + self._tracking_file = self._tracking_dir / TRACKING_FILE + self._initialized = False + + def _ensure_initialized(self) -> None: + """Ensure tracking directory exists.""" + if self._initialized: + return + self._tracking_dir.mkdir(parents=True, exist_ok=True) + self._initialized = True + + def record( + self, + operation: str, + model: str, + input_tokens: int = 0, + output_tokens: int = 0, + processing_time_ms: float = 0.0, + document_hash: str | None = None, + success: bool = True, + error_message: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> UsageRecord: + """Record a usage event. + + Args: + operation: Type of operation performed. + model: Model used for the operation. + input_tokens: Number of input tokens used. + output_tokens: Number of output tokens generated. + processing_time_ms: Total processing time in milliseconds. + document_hash: Optional hash of the processed document. + success: Whether the operation succeeded. + error_message: Error message if the operation failed. + metadata: Additional metadata about the operation. + + Returns: + The created UsageRecord. + """ + self._ensure_initialized() + + record = UsageRecord( + timestamp=datetime.now(timezone.utc), + operation=operation, + model=model, + input_tokens=input_tokens, + output_tokens=output_tokens, + processing_time_ms=processing_time_ms, + document_hash=document_hash, + success=success, + error_message=error_message, + metadata=metadata or {}, + ) + + # Append to tracking file + try: + with open(self._tracking_file, "a") as f: + f.write(json.dumps(record.to_dict()) + "\n") + except OSError as e: + logger.error("Failed to write usage record: %s", e) + + logger.debug( + "Recorded usage: %s, %s, tokens=%d, cost=$%.6f", + operation, + model, + record.total_tokens, + record.estimated_cost_usd, + ) + + return record + + def get_records( + self, + operation: str | None = None, + model: str | None = None, + since: datetime | None = None, + until: datetime | None = None, + limit: int | None = None, + ) -> list[UsageRecord]: + """Retrieve usage records with optional filtering. + + Args: + operation: Filter by operation type. + model: Filter by model. + since: Filter to records after this timestamp. + until: Filter to records before this timestamp. + limit: Maximum number of records to return (most recent first). + + Returns: + List of matching UsageRecords. + """ + self._ensure_initialized() + + if not self._tracking_file.exists(): + return [] + + records: list[UsageRecord] = [] + + try: + with open(self._tracking_file) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + record = UsageRecord.from_dict(data) + + # Apply filters + if operation and record.operation != operation: + continue + if model and record.model != model: + continue + if since and record.timestamp < since: + continue + if until and record.timestamp > until: + continue + + records.append(record) + except (json.JSONDecodeError, KeyError, ValueError) as e: + logger.debug("Skipping malformed record: %s", e) + continue + except OSError as e: + logger.error("Failed to read usage records: %s", e) + return [] + + # Sort by timestamp descending and apply limit + records.sort(key=lambda r: r.timestamp, reverse=True) + if limit: + records = records[:limit] + + return records + + def get_summary( + self, + since: datetime | None = None, + until: datetime | None = None, + ) -> dict[str, Any]: + """Get a summary of usage statistics. + + Args: + since: Start of the summary period. + until: End of the summary period. + + Returns: + Dictionary containing usage summary statistics. + """ + records = self.get_records(since=since, until=until) + + if not records: + return { + "period_start": since.isoformat() if since else None, + "period_end": until.isoformat() if until else None, + "total_records": 0, + "successful_records": 0, + "failed_records": 0, + "total_input_tokens": 0, + "total_output_tokens": 0, + "total_tokens": 0, + "total_cost_usd": 0.0, + "total_processing_time_ms": 0.0, + "by_operation": {}, + "by_model": {}, + } + + # Aggregate statistics + successful = sum(1 for r in records if r.success) + total_input = sum(r.input_tokens for r in records) + total_output = sum(r.output_tokens for r in records) + total_cost = sum(r.estimated_cost_usd for r in records) + total_time = sum(r.processing_time_ms for r in records) + + # By operation + by_operation: dict[str, dict[str, Any]] = {} + for record in records: + if record.operation not in by_operation: + by_operation[record.operation] = { + "count": 0, + "input_tokens": 0, + "output_tokens": 0, + "cost_usd": 0.0, + } + stats = by_operation[record.operation] + stats["count"] += 1 + stats["input_tokens"] += record.input_tokens + stats["output_tokens"] += record.output_tokens + stats["cost_usd"] += record.estimated_cost_usd + + # By model + by_model: dict[str, dict[str, Any]] = {} + for record in records: + if record.model not in by_model: + by_model[record.model] = { + "count": 0, + "input_tokens": 0, + "output_tokens": 0, + "cost_usd": 0.0, + } + stats = by_model[record.model] + stats["count"] += 1 + stats["input_tokens"] += record.input_tokens + stats["output_tokens"] += record.output_tokens + stats["cost_usd"] += record.estimated_cost_usd + + # Round cost values + for stats in by_operation.values(): + stats["cost_usd"] = round(stats["cost_usd"], 4) + for stats in by_model.values(): + stats["cost_usd"] = round(stats["cost_usd"], 4) + + return { + "period_start": (min(r.timestamp for r in records)).isoformat(), + "period_end": (max(r.timestamp for r in records)).isoformat(), + "total_records": len(records), + "successful_records": successful, + "failed_records": len(records) - successful, + "total_input_tokens": total_input, + "total_output_tokens": total_output, + "total_tokens": total_input + total_output, + "total_cost_usd": round(total_cost, 4), + "total_processing_time_ms": round(total_time, 2), + "avg_processing_time_ms": round(total_time / len(records), 2), + "by_operation": by_operation, + "by_model": by_model, + } + + def get_daily_summary(self, days: int = 7) -> list[dict[str, Any]]: + """Get daily usage summaries for the past N days. + + Args: + days: Number of days to include. + + Returns: + List of daily summary dictionaries. + """ + from datetime import timedelta + + summaries = [] + now = datetime.now(timezone.utc) + + for i in range(days): + day_start = datetime(now.year, now.month, now.day, tzinfo=timezone.utc) - timedelta( + days=i + ) + day_end = day_start + timedelta(days=1) + + records = self.get_records(since=day_start, until=day_end) + + summaries.append( + { + "date": day_start.date().isoformat(), + "total_records": len(records), + "total_tokens": sum(r.total_tokens for r in records), + "total_cost_usd": round(sum(r.estimated_cost_usd for r in records), 4), + "operations": list({r.operation for r in records}), + } + ) + + return summaries + + +# Module-level tracker instance (lazily initialized) +_tracker: UsageTracker | None = None + + +def get_tracker() -> UsageTracker: + """Get the global usage tracker instance. + + Returns: + Singleton UsageTracker instance. + """ + global _tracker + if _tracker is None: + _tracker = UsageTracker() + return _tracker diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..c271098 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,319 @@ +"""Tests for the document cache module. + +This module tests hash-based caching for document processing results. +""" + +import tempfile +from datetime import datetime, timedelta, timezone +from pathlib import Path +from unittest.mock import patch + +from document_analysis_mcp.cache import ( + CacheEntry, + DocumentCache, + get_cache, +) + + +class TestCacheEntry: + """Tests for CacheEntry dataclass.""" + + def test_is_expired_false(self): + """Test that non-expired entries return False.""" + future = datetime.now(timezone.utc) + timedelta(days=30) + entry = CacheEntry( + content_hash="abc123", + created_at=datetime.now(timezone.utc), + expires_at=future, + file_path="/tmp/test.json", + operation="extract", + ) + assert not entry.is_expired() + + def test_is_expired_true(self): + """Test that expired entries return True.""" + past = datetime.now(timezone.utc) - timedelta(days=1) + entry = CacheEntry( + content_hash="abc123", + created_at=datetime.now(timezone.utc) - timedelta(days=31), + expires_at=past, + file_path="/tmp/test.json", + operation="extract", + ) + assert entry.is_expired() + + def test_to_dict(self): + """Test serialization to dictionary.""" + now = datetime.now(timezone.utc) + future = now + timedelta(days=30) + entry = CacheEntry( + content_hash="abc123", + created_at=now, + expires_at=future, + file_path="/tmp/test.json", + operation="extract", + metadata={"pages": 5}, + ) + data = entry.to_dict() + + assert data["content_hash"] == "abc123" + assert data["file_path"] == "/tmp/test.json" + assert data["operation"] == "extract" + assert data["metadata"] == {"pages": 5} + assert "created_at" in data + assert "expires_at" in data + + def test_from_dict(self): + """Test deserialization from dictionary.""" + now = datetime.now(timezone.utc) + future = now + timedelta(days=30) + data = { + "content_hash": "abc123", + "created_at": now.isoformat(), + "expires_at": future.isoformat(), + "file_path": "/tmp/test.json", + "operation": "extract", + "metadata": {"pages": 5}, + } + entry = CacheEntry.from_dict(data) + + assert entry.content_hash == "abc123" + assert entry.file_path == "/tmp/test.json" + assert entry.operation == "extract" + assert entry.metadata == {"pages": 5} + + def test_roundtrip(self): + """Test that to_dict/from_dict roundtrips correctly.""" + original = CacheEntry( + content_hash="test_hash", + created_at=datetime.now(timezone.utc), + expires_at=datetime.now(timezone.utc) + timedelta(days=30), + file_path="/path/to/file.json", + operation="analyze", + metadata={"key": "value"}, + ) + data = original.to_dict() + restored = CacheEntry.from_dict(data) + + assert restored.content_hash == original.content_hash + assert restored.file_path == original.file_path + assert restored.operation == original.operation + assert restored.metadata == original.metadata + + +class TestDocumentCache: + """Tests for DocumentCache class.""" + + def test_compute_hash_string(self): + """Test hash computation for strings.""" + hash1 = DocumentCache.compute_hash("test content") + hash2 = DocumentCache.compute_hash("test content") + hash3 = DocumentCache.compute_hash("different content") + + assert hash1 == hash2 + assert hash1 != hash3 + assert len(hash1) == 64 # SHA-256 hex length + + def test_compute_hash_bytes(self): + """Test hash computation for bytes.""" + hash1 = DocumentCache.compute_hash(b"test content") + hash2 = DocumentCache.compute_hash("test content") + + assert hash1 == hash2 + + def test_put_and_get(self): + """Test basic put and get operations.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = DocumentCache(cache_dir=Path(tmpdir), ttl_days=30) + + content = "test pdf content" + result = {"text": "extracted text", "pages": 5} + + # Put result in cache + cache.put(content, "extract", result) + + # Get result from cache + cached = cache.get(content, "extract") + + assert cached is not None + assert cached["text"] == "extracted text" + assert cached["pages"] == 5 + + def test_get_miss(self): + """Test cache miss returns None.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = DocumentCache(cache_dir=Path(tmpdir), ttl_days=30) + + cached = cache.get("nonexistent content", "extract") + assert cached is None + + def test_get_expired(self): + """Test that expired entries return None.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = DocumentCache(cache_dir=Path(tmpdir), ttl_days=30) + + content = "test content" + result = {"text": "test"} + + # Put result in cache + cache.put(content, "extract", result) + + # Manually expire the entry + content_hash = cache.compute_hash(content) + cache_key = f"{content_hash}:extract" + cache._metadata[cache_key].expires_at = datetime.now(timezone.utc) - timedelta(days=1) + + # Get should return None for expired entry + cached = cache.get(content, "extract") + assert cached is None + + def test_different_operations_cached_separately(self): + """Test that different operations are cached separately.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = DocumentCache(cache_dir=Path(tmpdir), ttl_days=30) + + content = "test content" + extract_result = {"type": "extract", "text": "extracted"} + classify_result = {"type": "classify", "class": "document"} + + cache.put(content, "extract", extract_result) + cache.put(content, "classify", classify_result) + + assert cache.get(content, "extract")["type"] == "extract" + assert cache.get(content, "classify")["type"] == "classify" + + def test_params_hash_cache_key(self): + """Test that params_hash creates separate cache entries.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = DocumentCache(cache_dir=Path(tmpdir), ttl_days=30) + + content = "test content" + result1 = {"params": "set1"} + result2 = {"params": "set2"} + + cache.put(content, "extract", result1, params_hash="hash1") + cache.put(content, "extract", result2, params_hash="hash2") + + assert cache.get(content, "extract", params_hash="hash1")["params"] == "set1" + assert cache.get(content, "extract", params_hash="hash2")["params"] == "set2" + + def test_cleanup_expired(self): + """Test cleanup of expired entries.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = DocumentCache(cache_dir=Path(tmpdir), ttl_days=30) + + # Add some entries + cache.put("content1", "extract", {"id": 1}) + cache.put("content2", "extract", {"id": 2}) + cache.put("content3", "extract", {"id": 3}) + + # Expire one entry + content_hash = cache.compute_hash("content2") + cache_key = f"{content_hash}:extract" + cache._metadata[cache_key].expires_at = datetime.now(timezone.utc) - timedelta(days=1) + + # Cleanup + removed = cache.cleanup_expired() + + assert removed == 1 + assert cache.get("content1", "extract") is not None + assert cache.get("content2", "extract") is None + assert cache.get("content3", "extract") is not None + + def test_clear(self): + """Test clearing all cache entries.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = DocumentCache(cache_dir=Path(tmpdir), ttl_days=30) + + # Add entries + cache.put("content1", "extract", {"id": 1}) + cache.put("content2", "extract", {"id": 2}) + + # Clear + removed = cache.clear() + + assert removed == 2 + assert cache.get("content1", "extract") is None + assert cache.get("content2", "extract") is None + + def test_get_stats(self): + """Test cache statistics.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = DocumentCache(cache_dir=Path(tmpdir), ttl_days=30) + + # Add entries + cache.put("content1", "extract", {"id": 1}) + cache.put("content2", "classify", {"id": 2}) + cache.put("content3", "extract", {"id": 3}) + + stats = cache.get_stats() + + assert stats["total_entries"] == 3 + assert stats["ttl_days"] == 30 + assert "operations" in stats + assert stats["operations"]["extract"] == 2 + assert stats["operations"]["classify"] == 1 + + def test_persistence(self): + """Test that cache persists across instances.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create cache and add entry + cache1 = DocumentCache(cache_dir=Path(tmpdir), ttl_days=30) + cache1.put("content", "extract", {"persisted": True}) + + # Create new cache instance + cache2 = DocumentCache(cache_dir=Path(tmpdir), ttl_days=30) + + # Should find the persisted entry + cached = cache2.get("content", "extract") + assert cached is not None + assert cached["persisted"] is True + + def test_missing_file_returns_none(self): + """Test that missing cache files return None.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = DocumentCache(cache_dir=Path(tmpdir), ttl_days=30) + + # Add entry + cache.put("content", "extract", {"id": 1}) + + # Delete the result file manually + content_hash = cache.compute_hash("content") + cache_key = f"{content_hash}:extract" + file_path = Path(cache._metadata[cache_key].file_path) + file_path.unlink() + + # Get should return None and clean up metadata + cached = cache.get("content", "extract") + assert cached is None + + +class TestGetCache: + """Tests for get_cache singleton function.""" + + def test_returns_same_instance(self): + """Test that get_cache returns the same instance.""" + # Reset the module-level cache + import document_analysis_mcp.cache as cache_module + + cache_module._cache = None + + cache1 = get_cache() + cache2 = get_cache() + + assert cache1 is cache2 + + def test_uses_settings(self): + """Test that get_cache uses settings for configuration.""" + import document_analysis_mcp.cache as cache_module + + cache_module._cache = None + + with patch("document_analysis_mcp.cache.get_settings") as mock_settings: + mock_settings.return_value.cache_dir = Path("/tmp/test-cache") + mock_settings.return_value.cache_ttl_days = 14 + + cache = get_cache() + + assert cache._cache_dir == Path("/tmp/test-cache") + assert cache._ttl_days == 14 diff --git a/tests/test_extract.py b/tests/test_extract.py index 61a966c..39dbca3 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -317,5 +317,5 @@ def test_health_check_tool(self): result = health_check.fn() assert result["status"] == "healthy" - assert result["version"] == "0.2.0" + assert result["version"] == "0.3.0" assert "api_key_configured" in result diff --git a/tests/test_ocr.py b/tests/test_ocr.py new file mode 100644 index 0000000..ff1a9ed --- /dev/null +++ b/tests/test_ocr.py @@ -0,0 +1,237 @@ +"""Tests for PDF OCR tool. + +This module tests the pdf_ocr tool which uses Tesseract for scanned documents. +""" + +import base64 +import io +from unittest.mock import MagicMock, patch + +import pytest +from PIL import Image +from reportlab.lib.pagesizes import letter +from reportlab.pdfgen import canvas + +from document_analysis_mcp.tools.ocr import ( + DEFAULT_DPI, + DEFAULT_LANGUAGE, + MIN_TEXT_CHARS, + _is_tesseract_available, + _validate_base64_size, + pdf_ocr, +) + + +def create_simple_pdf(text: str = "Hello World", num_pages: int = 1) -> bytes: + """Create a simple PDF with the given text. + + Args: + text: Text content for each page. + num_pages: Number of pages to create. + + Returns: + PDF file as bytes. + """ + buffer = io.BytesIO() + c = canvas.Canvas(buffer, pagesize=letter) + + for page_num in range(num_pages): + c.drawString(100, 700, f"{text} - Page {page_num + 1}") + if page_num < num_pages - 1: + c.showPage() + + c.save() + buffer.seek(0) + return buffer.read() + + +class TestValidateBase64Size: + """Tests for base64 size validation.""" + + def test_small_file_passes(self): + """Test that small files pass validation.""" + small_content = "A" * 1024 + _validate_base64_size(small_content, max_file_size_mb=1.0) + + def test_large_file_rejected(self): + """Test that files exceeding limit are rejected.""" + large_content = "A" * (2 * 1024 * 1024) + with pytest.raises(ValueError, match="exceeds maximum size limit"): + _validate_base64_size(large_content, max_file_size_mb=1.0) + + +class TestIsTesseractAvailable: + """Tests for Tesseract availability check.""" + + def test_with_tesseract_installed(self): + """Test when Tesseract is available.""" + with patch("pytesseract.get_tesseract_version") as mock_get_version: + mock_get_version.return_value = "5.0.0" + # Note: Function checks availability, so just verify it doesn't crash + result = _is_tesseract_available() + # Result depends on whether tesseract is actually installed + assert isinstance(result, bool) + + def test_without_tesseract(self): + """Test when Tesseract is not available.""" + with patch("pytesseract.get_tesseract_version") as mock_get_version: + mock_get_version.side_effect = Exception("Not found") + result = _is_tesseract_available() + assert result is False + + +class TestPdfOcr: + """Tests for pdf_ocr function.""" + + def test_empty_content_raises(self): + """Test that empty content raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + pdf_ocr("") + + def test_invalid_base64(self): + """Test handling of invalid base64 input.""" + with patch("document_analysis_mcp.tools.ocr._is_tesseract_available", return_value=True): + result = pdf_ocr("not-valid-base64!!!") + + assert not result["success"] + assert "Invalid base64 encoding" in result["error"] + + @patch("document_analysis_mcp.tools.ocr._is_tesseract_available") + def test_tesseract_not_available(self, mock_tesseract_check): + """Test handling when Tesseract is not installed.""" + mock_tesseract_check.return_value = False + + pdf_bytes = create_simple_pdf("OCR Test") + pdf_b64 = base64.b64encode(pdf_bytes).decode("utf-8") + + result = pdf_ocr(pdf_b64) + + assert not result["success"] + assert "Tesseract OCR is not installed" in result["error"] + assert result["ocr_stats"]["tesseract_available"] is False + + @patch("document_analysis_mcp.tools.ocr.get_tracker") + @patch("document_analysis_mcp.tools.ocr.get_cache") + @patch("document_analysis_mcp.tools.ocr._is_tesseract_available") + def test_successful_extraction_with_text(self, mock_tesseract, mock_cache, mock_tracker): + """Test OCR with a PDF that has extractable text (no OCR needed).""" + mock_tesseract.return_value = True + mock_cache.return_value.get.return_value = None + mock_tracker.return_value.record.return_value = MagicMock() + + # Create PDF with sufficient text + pdf_bytes = create_simple_pdf("A" * 200, num_pages=1) # More than MIN_TEXT_CHARS + pdf_b64 = base64.b64encode(pdf_bytes).decode("utf-8") + + result = pdf_ocr(pdf_b64, use_cache=False) + + assert result["success"] + assert "ocr_stats" in result + # Since text extraction worked, OCR shouldn't have been used + assert result["ocr_stats"]["pages_text_extracted"] >= 0 + + @patch("document_analysis_mcp.tools.ocr.get_tracker") + @patch("document_analysis_mcp.tools.ocr.get_cache") + @patch("document_analysis_mcp.tools.ocr._is_tesseract_available") + @patch("document_analysis_mcp.tools.ocr._ocr_image") + @patch("document_analysis_mcp.tools.ocr._page_to_image") + def test_force_ocr( + self, mock_page_to_image, mock_ocr_image, mock_tesseract, mock_cache, mock_tracker + ): + """Test force_ocr option.""" + mock_tesseract.return_value = True + mock_cache.return_value.get.return_value = None + mock_tracker.return_value.record.return_value = MagicMock() + mock_page_to_image.return_value = Image.new("RGB", (100, 100)) + mock_ocr_image.return_value = "OCR extracted text" + + pdf_bytes = create_simple_pdf("Some text") + pdf_b64 = base64.b64encode(pdf_bytes).decode("utf-8") + + result = pdf_ocr(pdf_b64, force_ocr=True, use_cache=False) + + assert result["success"] + assert result["ocr_stats"]["force_ocr"] is True + # With force_ocr, OCR should have been called + assert mock_ocr_image.called + + @patch("document_analysis_mcp.tools.ocr.get_tracker") + @patch("document_analysis_mcp.tools.ocr.get_cache") + @patch("document_analysis_mcp.tools.ocr._is_tesseract_available") + def test_max_pages_limit(self, mock_tesseract, mock_cache, mock_tracker): + """Test max_pages limits processing.""" + mock_tesseract.return_value = True + mock_cache.return_value.get.return_value = None + mock_tracker.return_value.record.return_value = MagicMock() + + pdf_bytes = create_simple_pdf("Page text", num_pages=10) + pdf_b64 = base64.b64encode(pdf_bytes).decode("utf-8") + + result = pdf_ocr(pdf_b64, max_pages=3, use_cache=False) + + assert result["success"] + assert result["processing_stats"]["pages_processed"] == 3 + assert result["processing_stats"]["total_pages"] == 10 + + @patch("document_analysis_mcp.tools.ocr.get_tracker") + @patch("document_analysis_mcp.tools.ocr.get_cache") + @patch("document_analysis_mcp.tools.ocr._is_tesseract_available") + def test_cache_hit(self, mock_tesseract, mock_cache, mock_tracker): + """Test cache hit returns cached result.""" + mock_tesseract.return_value = True + cached_result = { + "success": True, + "text": "Cached OCR text", + "processing_stats": {"cache_hit": False}, + } + mock_cache.return_value.get.return_value = cached_result + + pdf_bytes = create_simple_pdf("Test") + pdf_b64 = base64.b64encode(pdf_bytes).decode("utf-8") + + result = pdf_ocr(pdf_b64, use_cache=True) + + assert result["success"] + assert result["text"] == "Cached OCR text" + assert result["processing_stats"]["cache_hit"] is True + + @patch("document_analysis_mcp.tools.ocr.get_tracker") + @patch("document_analysis_mcp.tools.ocr.get_cache") + @patch("document_analysis_mcp.tools.ocr._is_tesseract_available") + def test_result_structure(self, mock_tesseract, mock_cache, mock_tracker): + """Test that result has expected structure.""" + mock_tesseract.return_value = True + mock_cache.return_value.get.return_value = None + mock_tracker.return_value.record.return_value = MagicMock() + + pdf_bytes = create_simple_pdf("Structure test") + pdf_b64 = base64.b64encode(pdf_bytes).decode("utf-8") + + result = pdf_ocr(pdf_b64, use_cache=False) + + assert result["success"] + assert "text" in result + assert "pages" in result + assert "metadata" in result + assert "processing_stats" in result + assert "ocr_stats" in result + assert "tesseract_available" in result["ocr_stats"] + assert "pages_ocr_processed" in result["ocr_stats"] + assert "language" in result["ocr_stats"] + assert "dpi" in result["ocr_stats"] + + +class TestDefaultValues: + """Tests for default configuration values.""" + + def test_default_dpi(self): + """Test default DPI value.""" + assert DEFAULT_DPI == 300 + + def test_default_language(self): + """Test default language value.""" + assert DEFAULT_LANGUAGE == "eng" + + def test_min_text_chars(self): + """Test minimum text chars threshold.""" + assert MIN_TEXT_CHARS == 100 diff --git a/tests/test_structure.py b/tests/test_structure.py new file mode 100644 index 0000000..1ebbca4 --- /dev/null +++ b/tests/test_structure.py @@ -0,0 +1,404 @@ +"""Tests for PDF structure extraction tool. + +This module tests the pdf_extract_structure tool for TOC, table, and heading detection. +""" + +import base64 +import io +from unittest.mock import MagicMock, patch + +import pytest +from reportlab.lib import colors +from reportlab.lib.pagesizes import letter +from reportlab.pdfgen import canvas +from reportlab.platypus import SimpleDocTemplate, Table, TableStyle + +from document_analysis_mcp.tools.structure import ( + DEFAULT_MAX_FILE_SIZE_MB, + HEADING_PATTERNS, + HeadingData, + TableData, + TOCEntry, + _clean_cell, + _detect_heading_level, + _table_to_markdown, + _validate_base64_size, + pdf_extract_structure, +) + + +def create_simple_pdf(text: str = "Hello World", num_pages: int = 1) -> bytes: + """Create a simple PDF with the given text. + + Args: + text: Text content for each page. + num_pages: Number of pages to create. + + Returns: + PDF file as bytes. + """ + buffer = io.BytesIO() + c = canvas.Canvas(buffer, pagesize=letter) + + for page_num in range(num_pages): + c.drawString(100, 700, f"{text} - Page {page_num + 1}") + if page_num < num_pages - 1: + c.showPage() + + c.save() + buffer.seek(0) + return buffer.read() + + +def create_pdf_with_table() -> bytes: + """Create a PDF with a simple table.""" + buffer = io.BytesIO() + doc = SimpleDocTemplate(buffer, pagesize=letter) + + # Create table data + data = [ + ["Header 1", "Header 2", "Header 3"], + ["Row 1 Col 1", "Row 1 Col 2", "Row 1 Col 3"], + ["Row 2 Col 1", "Row 2 Col 2", "Row 2 Col 3"], + ] + + table = Table(data) + table.setStyle( + TableStyle( + [ + ("BACKGROUND", (0, 0), (-1, 0), colors.grey), + ("GRID", (0, 0), (-1, -1), 1, colors.black), + ] + ) + ) + + doc.build([table]) + buffer.seek(0) + return buffer.read() + + +class TestCleanCell: + """Tests for _clean_cell function.""" + + def test_clean_none(self): + """Test cleaning None value.""" + assert _clean_cell(None) == "" + + def test_clean_string(self): + """Test cleaning string value.""" + assert _clean_cell(" test ") == "test" + + def test_escape_pipe(self): + """Test escaping pipe characters.""" + assert _clean_cell("a|b|c") == "a\\|b\\|c" + + def test_replace_newlines(self): + """Test replacing newlines.""" + assert _clean_cell("line1\nline2") == "line1 line2" + + def test_clean_number(self): + """Test cleaning number value.""" + assert _clean_cell(123) == "123" + + +class TestTableToMarkdown: + """Tests for _table_to_markdown function.""" + + def test_simple_table(self): + """Test converting a simple table.""" + table = [ + ["H1", "H2"], + ["V1", "V2"], + ] + result = _table_to_markdown(table) + + assert "| H1 | H2 |" in result + assert "| --- | --- |" in result + assert "| V1 | V2 |" in result + + def test_empty_table(self): + """Test empty table returns empty string.""" + assert _table_to_markdown([]) == "" + assert _table_to_markdown([[]]) == "" + + def test_table_with_none(self): + """Test table with None values.""" + table = [ + ["H1", "H2"], + [None, "V2"], + ] + result = _table_to_markdown(table) + assert "| | V2 |" in result + + def test_uneven_rows(self): + """Test table with uneven row lengths.""" + table = [ + ["H1", "H2", "H3"], + ["V1"], # Short row + ] + result = _table_to_markdown(table) + # Should pad with empty cells + assert "| V1 | | |" in result + + +class TestDetectHeadingLevel: + """Tests for _detect_heading_level function.""" + + def test_chapter_is_level_1(self): + """Test that Chapter headers are level 1.""" + level = _detect_heading_level("Chapter 1: Introduction", 500, 792) + assert level == 1 + + def test_section_is_level_2(self): + """Test that Section headers are level 2.""" + level = _detect_heading_level("Section 2.1 Overview", 400, 792) + assert level == 2 + + def test_numbered_heading_levels(self): + """Test numbered heading levels.""" + assert _detect_heading_level("1 First Level", 400, 792) == 1 + assert _detect_heading_level("1.2 Second Level", 400, 792) == 2 + assert _detect_heading_level("1.2.3 Third Level", 400, 792) == 3 + assert _detect_heading_level("1.2.3.4 Fourth Level", 400, 792) == 4 + # Level caps at 4 + assert _detect_heading_level("1.2.3.4.5 Fifth Level", 400, 792) == 4 + + def test_all_caps_near_top(self): + """Test all-caps headings near top of page.""" + # Near top (high y value in PDF coordinates) + level = _detect_heading_level("INTRODUCTION", 700, 792) + assert level in (1, 2) + + +class TestHeadingData: + """Tests for HeadingData dataclass.""" + + def test_create_heading(self): + """Test creating a heading data object.""" + heading = HeadingData( + page_number=1, + text="Introduction", + level=1, + number="1", + position=700.0, + ) + assert heading.page_number == 1 + assert heading.text == "Introduction" + assert heading.level == 1 + assert heading.number == "1" + + def test_heading_without_number(self): + """Test heading without section number.""" + heading = HeadingData( + page_number=2, + text="APPENDIX", + level=2, + ) + assert heading.number is None + + +class TestTableData: + """Tests for TableData dataclass.""" + + def test_create_table_data(self): + """Test creating table data object.""" + table = TableData( + page_number=1, + table_index=0, + rows=[["H1", "H2"], ["V1", "V2"]], + markdown="| H1 | H2 |\n| --- | --- |\n| V1 | V2 |", + row_count=2, + col_count=2, + has_header=True, + ) + assert table.page_number == 1 + assert table.row_count == 2 + assert table.col_count == 2 + + +class TestTOCEntry: + """Tests for TOCEntry dataclass.""" + + def test_create_toc_entry(self): + """Test creating TOC entry.""" + entry = TOCEntry( + title="Introduction", + page_number=1, + level=0, + ) + assert entry.title == "Introduction" + assert entry.page_number == 1 + assert entry.level == 0 + + +class TestValidateBase64Size: + """Tests for size validation.""" + + def test_small_file_passes(self): + """Test that small files pass.""" + _validate_base64_size("A" * 1000, 1.0) + + def test_large_file_fails(self): + """Test that large files fail.""" + with pytest.raises(ValueError, match="exceeds maximum size limit"): + _validate_base64_size("A" * (2 * 1024 * 1024), 1.0) + + +class TestPdfExtractStructure: + """Tests for pdf_extract_structure function.""" + + def test_empty_content_raises(self): + """Test that empty content raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + pdf_extract_structure("") + + def test_invalid_base64(self): + """Test handling of invalid base64.""" + result = pdf_extract_structure("not-valid-base64!!!") + assert not result["success"] + assert "Invalid base64 encoding" in result["error"] + + @patch("document_analysis_mcp.tools.structure.get_tracker") + @patch("document_analysis_mcp.tools.structure.get_cache") + def test_simple_pdf_extraction(self, mock_cache, mock_tracker): + """Test extraction from a simple PDF.""" + mock_cache.return_value.get.return_value = None + mock_tracker.return_value.record.return_value = MagicMock() + + pdf_bytes = create_simple_pdf("Test content") + pdf_b64 = base64.b64encode(pdf_bytes).decode("utf-8") + + result = pdf_extract_structure(pdf_b64, use_cache=False) + + assert result["success"] + assert "toc" in result + assert "tables" in result + assert "headings" in result + assert "structure_summary" in result + assert "processing_stats" in result + + @patch("document_analysis_mcp.tools.structure.get_tracker") + @patch("document_analysis_mcp.tools.structure.get_cache") + def test_max_pages_limit(self, mock_cache, mock_tracker): + """Test max_pages limits processing.""" + mock_cache.return_value.get.return_value = None + mock_tracker.return_value.record.return_value = MagicMock() + + pdf_bytes = create_simple_pdf("Multi-page test", num_pages=10) + pdf_b64 = base64.b64encode(pdf_bytes).decode("utf-8") + + result = pdf_extract_structure(pdf_b64, max_pages=3, use_cache=False) + + assert result["success"] + assert result["processing_stats"]["pages_processed"] == 3 + assert result["processing_stats"]["total_pages"] == 10 + + @patch("document_analysis_mcp.tools.structure.get_tracker") + @patch("document_analysis_mcp.tools.structure.get_cache") + def test_extraction_flags(self, mock_cache, mock_tracker): + """Test extraction can be disabled per feature.""" + mock_cache.return_value.get.return_value = None + mock_tracker.return_value.record.return_value = MagicMock() + + pdf_bytes = create_simple_pdf("Test") + pdf_b64 = base64.b64encode(pdf_bytes).decode("utf-8") + + result = pdf_extract_structure( + pdf_b64, + extract_tables=False, + extract_headings=False, + extract_toc=False, + use_cache=False, + ) + + assert result["success"] + # Features disabled should result in empty lists + # (Note: May still have entries if extraction is always run; + # the flags control whether to process, but empty results + # depend on actual PDF content) + + @patch("document_analysis_mcp.tools.structure.get_tracker") + @patch("document_analysis_mcp.tools.structure.get_cache") + def test_cache_hit(self, mock_cache, mock_tracker): + """Test cache hit returns cached result.""" + cached_result = { + "success": True, + "toc": [{"title": "Cached TOC", "page_number": 1, "level": 0}], + "tables": [], + "headings": [], + "structure_summary": {}, + "processing_stats": {"cache_hit": False}, + } + mock_cache.return_value.get.return_value = cached_result + + pdf_bytes = create_simple_pdf("Test") + pdf_b64 = base64.b64encode(pdf_bytes).decode("utf-8") + + result = pdf_extract_structure(pdf_b64, use_cache=True) + + assert result["success"] + assert result["toc"][0]["title"] == "Cached TOC" + assert result["processing_stats"]["cache_hit"] is True + + @patch("document_analysis_mcp.tools.structure.get_tracker") + @patch("document_analysis_mcp.tools.structure.get_cache") + def test_result_structure(self, mock_cache, mock_tracker): + """Test result has expected structure.""" + mock_cache.return_value.get.return_value = None + mock_tracker.return_value.record.return_value = MagicMock() + + pdf_bytes = create_simple_pdf("Structure test") + pdf_b64 = base64.b64encode(pdf_bytes).decode("utf-8") + + result = pdf_extract_structure(pdf_b64, use_cache=False) + + assert result["success"] + + # Check structure_summary + summary = result["structure_summary"] + assert "page_count" in summary + assert "has_toc" in summary + assert "toc_entry_count" in summary + assert "table_count" in summary + assert "heading_count" in summary + + # Check processing_stats + stats = result["processing_stats"] + assert "pages_processed" in stats + assert "total_pages" in stats + assert "processing_time_ms" in stats + assert "cache_hit" in stats + + +class TestHeadingPatterns: + """Tests for heading pattern matching.""" + + def test_patterns_exist(self): + """Test that heading patterns are defined.""" + assert len(HEADING_PATTERNS) > 0 + + def test_numbered_pattern(self): + """Test numbered section pattern.""" + import re + + numbered = HEADING_PATTERNS[0] + match = re.match(numbered, "1.2.3 Section Title") + assert match is not None + assert match.group(1) == "1.2.3" + assert match.group(2) == "Section Title" + + def test_roman_numeral_pattern(self): + """Test roman numeral pattern.""" + import re + + roman = HEADING_PATTERNS[1] + match = re.match(roman, "III. Third Section") + assert match is not None + + +class TestDefaultValues: + """Tests for default values.""" + + def test_default_max_file_size(self): + """Test default file size limit.""" + assert DEFAULT_MAX_FILE_SIZE_MB == 50.0 diff --git a/tests/test_tracking.py b/tests/test_tracking.py new file mode 100644 index 0000000..51f08ca --- /dev/null +++ b/tests/test_tracking.py @@ -0,0 +1,344 @@ +"""Tests for the usage tracking module. + +This module tests API usage tracking and cost estimation. +""" + +import tempfile +from datetime import datetime, timedelta, timezone +from pathlib import Path + +from document_analysis_mcp.tracking import ( + TOKEN_PRICING, + TRACKING_FILE, + UsageRecord, + UsageTracker, + get_tracker, +) + + +class TestUsageRecord: + """Tests for UsageRecord dataclass.""" + + def test_total_tokens(self): + """Test total tokens calculation.""" + record = UsageRecord( + timestamp=datetime.now(timezone.utc), + operation="extract", + model="claude-sonnet-4-20250514", + input_tokens=1000, + output_tokens=500, + ) + assert record.total_tokens == 1500 + + def test_estimated_cost_sonnet(self): + """Test cost estimation for Sonnet model.""" + record = UsageRecord( + timestamp=datetime.now(timezone.utc), + operation="extract", + model="claude-sonnet-4-20250514", + input_tokens=1_000_000, # 1M tokens + output_tokens=1_000_000, # 1M tokens + ) + # Sonnet: $3/M input + $15/M output = $18 + assert record.estimated_cost_usd == 18.0 + + def test_estimated_cost_haiku(self): + """Test cost estimation for Haiku model.""" + record = UsageRecord( + timestamp=datetime.now(timezone.utc), + operation="classify", + model="claude-3-5-haiku-20241022", + input_tokens=1_000_000, + output_tokens=1_000_000, + ) + # Haiku: $0.80/M input + $4.00/M output = $4.80 + assert record.estimated_cost_usd == 4.8 + + def test_estimated_cost_unknown_model(self): + """Test cost estimation for unknown model uses default.""" + record = UsageRecord( + timestamp=datetime.now(timezone.utc), + operation="test", + model="unknown-model", + input_tokens=1_000_000, + output_tokens=1_000_000, + ) + # Default uses same as Sonnet: $3/M + $15/M = $18 + assert record.estimated_cost_usd == 18.0 + + def test_to_dict(self): + """Test serialization to dictionary.""" + now = datetime.now(timezone.utc) + record = UsageRecord( + timestamp=now, + operation="extract", + model="claude-sonnet-4-20250514", + input_tokens=1000, + output_tokens=500, + processing_time_ms=2500.0, + document_hash="abc123", + success=True, + metadata={"pages": 5}, + ) + data = record.to_dict() + + assert data["operation"] == "extract" + assert data["model"] == "claude-sonnet-4-20250514" + assert data["input_tokens"] == 1000 + assert data["output_tokens"] == 500 + assert data["total_tokens"] == 1500 + assert data["processing_time_ms"] == 2500.0 + assert data["document_hash"] == "abc123" + assert data["success"] is True + assert data["metadata"] == {"pages": 5} + assert "estimated_cost_usd" in data + + def test_from_dict(self): + """Test deserialization from dictionary.""" + now = datetime.now(timezone.utc) + data = { + "timestamp": now.isoformat(), + "operation": "classify", + "model": "claude-3-5-haiku-20241022", + "input_tokens": 500, + "output_tokens": 100, + "processing_time_ms": 1000.0, + "document_hash": "xyz789", + "success": True, + "metadata": {"type": "legal"}, + } + record = UsageRecord.from_dict(data) + + assert record.operation == "classify" + assert record.model == "claude-3-5-haiku-20241022" + assert record.input_tokens == 500 + assert record.output_tokens == 100 + assert record.document_hash == "xyz789" + + +class TestUsageTracker: + """Tests for UsageTracker class.""" + + def test_record(self): + """Test recording a usage event.""" + with tempfile.TemporaryDirectory() as tmpdir: + tracker = UsageTracker(tracking_dir=Path(tmpdir)) + + record = tracker.record( + operation="extract", + model="claude-sonnet-4-20250514", + input_tokens=1000, + output_tokens=500, + processing_time_ms=2500.0, + ) + + assert record.operation == "extract" + assert record.model == "claude-sonnet-4-20250514" + assert record.total_tokens == 1500 + + # Verify file was written + tracking_file = Path(tmpdir) / TRACKING_FILE + assert tracking_file.exists() + + def test_get_records(self): + """Test retrieving recorded events.""" + with tempfile.TemporaryDirectory() as tmpdir: + tracker = UsageTracker(tracking_dir=Path(tmpdir)) + + # Record some events + tracker.record(operation="extract", model="sonnet", input_tokens=100, output_tokens=50) + tracker.record(operation="classify", model="haiku", input_tokens=200, output_tokens=100) + tracker.record(operation="extract", model="sonnet", input_tokens=300, output_tokens=150) + + # Get all records + records = tracker.get_records() + assert len(records) == 3 + + def test_get_records_filter_operation(self): + """Test filtering records by operation.""" + with tempfile.TemporaryDirectory() as tmpdir: + tracker = UsageTracker(tracking_dir=Path(tmpdir)) + + tracker.record(operation="extract", model="sonnet", input_tokens=100, output_tokens=50) + tracker.record(operation="classify", model="haiku", input_tokens=200, output_tokens=100) + tracker.record(operation="extract", model="sonnet", input_tokens=300, output_tokens=150) + + records = tracker.get_records(operation="extract") + assert len(records) == 2 + assert all(r.operation == "extract" for r in records) + + def test_get_records_filter_model(self): + """Test filtering records by model.""" + with tempfile.TemporaryDirectory() as tmpdir: + tracker = UsageTracker(tracking_dir=Path(tmpdir)) + + tracker.record(operation="extract", model="sonnet", input_tokens=100, output_tokens=50) + tracker.record(operation="classify", model="haiku", input_tokens=200, output_tokens=100) + tracker.record(operation="ocr", model="tesseract", input_tokens=0, output_tokens=0) + + records = tracker.get_records(model="sonnet") + assert len(records) == 1 + assert records[0].model == "sonnet" + + def test_get_records_filter_time_range(self): + """Test filtering records by time range.""" + with tempfile.TemporaryDirectory() as tmpdir: + tracker = UsageTracker(tracking_dir=Path(tmpdir)) + + now = datetime.now(timezone.utc) + + # Record events + tracker.record(operation="op1", model="model", input_tokens=100, output_tokens=50) + + # Filter to only recent records + records = tracker.get_records(since=now - timedelta(hours=1)) + assert len(records) == 1 + + def test_get_records_limit(self): + """Test limiting number of returned records.""" + with tempfile.TemporaryDirectory() as tmpdir: + tracker = UsageTracker(tracking_dir=Path(tmpdir)) + + # Record many events + for i in range(10): + tracker.record(operation=f"op{i}", model="model", input_tokens=i, output_tokens=i) + + records = tracker.get_records(limit=5) + assert len(records) == 5 + + def test_get_summary_empty(self): + """Test summary with no records.""" + with tempfile.TemporaryDirectory() as tmpdir: + tracker = UsageTracker(tracking_dir=Path(tmpdir)) + + summary = tracker.get_summary() + + assert summary["total_records"] == 0 + assert summary["total_tokens"] == 0 + assert summary["total_cost_usd"] == 0.0 + + def test_get_summary(self): + """Test usage summary calculation.""" + with tempfile.TemporaryDirectory() as tmpdir: + tracker = UsageTracker(tracking_dir=Path(tmpdir)) + + tracker.record( + operation="extract", + model="claude-sonnet-4-20250514", + input_tokens=1000, + output_tokens=500, + ) + tracker.record( + operation="classify", + model="claude-3-5-haiku-20241022", + input_tokens=500, + output_tokens=100, + ) + tracker.record( + operation="extract", + model="claude-sonnet-4-20250514", + input_tokens=2000, + output_tokens=1000, + success=False, + ) + + summary = tracker.get_summary() + + assert summary["total_records"] == 3 + assert summary["successful_records"] == 2 + assert summary["failed_records"] == 1 + assert summary["total_input_tokens"] == 3500 + assert summary["total_output_tokens"] == 1600 + assert summary["total_tokens"] == 5100 + assert "by_operation" in summary + assert "by_model" in summary + assert summary["by_operation"]["extract"]["count"] == 2 + assert summary["by_operation"]["classify"]["count"] == 1 + + def test_get_daily_summary(self): + """Test daily summary breakdown.""" + with tempfile.TemporaryDirectory() as tmpdir: + tracker = UsageTracker(tracking_dir=Path(tmpdir)) + + # Record some events today + tracker.record( + operation="extract", model="sonnet", input_tokens=1000, output_tokens=500 + ) + tracker.record(operation="classify", model="haiku", input_tokens=500, output_tokens=100) + + daily = tracker.get_daily_summary(days=3) + + assert len(daily) == 3 + # Today's summary should have 2 records + assert daily[0]["total_records"] == 2 + assert daily[0]["total_tokens"] == 2100 + + def test_record_with_metadata(self): + """Test recording event with metadata.""" + with tempfile.TemporaryDirectory() as tmpdir: + tracker = UsageTracker(tracking_dir=Path(tmpdir)) + + record = tracker.record( + operation="extract", + model="sonnet", + input_tokens=100, + output_tokens=50, + metadata={"document_type": "research_paper", "pages": 42}, + ) + + assert record.metadata["document_type"] == "research_paper" + assert record.metadata["pages"] == 42 + + def test_record_failure(self): + """Test recording a failed operation.""" + with tempfile.TemporaryDirectory() as tmpdir: + tracker = UsageTracker(tracking_dir=Path(tmpdir)) + + record = tracker.record( + operation="extract", + model="sonnet", + input_tokens=0, + output_tokens=0, + success=False, + error_message="API rate limit exceeded", + ) + + assert record.success is False + assert record.error_message == "API rate limit exceeded" + + records = tracker.get_records() + assert records[0].success is False + + +class TestTokenPricing: + """Tests for token pricing constants.""" + + def test_pricing_has_common_models(self): + """Test that pricing includes common models.""" + assert "claude-sonnet-4-20250514" in TOKEN_PRICING + assert "claude-3-5-haiku-20241022" in TOKEN_PRICING + assert "claude-3-opus-20240229" in TOKEN_PRICING + assert "default" in TOKEN_PRICING + + def test_pricing_structure(self): + """Test that pricing has expected structure.""" + for _model, pricing in TOKEN_PRICING.items(): + assert "input" in pricing + assert "output" in pricing + assert pricing["input"] > 0 + assert pricing["output"] > 0 + + +class TestGetTracker: + """Tests for get_tracker singleton function.""" + + def test_returns_same_instance(self): + """Test that get_tracker returns the same instance.""" + import document_analysis_mcp.tracking as tracking_module + + tracking_module._tracker = None + + tracker1 = get_tracker() + tracker2 = get_tracker() + + assert tracker1 is tracker2