diff --git a/README.md b/README.md
index aab401f..aefa5c3 100644
--- a/README.md
+++ b/README.md
@@ -13,16 +13,16 @@
## 📖 Description
-With `extrai`, you can extract data from text documents with LLMs, which will be formatted into a given `SQLModel` and registered in your database.
+`extrai` extracts data from text documents using LLMs, formatting the output into a given `SQLModel` and registering it in a database.
-The core of the library is its [Consensus Mechanism](https://docs.extrai.xyz/concepts/consensus_mechanism.html). We make the same request multiple times, using the same or different providers, and then select the values that meet a certain threshold.
+The library utilizes a [Consensus Mechanism](https://docs.extrai.xyz/concepts/consensus_mechanism.html) to ensure accuracy. It makes the same request multiple times, using the same or different providers, and then selects the values that meet a configured threshold.
`extrai` also has other features, like [generating `SQLModel`s](https://docs.extrai.xyz/how_to/generate_sql_model.html) from a prompt and documents, and [generating few-shot examples](https://docs.extrai.xyz/how_to/generate_example_json.html). For complex, nested data, the library offers [Hierarchical Extraction](https://docs.extrai.xyz/how_to/handle_complex_data_with_hierarchical_extraction.html), breaking down the extraction into manageable, hierarchical steps. It also includes [built-in analytics](https://docs.extrai.xyz/analytics_collector.html) to monitor performance and output quality.
## ✨ Key Features
-- **[Consensus Mechanism](https://docs.extrai.xyz/concepts/consensus_mechanism.html)**: Improves extraction accuracy by consolidating multiple LLM outputs.
-- **[Dynamic SQLModel Generation](https://docs.extrai.xyz/sqlmodel_generator.html)**: Generate `SQLModel` schemas from natural language descriptions.
+- **[Consensus Mechanism](https://docs.extrai.xyz/concepts/consensus_mechanism.html)**: Consolidates multiple LLM outputs to improve extraction accuracy.
+- **[Dynamic SQLModel Generation](https://docs.extrai.xyz/sqlmodel_generator.html)**: Generates `SQLModel` schemas from natural language descriptions.
- **[Hierarchical Extraction](https://docs.extrai.xyz/how_to/handle_complex_data_with_hierarchical_extraction.html)**: Handles complex, nested data by breaking down the extraction into manageable, hierarchical steps.
- **[Extensible LLM Support](https://docs.extrai.xyz/llm_providers.html)**: Integrates with various LLM providers through a client interface.
- **[Built-in Analytics](https://docs.extrai.xyz/analytics_collector.html)**: Collects metrics on LLM performance and output quality to refine prompts and monitor errors.
@@ -59,7 +59,7 @@ For a complete guide, please see the full documentation. Here are the key sectio
- **Community**
- [Contributing Guide](https://docs.extrai.xyz/contributing.html)
-## ⚙️ Worflow Overview
+## ⚙️ Workflow Overview
The library is built around a few key components that work together to manage the extraction workflow. The following diagram illustrates the high-level workflow (see [Architecture Overview](https://docs.extrai.xyz/concepts/architecture_overview.html)):
diff --git a/docs/advanced/batch_deep_dive.md b/docs/advanced/batch_deep_dive.md
new file mode 100644
index 0000000..9760adb
--- /dev/null
+++ b/docs/advanced/batch_deep_dive.md
@@ -0,0 +1,120 @@
+# Batch Processing Deep Dive
+
+Batch processing in `extrai` is designed for high-volume extraction tasks where immediate results are not required and cost optimization is a priority. It leverages the "Batch API" features of LLM providers (like OpenAI's Batch API) to process requests asynchronously at a lower cost (often 50% cheaper).
+
+## The Batch State Machine
+
+The batch pipeline is managed by a robust state machine that transitions a job through various phases. This ensures that even long-running jobs can be tracked, resumed, and recovered in case of failures.
+
+### States
+
+The `BatchJobStatus` enum defines the possible states:
+
+* **SUBMITTED**: The initial extraction request has been sent to the LLM provider.
+* **PROCESSING**: The provider is currently running the batch.
+* **READY_TO_PROCESS**: The provider has finished, and the results are downloaded but not yet processed by `extrai`.
+* **COMPLETED**: All results have been processed, consensus run, and objects hydrated.
+* **FAILED**: The batch job failed at the provider or during local processing.
+* **CANCELLED**: The job was manually cancelled.
+
+#### Counting Phase States
+
+If entity counting is enabled, the job goes through a "pre-flight" counting phase:
+
+* **COUNTING_SUBMITTED**: The counting request is with the provider.
+* **COUNTING_PROCESSING**: Counting is running.
+* **COUNTING_READY_TO_PROCESS**: Counts are ready to be used for generating the main extraction prompts.
+
+### State Transition Diagram
+
+```mermaid
+stateDiagram-v2
+ [*] --> COUNTING_SUBMITTED: count_entities=True
+ [*] --> SUBMITTED: count_entities=False
+
+ COUNTING_SUBMITTED --> COUNTING_PROCESSING
+ COUNTING_PROCESSING --> COUNTING_READY_TO_PROCESS
+ COUNTING_READY_TO_PROCESS --> SUBMITTED: Generate Extraction Prompts
+
+ SUBMITTED --> PROCESSING
+ PROCESSING --> READY_TO_PROCESS
+ READY_TO_PROCESS --> COMPLETED: Consensus & Hydration
+
+ PROCESSING --> FAILED
+ SUBMITTED --> FAILED
+ COUNTING_PROCESSING --> FAILED
+
+ FAILED --> [*]
+ COMPLETED --> [*]
+```
+
+## Production Workflows
+
+### Submission & Polling
+
+The `WorkflowOrchestrator.synthesize_batch` method handles the lifecycle. You can use it in a blocking or non-blocking way.
+
+**Blocking (Simplest):**
+```python
+results = await orchestrator.synthesize_batch(
+ input_strings=docs,
+ wait_for_completion=True
+)
+```
+This will poll the provider internally, handle transitions from counting to extraction, and return the final objects.
+
+**Non-Blocking (Async):**
+```python
+# Submit
+batch_id = await orchestrator.synthesize_batch(
+ input_strings=docs,
+ wait_for_completion=False
+)
+
+# Later... check status
+current_status = await orchestrator.get_batch_status(batch_id, db_session)
+if current_status.status == BatchJobStatus.COMPLETED:
+ results = await orchestrator.get_batch_results(batch_id)
+```
+
+### Error Recovery & Resuming
+
+If your application crashes while a batch is running, you don't lose the job. The `root_batch_id` is the key to recovery.
+
+**Resuming Monitoring (e.g. after script restart)**
+
+If the batch job is still active or completed at the provider, but your script stopped monitoring it, simply call `monitor_batch_job` again:
+
+```python
+# Resume monitoring
+results = await orchestrator.monitor_batch_job(
+ root_batch_id="batch_123_abc",
+ db_session=session,
+ poll_interval=60
+)
+```
+This method inspects the current state (e.g., `COUNTING_SUBMITTED`, `READY_TO_PROCESS`) and automatically picks up where it left off, handling transitions between phases.
+
+**Retrying or Extending a Batch**
+
+If a batch failed or if you want to extend a completed workflow (e.g., adding more hierarchical steps), use `create_continuation_batch`. This creates a *new* batch job that copies the successful steps from the old one, saving time and money.
+
+```python
+# Continue from step 2
+new_batch_id = await orchestrator.create_continuation_batch(
+ original_batch_id="failed_batch_id",
+ db_session=session,
+ start_from_step_index=2, # Copy steps 0 and 1, restart from 2
+ wait_for_completion=True
+)
+```
+
+## Hierarchical Batches & Shallow Schemas
+
+When using `use_hierarchical_extraction=True` with batches, the process involves multiple "hops".
+
+1. **Level 1 Extraction**: The root object is extracted.
+2. **Shallow Schema Generation**: For lists of children, `extrai` generates "Shallow Schemas" (just IDs and essential fields) to keep the context window small.
+3. **Child Batches**: New batch jobs are spawned for each child entity to extract full details.
+
+This complex coordination is handled automatically by the `BatchPipeline`.
diff --git a/docs/analytics_collector.rst b/docs/analytics_collector.rst
index 94d5aa2..ef76491 100644
--- a/docs/analytics_collector.rst
+++ b/docs/analytics_collector.rst
@@ -57,6 +57,19 @@ These methods track the health of LLM API calls and output processing.
# To record an error when the LLM's JSON output fails schema validation
analytics_collector.record_llm_output_validation_error()
+**Token Usage Tracking**
+
+The collector automatically aggregates token usage from supported LLM providers (e.g., OpenAI, Gemini).
+
+.. code-block:: python
+
+ # This is typically called internally by the LLM client
+ analytics_collector.record_llm_usage(input_tokens=150, output_tokens=50)
+
+ # Access totals
+ print(f"Total Input: {analytics_collector.total_input_tokens}")
+ print(f"Total Output: {analytics_collector.total_output_tokens}")
+
**Consensus Run Details**
This method is used to log the outcome of a consensus process. It takes a dictionary of metrics, usually generated by the ``JSONConsensus`` component.
@@ -137,6 +150,8 @@ A report provides a detailed summary of the entire workflow.
"llm_output_parse_errors": 2,
"llm_output_validation_errors": 1,
"total_invalid_parsing_errors": 3,
+ "total_input_tokens": 1500,
+ "total_output_tokens": 450,
"number_of_consensus_runs": 1,
"hydrated_objects_successes": 10,
"hydration_failures": 0,
diff --git a/docs/api/extrai.core.rst b/docs/api/extrai.core.rst
index 381e15f..283ad0d 100644
--- a/docs/api/extrai.core.rst
+++ b/docs/api/extrai.core.rst
@@ -1,77 +1,197 @@
-llm\_consensus\_extraction.core package
-=======================================
+extrai.core package
+===================
Submodules
----------
-llm\_consensus\_extraction.core.analytics\_collector module
------------------------------------------------------------
+extrai.core.analytics_collector module
+--------------------------------------
.. automodule:: extrai.core.analytics_collector
:members:
:show-inheritance:
:undoc-members:
-llm\_consensus\_extraction.core.base\_llm\_client module
---------------------------------------------------------
+extrai.core.base_llm_client module
+----------------------------------
.. automodule:: extrai.core.base_llm_client
:members:
:show-inheritance:
:undoc-members:
-llm\_consensus\_extraction.core.db\_writer module
--------------------------------------------------
+extrai.core.batch_models module
+-------------------------------
-.. automodule:: extrai.core.db_writer
+.. automodule:: extrai.core.batch_models
:members:
:show-inheritance:
:undoc-members:
-llm\_consensus\_extraction.core.example\_json\_generator module
----------------------------------------------------------------
+extrai.core.batch_pipeline module
+---------------------------------
+
+.. automodule:: extrai.core.batch_pipeline
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.client_rotator module
+---------------------------------
+
+.. automodule:: extrai.core.client_rotator
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.conflict_resolvers module
+-------------------------------------
+
+.. automodule:: extrai.core.conflict_resolvers
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.entity_counter module
+---------------------------------
+
+.. automodule:: extrai.core.entity_counter
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.errors module
+-------------------------
+
+.. automodule:: extrai.core.errors
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.example_json_generator module
+-----------------------------------------
.. automodule:: extrai.core.example_json_generator
:members:
:show-inheritance:
:undoc-members:
-llm\_consensus\_extraction.core.json\_consensus module
-------------------------------------------------------
+extrai.core.extraction_config module
+------------------------------------
+
+.. automodule:: extrai.core.extraction_config
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.extraction_context_preparer module
+----------------------------------------------
+
+.. automodule:: extrai.core.extraction_context_preparer
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.extraction_pipeline module
+--------------------------------------
+
+.. automodule:: extrai.core.extraction_pipeline
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.extraction_request_factory module
+---------------------------------------------
+
+.. automodule:: extrai.core.extraction_request_factory
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.hierarchical_extractor module
+-----------------------------------------
+
+.. automodule:: extrai.core.hierarchical_extractor
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.json_consensus module
+---------------------------------
.. automodule:: extrai.core.json_consensus
:members:
:show-inheritance:
:undoc-members:
-llm\_consensus\_extraction.core.prompt\_builder module
-------------------------------------------------------
+extrai.core.llm_runner module
+-----------------------------
+
+.. automodule:: extrai.core.llm_runner
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.model_registry module
+---------------------------------
+
+.. automodule:: extrai.core.model_registry
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.model_wrapper_builder module
+----------------------------------------
+
+.. automodule:: extrai.core.model_wrapper_builder
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+extrai.core.prompt_builder module
+---------------------------------
.. automodule:: extrai.core.prompt_builder
:members:
:show-inheritance:
:undoc-members:
-llm\_consensus\_extraction.core.schema\_inspector module
---------------------------------------------------------
+extrai.core.result_processor module
+-----------------------------------
-.. automodule:: extrai.core.schema_inspector
+.. automodule:: extrai.core.result_processor
:members:
:show-inheritance:
:undoc-members:
-llm\_consensus\_extraction.core.sqlalchemy\_hydrator module
------------------------------------------------------------
+extrai.core.schema_inspector module
+-----------------------------------
-.. automodule:: extrai.core.sqlalchemy_hydrator
+.. automodule:: extrai.core.schema_inspector
:members:
:show-inheritance:
:undoc-members:
-llm\_consensus\_extraction.core.sqlmodel\_generator module
-----------------------------------------------------------
+extrai.core.sqlmodel_generator module
+-------------------------------------
.. automodule:: extrai.core.sqlmodel_generator
:members:
:show-inheritance:
:undoc-members:
+
+extrai.core.workflow_orchestrator module
+----------------------------------------
+
+.. automodule:: extrai.core.workflow_orchestrator
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+Module contents
+---------------
+
+.. automodule:: extrai.core
+ :members:
+ :show-inheritance:
+ :undoc-members:
diff --git a/docs/batch_processing_integration.md b/docs/batch_processing_integration.md
new file mode 100644
index 0000000..5c3369c
--- /dev/null
+++ b/docs/batch_processing_integration.md
@@ -0,0 +1,100 @@
+# Batch Processing Integration Guide
+
+This document outlines the architecture and usage of the Batch API integration in `WorkflowOrchestrator`. This feature allows for asynchronous, cost-effective extraction workflows using providers like Google Gemini or OpenAI.
+
+## Overview
+
+The Batch Processing integration allows you to offload the heavy lifting of LLM extraction to a background batch process. This is particularly useful for:
+
+* **Cost Reduction**: Batch APIs often offer significant discounts (e.g., 50% off).
+* **Scalability**: Decoupling submission from processing allows for massive throughput.
+* **Resilience**: The integration includes an automatic state machine that handles retries, hierarchical dependencies, and complex workflows.
+
+## Features
+
+### 1. Structured Output Support
+The Batch Pipeline fully supports **Structured Output** (Pydantic models). The system automatically converts your SQLModel classes into the appropriate JSON Schema/Response Format expected by the provider's Batch API.
+
+### 2. Automated Hierarchical Workflows
+For complex nested data, the pipeline implements a robust **State Machine** that manages dependencies automatically:
+
+1. **Counting Phase (`COUNTING_SUBMITTED`)**: (Optional) Submits a batch job to count entities.
+2. **Root Extraction (`SUBMITTED`)**: Extracts the top-level objects using the counts as constraints.
+3. **Child Extraction (`HIERARCHICAL_STEP_SUBMITTED`)**: Recursively submits batch jobs for child entities, linking them to their parents via Foreign Key Recovery.
+4. **Completion (`COMPLETED`)**: Aggregates all results into a final hydrated object graph.
+
+### 3. Smart Resumption and Continuation
+If a batch job fails or is interrupted, you don't need to start over.
+* **Monitoring**: Use `monitor_batch_job` to resume tracking an active batch.
+* **Continuation**: Use `create_continuation_batch` to create a new batch that continues from a specific step of a previous one, preserving completed work.
+
+## Usage Guide
+
+### 1. Simple Execution (Managed Loop)
+The easiest way to use the batch API is with the `wait_for_completion=True` flag. The orchestrator handles the polling loop for you.
+
+```python
+results = await orchestrator.synthesize_batch(
+ input_strings=["..."],
+ db_session=db_session,
+ wait_for_completion=True,
+ count_entities=True
+)
+```
+
+### 2. Manual Integration (e.g., Celery)
+For production environments, you may want to manage the polling yourself.
+
+**Phase A: Submission**
+```python
+job_id = await orchestrator.synthesize_batch(
+ input_strings=["..."],
+ db_session=db_session,
+ count_entities=True,
+ wait_for_completion=False
+)
+# Store job_id in your database
+```
+
+**Phase B: Processing Loop**
+Periodically check the status and process results. The `process_batch` method is the engine that drives the state machine.
+
+```python
+# In your background worker
+status = await orchestrator.get_batch_status(job_id, db_session)
+
+if status == BatchJobStatus.READY_TO_PROCESS:
+ # Download and process results
+ # If this was a hierarchical step, this method automagically submits the NEXT batch!
+ result = await orchestrator.process_batch(job_id, db_session)
+
+ if result.status == BatchJobStatus.SUBMITTED:
+ # A new batch (e.g., for child entities) was started.
+ new_provider_id = result.retry_batch_id
+ # Update your DB tracking
+
+ elif result.status == BatchJobStatus.COMPLETED:
+ # The entire workflow is done.
+ final_objects = result.hydrated_objects
+```
+
+### 3. Resuming a Job
+If a job gets stuck or you need to re-run a specific phase (e.g., re-do extraction but keep the counts), use `create_continuation_batch`.
+
+```python
+# Restart from hierarchical step 2 (e.g., "Flights" extraction)
+new_batch_id = await orchestrator.create_continuation_batch(
+ original_batch_id="failed_batch_id",
+ db_session=db_session,
+ start_from_step_index=2,
+ wait_for_completion=True
+)
+```
+
+## Technical Details
+
+### Context Storage
+Batch jobs can involve large contexts (documents + history). The system stores this context in the database using optimized `JSON` column types (mapped to `LONGTEXT` or native `JSON` depending on the DB) to prevent size limit errors.
+
+### Shallow Schema Enforcement
+During hierarchical steps, the system uses "Shallow Schemas" (pydantic models without nested relationships) to prevent the LLM from hallucinating deep structures that should be extracted in subsequent steps.
diff --git a/docs/concepts/architecture_overview.rst b/docs/concepts/architecture_overview.rst
index 3a65f27..a59820b 100644
--- a/docs/concepts/architecture_overview.rst
+++ b/docs/concepts/architecture_overview.rst
@@ -5,6 +5,19 @@ Architecture Overview
The `extrai` library follows a modular, multi-stage pipeline to transform unstructured text into structured, database-ready objects. This document provides an overview of this architecture, covering both the standard and optional dynamic model generation workflows.
+Modular Design (Facade Pattern)
+-------------------------------
+
+The core `WorkflowOrchestrator` acts as a **Facade**, delegating complex logic to specialized components. This separation of concerns ensures maintainability and extensibility.
+
+* **ModelRegistry**: Handles schema discovery and model lookup.
+* **ExtractionPipeline**: Orchestrates the core extraction workflow (Standard or Hierarchical).
+* **EntityCounter**: Performs the pre-extraction counting phase to constrain LLM output.
+* **LLMRunner**: Manages LLM client rotation, parallel revision generation, and consensus execution.
+* **JSONConsensus**: The engine for resolving conflicts between LLM revisions.
+* **BatchPipeline**: Manages asynchronous batch job submission, tracking, and processing.
+* **ResultProcessor**: Handles object hydration (converting JSON to SQLModel) and database persistence.
+
Core Workflow Diagram
---------------------
@@ -38,20 +51,18 @@ The following diagram illustrates the complete workflow, including the optional
subgraph "Data Extraction"
EG("📝
Example Generation
(Optional)")
- P("✍️
Prompt Generation")
+ CNT("1️⃣
Entity Counting")
+ P("✍️
Prompt/Schema Prep")
subgraph "LLM Extraction Revisions"
direction LR
E1("🤖
Revision 1")
- H1("💧
SQLAlchemy Hydration 1")
E2("🤖
Revision 2")
- H2("💧
SQLAlchemy Hydration 2")
- E3("🤖
...")
- H3("💧
...")
+ E3("🤖
Revision 3")
end
F("🤝
JSON Consensus")
- H("💧
SQLAlchemy Hydration")
+ H("💧
Hydration (ResultProcessor)")
end
subgraph Outputs
@@ -65,15 +76,13 @@ The following diagram illustrates the complete workflow, including the optional
A --> P
B --> EG
EG --> P
- P --> E1
- P --> E2
- P --> E3
- E1 --> H1
- E2 --> H2
- E3 --> H3
- H1 --> F
- H2 --> F
- H3 --> F
+ P --> CNT
+ CNT --> E1
+ CNT --> E2
+ CNT --> E3
+ E1 --> F
+ E2 --> F
+ E3 --> F
F --> H
H --> O
H --> DB
@@ -89,11 +98,49 @@ The following diagram illustrates the complete workflow, including the optional
%% Apply styles
class A,B,C,D,L1,L2 inputStyle;
- class P,E1,E2,E3,H,EG processStyle;
+ class P,CNT,E1,E2,E3,H,EG processStyle;
class F consensusStyle;
class O,DB,SM outputStyle;
class MG modelGenStyle;
+Component Interaction (Sequence Diagram)
+----------------------------------------
+
+This sequence diagram details the interaction between the Facade (`WorkflowOrchestrator`) and its internal subsystems during a standard extraction run.
+
+.. mermaid::
+
+ sequenceDiagram
+ participant User
+ participant WO as WorkflowOrchestrator
+ participant EC as EntityCounter
+ participant LLM as LLMRunner
+ participant CON as JSONConsensus
+ participant RP as ResultProcessor
+ participant DB as Database
+
+ User->>WO: synthesize_and_save(text)
+ WO->>EC: count_entities(text)
+ EC-->>WO: Entity Counts
+
+ WO->>LLM: run(text, schema, counts)
+ loop Revisions
+ LLM->>LLM: Generate Revision 1..N
+ end
+ LLM-->>WO: List[JSON Revisions]
+
+ WO->>CON: consensus(revisions)
+ CON-->>WO: Final JSON
+
+ WO->>RP: hydrate(json)
+ RP-->>WO: SQLModel Objects
+
+ WO->>RP: save(objects, session)
+ RP->>DB: commit()
+ RP-->>WO: Saved Objects
+
+ WO-->>User: Saved Objects
+
Workflow Stages
---------------
@@ -103,16 +150,20 @@ The library processes data through the following stages:
1. **Documents Ingestion**: The `WorkflowOrchestrator` accepts one or more text documents as the primary input for extraction.
-2. **Schema Introspection**: The library inspects the provided `SQLModel` classes (either predefined or dynamically generated) to create a detailed JSON schema. This schema is crucial for instructing the LLM on the desired output format.
+2. **Schema Introspection**: The `ModelRegistry` inspects the provided `SQLModel` classes (either predefined or dynamically generated) to create a detailed JSON schema or Pydantic model wrapper. This schema is crucial for instructing the LLM on the desired output format (Standard JSON or Structured Output).
+
+3. **Example Generation (Optional)**: To improve the accuracy of the LLM, the `ExtractionPipeline` can auto-generate few-shot examples from the schema. These examples are included in the prompt to give the LLM a clear template to follow.
-3. **Example Generation (Optional)**: To improve the accuracy of the LLM, the `ExampleJSONGenerator` can create few-shot examples from the schema. These examples are included in the prompt to give the LLM a clear template to follow.
+4. **Entity Counting**: Before full extraction, the `EntityCounter` performs a high-level pass to count the expected number of entities. This count is injected into the extraction prompt as a "Critical Quantity Constraint" to improve recall.
-4. **Prompt Generation**: The `PromptBuilder` combines the JSON schema, the input documents, and any few-shot examples into a comprehensive system prompt and a user prompt.
+5. **Prompt & Request Preparation**: The `ExtractionRequestFactory` and `PromptBuilder` combine the schema, input documents, entity counts, and examples into a comprehensive prompt. If using "Structured Extraction", this step also prepares the Pydantic models for the LLM's `response_format`.
-5. **LLM Interaction & Revisioning**: The configured `LLMClient` sends the prompts to the LLM to produce multiple, independent JSON structures (revisions). This step is fundamental to the consensus mechanism.
+6. **LLM Interaction & Revisioning**: The `LLMRunner` rotates through configured clients and sends the prompts to the LLM to produce multiple, independent revisions. This step is fundamental to the consensus mechanism.
-6. **JSON Validation & Consensus**: Each JSON revision from the LLM is validated against the schema. The `JSONConsensus` class then takes all valid revisions and applies a consensus algorithm to resolve discrepancies, producing a single, unified JSON object.
+7. **Consensus**: The `JSONConsensus` engine takes all valid revisions and applies a **Weighted Consensus** algorithm. Revisions are weighted by their global similarity to others, and conflicts are resolved field-by-field. Strategies like `SimilarityClusterResolver` ensure that semantic equivalents (e.g., "US" vs "U.S.A.") are correctly unified.
-7. **SQLAlchemy Object Hydration**: The `SQLAlchemyHydrator` transforms the final consensus JSON into a graph of `SQLModel` instances, correctly linking related objects.
+8. **Object Hydration**: The `ResultProcessor` transforms the final consensus JSON into a graph of `SQLModel` instances. It supports two strategies:
+ * **Direct Hydration**: Recursive instantiation for "Structured Output" results.
+ * **SQLAlchemy Hydration**: Graph reconstruction for flat JSON results (legacy).
-8. **Database Persistence (Optional)**: The hydrated `SQLModel` objects can be saved to a relational database via a standard SQLAlchemy session.
+9. **Database Persistence (Optional)**: The hydrated objects are persisted to the database. The `ResultProcessor` performs **Foreign Key Recovery** to ensure relationships are preserved even if Primary Keys were stripped during processing to prevent collisions.
diff --git a/docs/concepts/consensus_mechanism.rst b/docs/concepts/consensus_mechanism.rst
index ca08481..c213920 100644
--- a/docs/concepts/consensus_mechanism.rst
+++ b/docs/concepts/consensus_mechanism.rst
@@ -13,7 +13,7 @@ The Core Idea: Field-Level Agreement
Instead of comparing entire JSON objects, which can be brittle, the consensus mechanism works on a field-by-field basis. It achieves this through a three-step process:
1. **Flattening**: Each JSON revision is "flattened" into a simple key-value dictionary. Nested structures and list elements are represented using a dot-notation path.
-2. **Aggregation & Voting**: The algorithm aggregates all the values for each unique path across all revisions and determines if any value meets a predefined agreement threshold.
+2. **Weighted Aggregation**: The algorithm calculates a "Trust Score" for each revision based on its similarity to others, then aggregates values for each path.
3. **Un-flattening**: The paths that reached a consensus are used to reconstruct the final, nested JSON object.
Step 1: Flattening
@@ -46,48 +46,59 @@ These are flattened into:
- **Revision 1:** ``{"name": "SuperWidget", "specs.ram_gb": 16, "tags.0": "A", "tags.1": "B"}``
- **Revision 2:** ``{"name": "SuperWidget", "specs.ram_gb": 32, "tags.0": "A", "tags.1": "C"}``
-Step 2: Aggregation and Voting
-------------------------------
+Step 2: Weighted Aggregation and Voting
+---------------------------------------
-The algorithm then groups the values for each path:
+Unlike simple majority voting, `extrai` uses a **Weighted Consensus** algorithm.
-- ``name``: ``["SuperWidget", "SuperWidget"]``
-- ``specs.ram_gb``: ``[16, 32]``
-- ``tags.0``: ``["A", "A"]``
-- ``tags.1``: ``["B", "C"]``
+1. **Trust Score Calculation**: Each revision is compared against all others (using Levenshtein similarity). Revisions that are more similar to the group average get a higher weight. This helps filter out "hallucinations" or "lazy" responses that diverge significantly from the consensus.
+2. **Vote Aggregation**: The algorithm groups values for each path and sums their weights.
-Next, it checks each path against the ``consensus_threshold``. This threshold (a float between 0.0 and 1.0) defines the minimum proportion of revisions that must agree. Let's assume a ``consensus_threshold`` of ``0.5``, meaning more than 50% of revisions must agree.
+For example, if Revision 1 is deemed "more trustworthy" (weight 1.2) and Revision 2 is "less trustworthy" (weight 0.8):
-- ``name``: "SuperWidget" appears in 2/2 revisions (100%). **Consensus reached.**
-- ``specs.ram_gb``: 16 appears in 1/2 (50%), 32 appears in 1/2 (50%). Neither meets the "> 50%" threshold. **No consensus.**
-- ``tags.0``: "A" appears in 2/2 revisions (100%). **Consensus reached.**
-- ``tags.1``: "B" appears in 1/2 (50%), "C" appears in 1/2 (50%). **No consensus.**
+- ``name`` ("SuperWidget"): 1.2 + 0.8 = 2.0 (Total Agreement)
+- ``tags.1``:
+ - "B": 1.2
+ - "C": 0.8
-Step 3: Un-flattening and Conflict Resolution
----------------------------------------------
+The system then checks if the **Weighted Agreement Ratio** (value weight / total weight) meets the ``consensus_threshold``.
-Only the paths that reached consensus are kept:
+Step 3: Conflict Resolution & Clustering
+----------------------------------------
-- ``name``: "SuperWidget"
-- ``tags.0``: "A"
+What happens when no value meets the threshold? The system employs a `conflict_resolver`.
-These are then un-flattened to produce the final JSON object:
+**Standard Resolvers**:
-.. code-block:: json
+- ``default_conflict_resolver``: Omits the field if no consensus is reached.
+- ``prefer_most_common_resolver``: Picks the value with the highest weight, even if it's below the threshold.
- {
- "name": "SuperWidget",
- "tags": ["A"]
- }
+**Advanced: Similarity Cluster Resolver**
-Notice that ``specs.ram_gb`` and the second tag are missing. This is the default behavior when no consensus is reached for a path.
+String fields often suffer from minor formatting differences that cause false conflicts. The ``SimilarityClusterResolver`` handles this by clustering similar values.
-Conflict Resolution
--------------------
+**Example Scenario**:
+Three LLM revisions extract a "Country" field:
-What happens when no value meets the threshold is determined by a ``conflict_resolver`` function that can be passed to the ``JSONConsensus`` initializer. The library provides two main strategies:
+1. "USA"
+2. "U.S.A."
+3. "United States"
+4. "France"
+
+Without clustering, each value might have a low agreement score (e.g., 25% each), failing the threshold.
+
+With **Similarity Clustering**:
+
+1. The resolver detects that "USA" and "U.S.A." are highly similar (Levenstein distance).
+2. "United States" might also be linked via semantic matching (if enabled).
+3. "France" is distinct.
+4. The system treats {"USA", "U.S.A."} as a single consensus group with 50% agreement (assuming equal weights).
+5. If this meets the threshold, the most standard format (e.g., "USA") is selected as the final value.
+
+Analytics & Metrics
+-------------------
-- ``default_conflict_resolver``: If no consensus is found for a path, the field is simply omitted from the final output.
-- ``prefer_most_common_resolver``: If no consensus is found, this resolver will pick the most frequent value, even if it doesn't meet the threshold. This is useful if you always want a value for a field, even if the LLM was inconsistent.
+The process produces detailed metrics available in the `WorkflowAnalyticsCollector`:
-You can also implement your own custom conflict resolver function for more advanced logic.
+- **`consensus_confidence_score`**: An aggregate score (0.0 - 1.0) indicating how "confident" the system is in the final result, based on the average agreement ratio across all fields.
+- **`average_string_similarity`**: Measures how textually similar the revisions were to each other.
diff --git a/docs/concepts/result_processing.rst b/docs/concepts/result_processing.rst
new file mode 100644
index 0000000..4bf2282
--- /dev/null
+++ b/docs/concepts/result_processing.rst
@@ -0,0 +1,32 @@
+.. _result_processing:
+
+Result Processing & Hydration
+=============================
+
+After the consensus phase produces a clean JSON object, the `ResultProcessor` is responsible for converting it back into Python objects and persisting them.
+
+Hydration Strategies
+--------------------
+
+The processor employs two strategies depending on the extraction mode:
+
+1. **Direct Hydration (Structured Output)**
+ When using `use_structured_output=True` (default), the JSON structure exactly mirrors the Pydantic models derived from your SQLModels.
+ * **Mechanism**: Recursive Pydantic instantiation (`Model(**data)`).
+ * **Pros**: Fast, type-safe, handles nested objects automatically.
+ * **Cons**: Requires the LLM to strictly follow the schema.
+
+2. **Graph Reconstruction (Flat/Legacy)**
+ When working with flat JSON output (used in some legacy modes or specific prompt configurations), the processor must reconstruct the graph.
+ * **Mechanism**: It identifies objects by their keys and reconstructs relationships manually.
+
+Foreign Key Recovery
+--------------------
+
+In **Hierarchical Extraction**, child objects (like `Employees`) are extracted in separate steps from their parents (`Departments`).
+
+* **The Problem**: The child objects generated by the LLM don't know the real database IDs of their parents (since the parents might have just been inserted).
+* **The Solution**:
+ 1. The `BatchPipeline` tracks the "Parent Context" for every child batch.
+ 2. When hydrating the child, the `ResultProcessor` injects the correct `parent_id` based on this context.
+ 3. This ensures that even though they were extracted separately, the database relationships are correctly preserved.
diff --git a/docs/features/entity_counting.md b/docs/features/entity_counting.md
new file mode 100644
index 0000000..0de1fa3
--- /dev/null
+++ b/docs/features/entity_counting.md
@@ -0,0 +1,61 @@
+# Entity Counting
+
+## Overview
+
+The **Entity Counting** feature performs a high-level pass with the LLM to estimate the number of entities to be extracted *before* the main extraction phase. This count is then injected into the extraction prompt as a "Critical Quantity Constraint", significantly improving recall and reducing hallucinations.
+
+This logic is encapsulated in the `EntityCounter` class, which is automatically managed by the `WorkflowOrchestrator`.
+
+## Key Capabilities
+
+1. **Recall Improvement**: By explicitly asking "how many items are there?", the LLM is forced to scan the text more thoroughly.
+2. **Hallucination Reduction**: The constraint prevents the LLM from inventing extra entities to fill a quota.
+3. **Context Awareness**: In hierarchical extraction workflows, the counter is aware of previously extracted parent entities, allowing it to provide more accurate counts for child entities (e.g., "3 flights for Traveler A, 2 flights for Traveler B").
+4. **Batch Integration**: Counting is fully integrated into the asynchronous batch pipeline (`COUNTING_SUBMITTED` phase), ensuring scalability.
+
+## Configuration
+
+To enable entity counting, simply pass `count_entities=True` to the synthesis method.
+
+```python
+results = await orchestrator.synthesize(
+ input_strings=[document_text],
+ count_entities=True
+)
+```
+
+### Customizing the Counting Phase
+
+You can fine-tune the counting process using the following parameters:
+
+**1. Custom Context (`custom_counting_context`)**
+
+Provide specific hints or context for the counting phase. This is useful if the entities are ambiguous.
+
+```python
+results = await orchestrator.synthesize(
+ ...,
+ count_entities=True,
+ custom_counting_context="Count 'Invoices' only if they have a non-zero total."
+)
+```
+
+**2. Dedicated LLM Client (`counting_llm_client`)**
+
+You can use a different model for counting (e.g., a faster/cheaper one) than for the main extraction.
+
+```python
+# Initialize orchestrator with a specialized counting client
+orchestrator = WorkflowOrchestrator(
+ root_sqlmodel_class=MyModel,
+ llm_client=gpt4_client, # Main extraction (high precision)
+ counting_llm_client=gpt35_client # Counting (high speed/low cost)
+)
+```
+
+## How it Works
+
+1. **Prompting**: The `EntityCounter` generates a specialized prompt asking the LLM to return a JSON object mapping model names to counts (e.g., `{"Invoice": 5, "LineItem": 20}`).
+2. **Execution**: The prompt is sent to the `counting_llm_client` (or the default client).
+3. **Validation**: The output is validated to ensure it matches the expected structure.
+4. **Injection**: The counts are formatted into a constraint string (e.g., *"CRITICAL: You must extract exactly 5 Invoice items..."*) and added to the system prompt of the subsequent extraction phase.
diff --git a/docs/getting_started.rst b/docs/getting_started.rst
index 92dddae..985846d 100644
--- a/docs/getting_started.rst
+++ b/docs/getting_started.rst
@@ -19,9 +19,9 @@ Prerequisites
Before you begin, make sure you have:
-1. **Python 3.8+** installed.
+1. **Python 3.10+** installed.
2. The `extrai` library installed (see :ref:`installation`).
-3. An **API key** for an LLM provider (e.g., Gemini, OpenAI).
+3. An **API key** for a supported LLM provider (Google Gemini, OpenAI, DeepSeek, etc.).
Step 1: Define Your Data Models
-------------------------------
diff --git a/docs/how_to/handle_complex_data_with_hierarchical_extraction.rst b/docs/how_to/handle_complex_data_with_hierarchical_extraction.rst
index 36803a1..3b65ca5 100644
--- a/docs/how_to/handle_complex_data_with_hierarchical_extraction.rst
+++ b/docs/how_to/handle_complex_data_with_hierarchical_extraction.rst
@@ -5,14 +5,48 @@ How to Handle Complex Data with Hierarchical Extraction
When dealing with complex, nested data structures, a single LLM call can struggle to produce a valid, complete JSON object. The hierarchical extraction feature is designed to solve this by breaking the problem down into smaller, more manageable pieces.
-When to Use This Feature
-------------------------
+When to Use This Feature (Decision Tree)
+----------------------------------------
-Use hierarchical extraction when your data has clear parent-child relationships (e.g., companies, departments, and employees) and standard extraction methods fail to capture the full structure.
+Use hierarchical extraction when your data has clear parent-child relationships and standard extraction fails. Use the following decision tree to decide:
+
+1. **Is your Data Model Deeply Nested?** (e.g., Company -> Department -> Team -> Employee)
+ * **Yes**: Proceed to 2.
+ * **No**: Use **Standard Extraction**.
+
+2. **Is the text document very long or dense?** (e.g., > 10 pages)
+ * **Yes**: Use **Hierarchical Extraction**.
+ * **No**: Proceed to 3.
+
+3. **Does Standard Extraction fail to link children correctly?** (e.g., Employees assigned to wrong Department)
+ * **Yes**: Use **Hierarchical Extraction**.
+ * **No**: Use **Standard Extraction** (it's faster and cheaper).
.. warning::
Enabling hierarchical extraction can significantly increase the number of LLM API calls and the total processing time. Expect the number of calls to be roughly `(number of revisions) * (depth of the model)`. For a model with a depth of 3 and 3 revisions, this means about 9 calls, plus one for example generation. It should be used judiciously when the standard extraction method proves insufficient.
+Context Injection Mechanism
+---------------------------
+
+One of the challenges in hierarchical extraction is maintaining context. How does the LLM know which "Department" an "Employee" belongs to if we extract them separately?
+
+`extrai` uses **Context Injection** to solve this:
+
+1. **Top-Level Extraction**: First, we extract the parent objects (e.g., `Department` A and `Department` B).
+2. **ID Assignment**: We assign temporary IDs to these parents.
+3. **Child Extraction Prompt**: When asking the LLM to extract `Employees`, we inject the parent context into the prompt:
+
+ .. code-block:: text
+
+ We have already found the following Departments:
+ - ID 1: Research (focus on AI)
+ - ID 2: Sales (focus on North America)
+
+ Now, extract all Employees.
+ IMPORTANT: Assign each Employee to the correct 'department_id' from the list above.
+
+This ensures that the LLM has the necessary "Foreign Keys" available to link objects correctly during the generation phase.
+
Step 1: Define Your Nested Data Models
---------------------------------------
@@ -103,6 +137,26 @@ The output shows that the orchestrator successfully identified and linked the co
This demonstrates how hierarchical extraction can robustly handle nested data by processing it one level at a time.
+Hierarchical Extraction in Batch Mode
+-------------------------------------
+
+Hierarchical extraction also works seamlessly with the **Batch API**, allowing you to process massive datasets cost-effectively.
+
+Key Features:
+* **Automatic State Machine**: The pipeline automatically transitions from Parent -> Child steps.
+* **Shallow Schemas**: To prevent hallucinations, the LLM is only given the schema for the *current* level (e.g., just `Company`, then just `Department`), preventing it from trying to extract the whole tree at once.
+* **Contextual Linking**: When extracting children (e.g., `Departments`), the prompt includes the previously extracted parents (`Company` list) so the LLM can correctly assign Foreign Keys.
+
+Usage:
+
+.. code-block:: python
+
+ results = await orchestrator.synthesize_batch(
+ input_strings=[large_text_blob],
+ wait_for_completion=True, # The orchestrator handles the multi-step polling loop
+ count_entities=True # Recommended for complex hierarchies
+ )
+
.. seealso::
For a complete, runnable script, see the example file: `examples/hierarchical_extraction.py`.
diff --git a/docs/how_to/implement_custom_conflict_resolvers.rst b/docs/how_to/implement_custom_conflict_resolvers.rst
new file mode 100644
index 0000000..f1f4344
--- /dev/null
+++ b/docs/how_to/implement_custom_conflict_resolvers.rst
@@ -0,0 +1,62 @@
+.. _custom_conflict_resolution:
+
+How to Implement Custom Conflict Resolvers
+==========================================
+
+When the consensus algorithm encounters conflicting values for a field (e.g., one revision says "Active" and another "Inactive"), and the weighted agreement is below the threshold, it delegates the decision to a **Conflict Resolver**.
+
+By default, `extrai` uses resolvers that either drop the field or pick the most common value. However, you can inject your own logic.
+
+The Resolver Interface
+----------------------
+
+A conflict resolver is simply a function with the following signature:
+
+.. code-block:: python
+
+ from typing import Any, List, Optional
+
+ def my_resolver(
+ path: str,
+ values: List[Any],
+ weights: List[float]
+ ) -> Optional[Any]:
+ ...
+
+* **path**: The dot-notation path to the field (e.g., `"users.0.status"`).
+* **values**: A list of all candidate values from the different revisions.
+* **weights**: The corresponding trust weights for those values.
+* **Return**: The resolved value, or `None` to omit the field.
+
+Example: Strict Numeric Resolver
+--------------------------------
+
+Let's say you are extracting financial data, and if there is *any* disagreement on a price, you want to be conservative and pick the **highest** price found (e.g., for cost estimation), rather than the most common one.
+
+.. code-block:: python
+
+ def conservative_max_price_resolver(path: str, values: List[Any], weights: List[float]) -> Optional[Any]:
+ # Only apply custom logic to 'price' fields
+ if "price" in path:
+ # Filter for numeric values only
+ numeric_values = [v for v in values if isinstance(v, (int, float))]
+ if numeric_values:
+ return max(numeric_values)
+
+ # Fallback to default behavior for other fields
+ # You can call the default resolver here or return None
+ return None
+
+Using Your Resolver
+-------------------
+
+Pass your function to the `WorkflowOrchestrator` during initialization.
+
+.. code-block:: python
+
+ from extrai.core import WorkflowOrchestrator
+
+ orchestrator = WorkflowOrchestrator(
+ ...,
+ conflict_resolver=conservative_max_price_resolver
+ )
diff --git a/docs/index.rst b/docs/index.rst
index b02065e..59f801f 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -16,20 +16,37 @@ Welcome to Extrai's documentation!
.. toctree::
:maxdepth: 2
- :caption: How-to Guides
+ :caption: Features
- how_to/generate_sql_model
- how_to/generate_example_json
- how_to/customize_extraction_prompts
- how_to/handle_complex_data_with_hierarchical_extraction
- how_to/using_multiple_llm_providers
+ features/entity_counting
+ features/structured_output
+ batch_processing_integration
.. toctree::
:maxdepth: 2
:caption: Core Concepts
concepts/architecture_overview
- concepts/consensus_mechanism
+ concepts/consensus_mechanism
+ concepts/result_processing
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Advanced Topics
+
+ advanced/batch_deep_dive
+ advanced/error_handling
+
+.. toctree::
+ :maxdepth: 2
+ :caption: How-to Guides
+
+ how_to/generate_sql_model
+ how_to/generate_example_json
+ how_to/customize_extraction_prompts
+ how_to/handle_complex_data_with_hierarchical_extraction
+ how_to/implement_custom_conflict_resolvers
+ how_to/using_multiple_llm_providers
.. toctree::
:maxdepth: 2
diff --git a/docs/introduction.rst b/docs/introduction.rst
index b46a6cd..ef3ca13 100644
--- a/docs/introduction.rst
+++ b/docs/introduction.rst
@@ -1,13 +1,13 @@
Intro
--------
-With `extrai`, you can extract data from text documents with LLMs, which will be formatted into a given `SQLModel` and registered in your database.
+`extrai` extracts data from text documents using LLMs, formatting the output into a given `SQLModel` and registering it in a database.
-The core of the library is its :ref:`Consensus Mechanism `. We make the same request multiple times, using the same or different providers, and then select the values that meet a certain threshold.
+The library utilizes a :ref:`Consensus Mechanism ` to ensure accuracy. It makes the same request multiple times, using the same or different providers, and then selects the values that meet a configured threshold.
`extrai` also has other features, like :ref:`generating SQLModels ` from a prompt and documents, and :ref:`generating few-shot examples `. For complex, nested data, the library offers :ref:`Hierarchical Extraction `, breaking down the extraction into manageable, hierarchical steps. It also includes :ref:`built-in analytics ` to monitor performance and output quality.
-Worflow Overview
+Workflow Overview
----------------------
The library is built around a few key components that work together to manage the extraction workflow. The following diagram illustrates the high-level workflow (see :ref:`Architecture Overview ` for more details):
diff --git a/docs/sqlmodel_generator.rst b/docs/sqlmodel_generator.rst
index 28a8e4f..46003d0 100644
--- a/docs/sqlmodel_generator.rst
+++ b/docs/sqlmodel_generator.rst
@@ -13,7 +13,7 @@ This component is responsible for:
- Generating a prompt for an LLM based on your documents and task description.
- Interacting with the LLM to get a structured JSON description of the desired data models.
-- Generating Python code from this description.
+- Generating Python code from this description (via the ``PythonModelBuilder`` component).
- Dynamically compiling and loading the new ``SQLModel`` classes into your application at runtime.
Core Workflow
diff --git a/docs/workflow_orchestrator.rst b/docs/workflow_orchestrator.rst
index 5e76e5f..c39b027 100644
--- a/docs/workflow_orchestrator.rst
+++ b/docs/workflow_orchestrator.rst
@@ -26,7 +26,7 @@ The typical workflow involves these steps:
Initialization and Configuration
--------------------------------
-The constructor of the ``WorkflowOrchestrator`` is key to configuring its behavior.
+The constructor of the ``WorkflowOrchestrator`` is key to configuring its behavior. The ``WorkflowOrchestrator`` acts as a facade, coordinating various internal components (like ``ExtractionPipeline`` and ``BatchPipeline``) to simplify the API.
.. code-block:: python
@@ -87,6 +87,12 @@ Here are the parameters you can use:
# Multiple clients for resilience
orchestrator = WorkflowOrchestrator(..., llm_client=[client1, client2])
+``counting_llm_client``
+ An optional LLM client instance specifically for the entity counting phase. This allows you to use a cheaper or faster model for the initial count, while reserving the main ``llm_client`` for the detailed extraction.
+
+ * **Type**: ``Optional[BaseLLMClient]``
+ * **Default**: ``None`` (uses ``llm_client``)
+
``num_llm_revisions``
The total number of times the LLM will be asked to generate a JSON output for the given input. A higher number increases the chances of a reliable consensus but also increases costs and latency.
@@ -155,13 +161,17 @@ Once the orchestrator is configured, you can start processing documents using on
hydrated_objects = await orchestrator.synthesize(
input_strings=["Text document 1...", "Text document 2..."],
- db_session_for_hydration=db_session # Optional: for relationship resolution
+ db_session_for_hydration=db_session, # Optional: for relationship resolution
+ count_entities=True, # Optional: enable entity counting
+ custom_counting_context="..." # Optional: context for counting
)
**Parameters:**
* ``input_strings`` (``List[str]``): A list of strings, where each string is a document to be processed.
* ``db_session_for_hydration`` (``Optional[Session]``): An optional SQLAlchemy session. If provided, the hydrator will use it to resolve relationships. If not, a temporary in-memory session is created.
+ * ``count_entities`` (``bool``, default ``False``): If True, performs an initial pass to count entities before extraction.
+ * ``custom_counting_context`` (``str``, optional): Custom instructions or context specifically for the entity counting phase.
* ``extraction_example_json`` (``str``, optional): A JSON string that provides a few-shot example to the LLM, guiding it to produce a better-structured output. If not provided, the orchestrator will attempt to auto-generate one.
* ``extraction_example_object`` (``Optional[Union[SQLModel, List[SQLModel]]]``, optional): An existing SQLModel object or a list of them to be used as the few-shot example. This is an alternative to providing the example as a raw JSON string.
* ``custom_extraction_process`` (``str``, optional): Custom, step-by-step instructions for the LLM on how to perform the extraction.
@@ -181,6 +191,88 @@ Once the orchestrator is configured, you can start processing documents using on
The parameters are the same as for ``synthesize()``, except it requires a ``db_session`` to commit the transaction.
+``synthesize_batch()``
+ Submits an asynchronous batch job for extraction. This is ideal for large-scale processing or when using cheaper batch APIs.
+
+ .. code-block:: python
+
+ # Non-blocking submission (returns str)
+ batch_id = await orchestrator.synthesize_batch(
+ input_strings=["..."],
+ db_session=db_session,
+ wait_for_completion=False
+ )
+ print(f"Batch submitted: {batch_id}")
+
+ # Blocking until complete (returns BatchProcessResult)
+ result = await orchestrator.synthesize_batch(
+ input_strings=["..."],
+ db_session=db_session,
+ wait_for_completion=True
+ )
+ print(f"Extraction complete: {len(result.hydrated_objects)} objects found.")
+
+ **Parameters:**
+
+ * ``input_strings`` (``List[str]``): A list of strings to be processed.
+ * ``db_session`` (``Optional[Session]``): A database session (SQLModel/SQLAlchemy) used for initial counting and potentially for result hydration.
+ * ``wait_for_completion`` (``bool``, default ``False``):
+ * If ``True``, the method polls the batch status until completion (or error). It handles multi-stage hierarchical batches automatically. **Returns**: ``BatchProcessResult`` (containing hydrated objects).
+ * If ``False``, it submits the batch to the LLM provider and returns immediately. **Returns**: ``str`` (the batch job ID).
+ * ``count_entities`` (``bool``, default ``False``): If True, performs an initial pass to count entities before extraction.
+ * ``custom_counting_context`` (``str``, optional): Custom instructions or context specifically for the entity counting phase.
+ * ... (other parameters similar to ``synthesize``)
+
+``create_continuation_batch()``
+ Creates a new batch job that continues from a specific step of a previous batch. This is useful for retrying failed steps or extending a workflow without re-running earlier successful steps.
+
+ .. code-block:: python
+
+ # Continue from step 2 (0-indexed) of a previous batch
+ new_batch_id = await orchestrator.create_continuation_batch(
+ original_batch_id="prev_batch_id",
+ db_session=db_session,
+ start_from_step_index=2,
+ wait_for_completion=True
+ )
+
+ **Parameters:**
+
+ * ``original_batch_id`` (``str``): The ID of the batch to continue from.
+ * ``start_from_step_index`` (``int``): The hierarchical step index to start the new batch from. Steps before this index will be copied from the original batch.
+ * ``wait_for_completion`` (``bool``, default ``False``): Same behavior as in ``synthesize_batch``.
+ * (Other parameters allow overriding configuration for the new batch)
+
+``monitor_batch_job()``
+ Polls the status of an existing batch job until it completes. This method handles multi-stage workflows (like counting -> extraction, or hierarchical steps) by automatically detecting phase transitions and submitting subsequent jobs.
+
+ .. code-block:: python
+
+ # Resume monitoring a batch (e.g., after script restart)
+ result = await orchestrator.monitor_batch_job(
+ root_batch_id="existing_batch_id",
+ db_session=db_session
+ )
+
+ **Parameters:**
+
+ * ``root_batch_id`` (``str``): The ID of the batch to monitor.
+ * ``poll_interval`` (``int``, default ``60``): Seconds to wait between status checks.
+
+``get_batch_status()``
+ Retrieves the current status of a batch job, checking with the LLM provider if necessary.
+
+ .. code-block:: python
+
+ status = await orchestrator.get_batch_status("batch_id", db_session)
+
+``process_batch()``
+ Processes a completed batch job. This downloads results, runs consensus, hydrates objects, and persists them to the database. It is typically called automatically by ``monitor_batch_job``, but can be used manually for non-blocking workflows.
+
+ .. code-block:: python
+
+ result = await orchestrator.process_batch("batch_id", db_session)
+
Concise Usage Example
---------------------
diff --git a/src/extrai/core/__init__.py b/src/extrai/core/__init__.py
index 1a4a763..31a0286 100644
--- a/src/extrai/core/__init__.py
+++ b/src/extrai/core/__init__.py
@@ -8,30 +8,29 @@
from .errors import (
WorkflowError,
LLMInteractionError,
- ConfigurationError,
ConsensusProcessError,
HydrationError,
LLMConfigurationError,
LLMOutputParseError,
LLMOutputValidationError,
LLMAPICallError,
- ExampleGenerationError,
+ ConfigurationError,
)
-from .base_llm_client import BaseLLMClient
from .analytics_collector import WorkflowAnalyticsCollector
-from .json_consensus import JSONConsensus, default_conflict_resolver
+from .json_consensus import JSONConsensus
from .prompt_builder import generate_system_prompt, generate_user_prompt_for_docs
-from .schema_inspector import (
- generate_llm_schema_from_models,
- discover_sqlmodels_from_root,
- inspect_sqlalchemy_model,
-)
-from .sqlalchemy_hydrator import SQLAlchemyHydrator
-from .db_writer import persist_objects
+from .model_registry import ModelRegistry
+from .schema_inspector import SchemaInspector
+from .result_processor import ResultProcessor, SQLAlchemyHydrator, persist_objects
from .workflow_orchestrator import WorkflowOrchestrator
from .sqlmodel_generator import SQLModelCodeGenerator
from .example_json_generator import ExampleJSONGenerator
+from .conflict_resolvers import (
+ SimilarityClusterResolver,
+ default_conflict_resolver,
+ prefer_most_common_resolver,
+)
__all__ = [
# Errors
@@ -44,20 +43,22 @@
"LLMOutputParseError",
"LLMOutputValidationError",
"LLMAPICallError",
- "ExampleGenerationError",
+ "ExampleGenerationError"
# Classes & Functions
"BaseLLMClient",
"WorkflowAnalyticsCollector",
"JSONConsensus",
"default_conflict_resolver",
+ "prefer_most_common_resolver",
"generate_system_prompt",
"generate_user_prompt_for_docs",
- "generate_llm_schema_from_models",
- "discover_sqlmodels_from_root",
- "inspect_sqlalchemy_model",
+ "ModelRegistry",
+ "SchemaInspector",
+ "ResultProcessor",
"SQLAlchemyHydrator",
"persist_objects",
"WorkflowOrchestrator",
"SQLModelCodeGenerator",
"ExampleJSONGenerator",
+ "SimilarityClusterResolver",
]
diff --git a/src/extrai/core/analytics_collector.py b/src/extrai/core/analytics_collector.py
index 41ad6a4..751d9f7 100644
--- a/src/extrai/core/analytics_collector.py
+++ b/src/extrai/core/analytics_collector.py
@@ -22,6 +22,34 @@ def __init__(self, logger: Optional[logging.Logger] = None):
self._consensus_run_details_list: List[Dict[str, Any]] = []
self._custom_events: List[Dict[str, Any]] = []
self._workflow_errors: List[Dict[str, Any]] = []
+ self._llm_output_validations_errors_details: List[Dict[str, Any]] = []
+ self._total_llm_cost: float = 0.0
+ self._total_input_tokens: int = 0
+ self._total_output_tokens: int = 0
+ self._llm_cost_details: List[Dict[str, Any]] = []
+
+ def record_llm_usage(
+ self,
+ input_tokens: int,
+ output_tokens: int,
+ model: str,
+ cost: float = 0.0,
+ details: Optional[Dict[str, Any]] = None,
+ ):
+ """Records the token usage and optional cost of an LLM call."""
+ self._total_input_tokens += input_tokens
+ self._total_output_tokens += output_tokens
+ self._total_llm_cost += cost
+
+ usage_details = {
+ "model": model,
+ "input_tokens": input_tokens,
+ "output_tokens": output_tokens,
+ "cost": cost,
+ }
+ if details:
+ usage_details.update(details)
+ self._llm_cost_details.append(usage_details)
def record_llm_api_call_success(self):
"""Increments the count of successful LLM API calls."""
@@ -35,9 +63,10 @@ def record_llm_output_parse_error(self):
"""Increments the count of LLM output parsing errors."""
self._llm_output_parse_errors += 1
- def record_llm_output_validation_error(self):
+ def record_llm_output_validation_error(self, details=None):
"""Increments the count of LLM output validation errors."""
self._llm_output_validation_errors += 1
+ self._llm_output_validations_errors_details.append(details)
def record_hydration_success(self, count: int):
"""Records the number of successfully hydrated objects."""
@@ -100,6 +129,21 @@ def llm_output_validation_errors(self) -> int:
"""Returns the total count of LLM output validation errors."""
return self._llm_output_validation_errors
+ @property
+ def total_llm_cost(self) -> float:
+ """Returns the total cost of LLM calls."""
+ return self._total_llm_cost
+
+ @property
+ def total_input_tokens(self) -> int:
+ """Returns the total input tokens used."""
+ return self._total_input_tokens
+
+ @property
+ def total_output_tokens(self) -> int:
+ """Returns the total output tokens used."""
+ return self._total_output_tokens
+
@property
def number_of_consensus_runs(self) -> int:
"""Returns the total number of consensus runs recorded."""
@@ -201,6 +245,9 @@ def get_report(self) -> Dict[str, Any]:
"number_of_consensus_runs": self.number_of_consensus_runs,
"hydrated_objects_successes": self._hydrated_objects_successes,
"hydration_failures": self._hydration_failures,
+ "total_llm_cost": self._total_llm_cost,
+ "total_input_tokens": self._total_input_tokens,
+ "total_output_tokens": self._total_output_tokens,
}
if self._consensus_run_details_list:
report.update(
@@ -225,6 +272,12 @@ def get_report(self) -> Dict[str, Any]:
report["custom_events"] = self._custom_events
if self._workflow_errors:
report["workflow_errors"] = self._workflow_errors
+ if self._llm_output_validations_errors_details:
+ report["llm_output_validations_errors_details"] = (
+ self._llm_output_validations_errors_details
+ )
+ if self._llm_cost_details:
+ report["llm_cost_details"] = self._llm_cost_details
return report
@@ -265,3 +318,7 @@ def reset(self):
self._consensus_run_details_list = []
self._custom_events = []
self._workflow_errors = []
+ self._total_llm_cost = 0.0
+ self._total_input_tokens = 0
+ self._total_output_tokens = 0
+ self._llm_cost_details = []
diff --git a/src/extrai/core/base_llm_client.py b/src/extrai/core/base_llm_client.py
index 22cbc7f..3de5e78 100644
--- a/src/extrai/core/base_llm_client.py
+++ b/src/extrai/core/base_llm_client.py
@@ -61,7 +61,12 @@ def __init__(
self.logger.setLevel(logging.WARNING)
@abstractmethod
- async def _execute_llm_call(self, system_prompt: str, user_prompt: str) -> str:
+ async def _execute_llm_call(
+ self,
+ system_prompt: str,
+ user_prompt: str,
+ analytics_collector: Optional[WorkflowAnalyticsCollector] = None,
+ ) -> str:
"""
Makes the actual API call to the LLM and returns the raw string content.
@@ -71,6 +76,7 @@ async def _execute_llm_call(self, system_prompt: str, user_prompt: str) -> str:
Args:
system_prompt: The system prompt for the LLM.
user_prompt: The user prompt for the LLM.
+ analytics_collector: Optional analytics collector for tracking costs.
Returns:
The raw string content from the LLM response. Should return an empty
@@ -82,6 +88,32 @@ async def _execute_llm_call(self, system_prompt: str, user_prompt: str) -> str:
"""
...
+ async def generate_structured(
+ self,
+ system_prompt: str,
+ user_prompt: str,
+ response_model: Type[Any],
+ analytics_collector: Optional[WorkflowAnalyticsCollector] = None,
+ **kwargs: Any,
+ ) -> Any:
+ """
+ Generates a structured output directly from the LLM.
+ This defaults to raising NotImplementedError for providers that don't support it.
+
+ Args:
+ system_prompt: The system prompt.
+ user_prompt: The user prompt.
+ response_model: The Pydantic model class to parse the response into.
+ analytics_collector: Optional analytics collector.
+ **kwargs: Additional arguments.
+
+ Returns:
+ An instance of response_model.
+ """
+ raise NotImplementedError(
+ "Structured generation is not supported by this provider."
+ )
+
async def _attempt_single_generation_and_validation(
self,
*,
@@ -89,17 +121,22 @@ async def _attempt_single_generation_and_validation(
user_prompt: str,
validation_callable: Callable[[str, str], Dict[str, Any]],
revision_info_for_error: str,
+ analytics_collector: Optional[WorkflowAnalyticsCollector] = None,
) -> Dict[str, Any]:
"""
Performs one LLM call and one validation attempt.
"""
raw_response_content = await self._execute_llm_call(
- system_prompt=system_prompt, user_prompt=user_prompt
+ system_prompt=system_prompt,
+ user_prompt=user_prompt,
+ analytics_collector=analytics_collector,
)
if not raw_response_content:
raise ValueError(f"{revision_info_for_error}: LLM returned empty content.")
+ self.logger.debug(f"received {raw_response_content} from the llm")
+
validated_data = validation_callable(
raw_response_content, revision_info_for_error
)
@@ -129,6 +166,7 @@ async def _generate_one_revision_with_retries(
user_prompt=user_prompt,
validation_callable=validation_callable,
revision_info_for_error=revision_info_for_error,
+ analytics_collector=analytics_collector,
)
if analytics_collector:
analytics_collector.record_llm_api_call_success()
@@ -221,6 +259,7 @@ async def _generate_all_revisions(
self.logger.info(
f"Revision generation summary: {num_successful} successful, {num_failures} failed."
)
+ self.logger.debug(f"Generated objects : {successful_revisions}")
if failures:
# If all revisions failed, raise an aggregate error.
@@ -236,6 +275,72 @@ async def _generate_all_revisions(
return successful_revisions
+ async def create_batch_job(
+ self,
+ requests: List[Dict[str, Any]],
+ endpoint: str = "/v1/chat/completions",
+ completion_window: str = "24h",
+ metadata: Optional[Dict[str, str]] = None,
+ ) -> Any:
+ """
+ Creates a batch job for processing multiple requests.
+
+ Args:
+ requests: List of request bodies. Each request should be a dictionary
+ representing the body of a single API call (e.g. chat completion).
+ Each request MUST have a 'custom_id' field for identification.
+ endpoint: The API endpoint to target (default: /v1/chat/completions).
+ completion_window: The time window for completion (default: 24h).
+ metadata: Optional metadata to attach to the batch.
+
+ Returns:
+ The created batch job object.
+
+ Raises:
+ NotImplementedError: If the provider does not support batch processing.
+ """
+ raise NotImplementedError("Batch processing is not supported by this provider.")
+
+ async def retrieve_batch_job(self, batch_id: str) -> Any:
+ """
+ Retrieves the status and details of a batch job.
+ """
+ raise NotImplementedError("Batch processing is not supported by this provider.")
+
+ async def list_batch_jobs(
+ self, limit: int = 20, after: Optional[str] = None
+ ) -> Any:
+ """
+ Lists batch jobs.
+ """
+ raise NotImplementedError("Batch processing is not supported by this provider.")
+
+ async def cancel_batch_job(self, batch_id: str) -> Any:
+ """
+ Cancels a batch job.
+ """
+ raise NotImplementedError("Batch processing is not supported by this provider.")
+
+ async def retrieve_batch_results(self, file_id: str) -> str:
+ """
+ Retrieves the content of a batch output file.
+ """
+ raise NotImplementedError("Batch processing is not supported by this provider.")
+
+ def extract_content_from_batch_response(
+ self, response: Dict[str, Any]
+ ) -> Optional[str]:
+ """
+ Extracts the text content from a single item in a batch response file.
+
+ Args:
+ response: A dictionary representing a single line/item from the batch output.
+
+ Returns:
+ The extracted content string, or None if extraction failed.
+ """
+ raise NotImplementedError("Batch processing is not supported by this provider.")
+
async def generate_json_revisions(
self,
system_prompt: str,
@@ -276,6 +381,7 @@ async def generate_and_validate_raw_json_output(
max_validation_retries_per_revision: int,
target_json_schema: Optional[Dict[str, Any]] = None,
analytics_collector: Optional[WorkflowAnalyticsCollector] = None,
+ attempt_unwrap: bool = True,
) -> List[Dict[str, Any]]:
"""
Generates multiple JSON output revisions, validating against a raw JSON schema.
@@ -286,6 +392,7 @@ def validation_callable(content: str, revision_info: str) -> Dict[str, Any]:
raw_llm_content=content,
revision_info_for_error=revision_info,
target_json_schema=target_json_schema,
+ attempt_unwrap=attempt_unwrap,
)
return await self._generate_all_revisions(
diff --git a/src/extrai/core/batch_models.py b/src/extrai/core/batch_models.py
new file mode 100644
index 0000000..93e6525
--- /dev/null
+++ b/src/extrai/core/batch_models.py
@@ -0,0 +1,70 @@
+from enum import Enum
+from typing import List, Optional, Any, Dict
+from datetime import datetime, timezone
+from sqlmodel import SQLModel, Field, Relationship
+from sqlalchemy import JSON
+
+
+class BatchJobStatus(str, Enum):
+ SUBMITTED = "submitted"
+ PROCESSING = "processing"
+ READY_TO_PROCESS = "ready_to_process"
+ COMPLETED = "completed"
+ FAILED = "failed"
+ CANCELLED = "cancelled"
+
+ # Counting phase statuses
+ COUNTING_SUBMITTED = "counting_submitted"
+ COUNTING_PROCESSING = "counting_processing"
+ COUNTING_READY_TO_PROCESS = "counting_ready_to_process"
+
+
+class BatchJobContext(SQLModel, table=True):
+ """
+ Stores the state of a batch job managed by the WorkflowOrchestrator.
+ """
+
+ root_batch_id: str = Field(primary_key=True)
+ current_batch_id: str = Field(index=True) # Provider's batch ID
+ status: BatchJobStatus = Field(default=BatchJobStatus.SUBMITTED)
+
+ input_strings: List[str] = Field(default_factory=list, sa_type=JSON)
+ config: Dict[str, Any] = Field(default_factory=dict, sa_type=JSON)
+
+ # Store results when completed
+ results: Optional[List[Any]] = Field(default=None, sa_type=JSON)
+
+ # Tracking retries
+ retry_count: int = Field(default=0)
+
+ created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
+ updated_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
+
+ # Error tracking
+ last_error: Optional[str] = None
+
+ steps: List["BatchJobStep"] = Relationship(back_populates="batch")
+
+
+class BatchJobStep(SQLModel, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ batch_id: str = Field(foreign_key="batchjobcontext.root_batch_id")
+ step_index: int
+ status: BatchJobStatus = Field(default=BatchJobStatus.COMPLETED)
+ result: List[Any] = Field(default_factory=list, sa_type=JSON)
+ metadata_json: Dict[str, Any] = Field(default_factory=dict, sa_type=JSON)
+ created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
+
+ batch: BatchJobContext = Relationship(back_populates="steps")
+
+
+class BatchProcessResult(SQLModel):
+ """
+ Result returned by process_batch.
+ """
+
+ status: BatchJobStatus
+ hydrated_objects: Optional[List[Any]] = None
+ original_pk_map: Optional[Dict[Any, Any]] = Field(default=None, exclude=True)
+ retry_batch_id: Optional[str] = None
+ message: Optional[str] = None
diff --git a/src/extrai/core/batch_pipeline.py b/src/extrai/core/batch_pipeline.py
new file mode 100644
index 0000000..1e9c1bd
--- /dev/null
+++ b/src/extrai/core/batch_pipeline.py
@@ -0,0 +1,754 @@
+import json
+import uuid
+import logging
+from datetime import datetime, timezone
+from typing import List, Dict, Any, Optional, Union
+from sqlalchemy.orm import Session
+from sqlmodel import SQLModel, select
+
+from extrai.core.base_llm_client import BaseLLMClient
+from .client_rotator import ClientRotator
+from .extraction_context_preparer import ExtractionContextPreparer
+from .model_registry import ModelRegistry
+from .extraction_config import ExtractionConfig
+from .prompt_builder import PromptBuilder
+from .entity_counter import EntityCounter
+from .analytics_collector import WorkflowAnalyticsCollector
+from .batch_models import (
+ BatchJobContext,
+ BatchJobStatus,
+ BatchProcessResult,
+ BatchJobStep,
+)
+from .model_wrapper_builder import ModelWrapperBuilder
+from extrai.utils.llm_output_processing import process_and_validate_llm_output
+from extrai.utils.alignment_utils import normalize_json_revisions
+from .json_consensus import JSONConsensus
+from .extraction_request_factory import ExtractionRequestFactory
+
+
+class BatchPipeline:
+ """Manages batch extraction workflows."""
+
+ def __init__(
+ self,
+ model_registry: ModelRegistry,
+ llm_client: Union["BaseLLMClient", List["BaseLLMClient"]],
+ config: ExtractionConfig,
+ analytics_collector: WorkflowAnalyticsCollector,
+ logger: logging.Logger,
+ counting_llm_client: Optional[BaseLLMClient] = None,
+ ):
+ self.model_registry = model_registry
+ self.config = config
+ self.analytics_collector = analytics_collector
+ self.logger = logger
+
+ self.client_rotator = ClientRotator(llm_client)
+ self.prompt_builder = PromptBuilder(model_registry, logger=logger)
+ c_client = counting_llm_client or llm_client
+ if isinstance(c_client, list):
+ c_client = c_client[0]
+
+ self.entity_counter = EntityCounter(
+ model_registry, c_client, config, analytics_collector, logger=logger
+ )
+ self.context_preparer = ExtractionContextPreparer(
+ model_registry,
+ analytics_collector,
+ config.max_validation_retries_per_revision,
+ logger=logger,
+ )
+ self.model_wrapper_builder = ModelWrapperBuilder()
+ self.consensus = JSONConsensus(
+ consensus_threshold=config.consensus_threshold,
+ conflict_resolver=config.conflict_resolver,
+ logger=logger,
+ )
+ self.request_factory = ExtractionRequestFactory(
+ model_registry,
+ self.prompt_builder,
+ self.model_wrapper_builder,
+ logger=logger,
+ )
+
+ async def submit_batch(
+ self,
+ db_session: Session,
+ input_strings: List[str],
+ extraction_example_json: str = "",
+ extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None,
+ custom_extraction_process: str = "",
+ custom_extraction_guidelines: str = "",
+ custom_final_checklist: str = "",
+ custom_context: str = "",
+ count_entities: bool = False,
+ custom_counting_context: str = "",
+ ) -> str:
+ """Submits a batch job and returns root_batch_id."""
+ if not input_strings:
+ raise ValueError("input_strings cannot be empty")
+
+ # Prepare example
+ example_json = await self.context_preparer.prepare_example(
+ extraction_example_json,
+ extraction_example_object,
+ self.client_rotator.get_next_client,
+ )
+
+ root_batch_id = str(uuid.uuid4())
+
+ # Initialize configuration
+ config_data = {
+ "extraction_example_json": example_json,
+ "custom_extraction_process": custom_extraction_process,
+ "custom_extraction_guidelines": custom_extraction_guidelines,
+ "custom_final_checklist": custom_final_checklist,
+ "custom_context": custom_context,
+ "count_entities": count_entities,
+ "custom_counting_context": custom_counting_context,
+ "schema_json": self.model_registry.llm_schema_json,
+ }
+
+ if self.config.use_hierarchical_extraction:
+ config_data.update({"hierarchical": True, "current_model_index": 0})
+
+ context = BatchJobContext(
+ root_batch_id=root_batch_id,
+ current_batch_id="pending",
+ status=BatchJobStatus.SUBMITTED,
+ input_strings=input_strings,
+ config=config_data,
+ )
+ db_session.add(context)
+ db_session.commit()
+
+ if count_entities:
+ await self._submit_counting_phase(context, db_session)
+ else:
+ await self._submit_extraction_phase(context, db_session, step_index=0)
+
+ self.logger.info(f"Batch workflow initiated: {root_batch_id}")
+ return root_batch_id
+
+ async def create_continuation_batch(
+ self,
+ db_session: Session,
+ original_batch_id: str,
+ new_config: Dict[str, Any],
+ start_from_step_index: int,
+ ) -> str:
+ """
+ Creates a new batch cycle continuing from a previous batch's state.
+ Copies completed steps up to start_from_step_index into the new batch.
+ """
+ old_context = db_session.get(BatchJobContext, original_batch_id)
+ if not old_context:
+ raise ValueError("Old batch not found")
+
+ new_batch_id = str(uuid.uuid4())
+
+ # Ensure new config has required fields
+ if self.config.use_hierarchical_extraction and "hierarchical" not in new_config:
+ new_config["hierarchical"] = True
+ new_config["current_model_index"] = start_from_step_index
+
+ if "expected_entity_descriptions" in old_context.config:
+ new_config["expected_entity_descriptions"] = old_context.config[
+ "expected_entity_descriptions"
+ ]
+
+ new_context = BatchJobContext(
+ root_batch_id=new_batch_id,
+ current_batch_id="pending",
+ status=BatchJobStatus.SUBMITTED,
+ input_strings=old_context.input_strings,
+ config=new_config,
+ )
+ db_session.add(new_context)
+ db_session.commit()
+
+ # Copy valid steps from old batch
+ if start_from_step_index > 0:
+ old_steps = db_session.exec(
+ select(BatchJobStep)
+ .where(BatchJobStep.batch_id == original_batch_id)
+ .where(BatchJobStep.step_index < start_from_step_index)
+ .where(BatchJobStep.status == BatchJobStatus.COMPLETED)
+ ).all()
+
+ for step in old_steps:
+ new_step = BatchJobStep(
+ batch_id=new_batch_id,
+ step_index=step.step_index,
+ status=step.status,
+ result=step.result,
+ metadata_json=step.metadata_json,
+ )
+ db_session.add(new_step)
+
+ db_session.commit()
+
+ self.logger.info(
+ f"Created continuation batch {new_batch_id} from {original_batch_id}, starting at step {start_from_step_index}"
+ )
+
+ # Determine starting phase
+ # If counting is enabled, we start with counting phase for the starting step
+ if new_config.get("count_entities"):
+ step_idx = start_from_step_index if new_config.get("hierarchical") else 0
+ await self._submit_counting_phase(
+ new_context, db_session, step_index=step_idx
+ )
+ elif new_config.get("hierarchical"):
+ await self._submit_extraction_phase(
+ new_context, db_session, step_index=start_from_step_index
+ )
+ else:
+ await self._submit_extraction_phase(new_context, db_session, step_index=0)
+
+ return new_batch_id
+
+ async def _submit_counting_phase(
+ self,
+ context: BatchJobContext,
+ db_session: Session,
+ step_index: Optional[int] = None,
+ ):
+ input_strings = context.input_strings
+
+ # Determine which models to count
+ if context.config.get("hierarchical"):
+ idx = (
+ step_index
+ if step_index is not None
+ else context.config.get("current_model_index", 0)
+ )
+ if 0 <= idx < len(self.model_registry.models):
+ model_names = [self.model_registry.models[idx].__name__]
+ else:
+ model_names = self.model_registry.get_all_model_names()
+ else:
+ model_names = self.model_registry.get_all_model_names()
+
+ custom_counting_context = context.config.get("custom_counting_context", "")
+
+ system_prompt, user_prompt = self.entity_counter.prepare_counting_prompts(
+ input_strings, model_names, custom_counting_context
+ )
+
+ client = self.entity_counter.llm_client
+ requests = self._create_batch_requests(
+ system_prompt, user_prompt, num_revisions=1, override_client=client
+ )
+
+ batch_job = await client.create_batch_job(requests)
+ provider_batch_id = batch_job.id if hasattr(batch_job, "id") else str(batch_job)
+
+ context.current_batch_id = provider_batch_id
+ context.status = BatchJobStatus.COUNTING_SUBMITTED
+ context.updated_at = datetime.now(timezone.utc)
+ db_session.add(context)
+ db_session.commit()
+
+ self.logger.info(
+ f"Submitted counting batch for {context.root_batch_id}: {provider_batch_id}"
+ )
+
+ async def _submit_extraction_phase(
+ self, context: BatchJobContext, db_session: Session, step_index: int = 0
+ ):
+ # Update current index in config if hierarchical
+ if context.config.get("hierarchical"):
+ # Update the config dictionary properly
+ new_config = context.config.copy()
+ new_config["current_model_index"] = step_index
+ context.config = new_config
+ db_session.add(context)
+ db_session.commit()
+
+ # Retrieve previous entities from completed steps
+ previous_entities = []
+ if context.config.get("hierarchical") and step_index > 0:
+ steps = db_session.exec(
+ select(BatchJobStep)
+ .where(BatchJobStep.batch_id == context.root_batch_id)
+ .where(BatchJobStep.step_index < step_index)
+ .where(BatchJobStep.status == BatchJobStatus.COMPLETED)
+ .order_by(BatchJobStep.step_index)
+ ).all()
+ for s in steps:
+ previous_entities.extend(s.result)
+
+ # Prepare request
+ request = self.request_factory.prepare_request(
+ input_strings=context.input_strings,
+ config=self.config,
+ extraction_example_json=context.config.get("extraction_example_json", ""),
+ custom_extraction_process=context.config.get(
+ "custom_extraction_process", ""
+ ),
+ custom_extraction_guidelines=context.config.get(
+ "custom_extraction_guidelines", ""
+ ),
+ custom_final_checklist=context.config.get("custom_final_checklist", ""),
+ custom_context=context.config.get("custom_context", ""),
+ expected_entity_descriptions=context.config.get(
+ "expected_entity_descriptions"
+ ),
+ previous_entities=previous_entities if previous_entities else None,
+ hierarchical_model_index=step_index
+ if context.config.get("hierarchical")
+ else None,
+ )
+
+ requests = self._create_batch_requests(
+ request.system_prompt, request.user_prompt, request.json_schema
+ )
+
+ client = self.client_rotator.get_next_client()
+ batch_job = await client.create_batch_job(requests)
+ provider_batch_id = batch_job.id if hasattr(batch_job, "id") else str(batch_job)
+
+ context.current_batch_id = provider_batch_id
+ context.status = BatchJobStatus.SUBMITTED
+ context.updated_at = datetime.now(timezone.utc)
+ db_session.add(context)
+ db_session.commit()
+
+ phase_name = (
+ f"step {step_index}" if context.config.get("hierarchical") else "extraction"
+ )
+ self.logger.info(
+ f"Submitted extraction batch ({phase_name}) for {context.root_batch_id}: {provider_batch_id}"
+ )
+
+ async def get_status(
+ self, root_batch_id: str, db_session: Session
+ ) -> BatchJobStatus:
+ context = db_session.get(BatchJobContext, root_batch_id)
+ if not context:
+ raise ValueError(f"Batch job {root_batch_id} not found")
+
+ terminal_states = [
+ BatchJobStatus.COMPLETED,
+ BatchJobStatus.FAILED,
+ BatchJobStatus.CANCELLED,
+ BatchJobStatus.READY_TO_PROCESS,
+ BatchJobStatus.COUNTING_READY_TO_PROCESS,
+ ]
+ if context.status in terminal_states:
+ return context.status
+
+ try:
+ # Determine client based on phase
+ if context.status in [
+ BatchJobStatus.COUNTING_SUBMITTED,
+ BatchJobStatus.COUNTING_PROCESSING,
+ ]:
+ client = self.entity_counter.llm_client
+ else:
+ client = self.client_rotator.get_next_client()
+
+ batch_job = await client.retrieve_batch_job(context.current_batch_id)
+ new_provider_status = self._map_provider_status(batch_job.status)
+
+ new_status = context.status
+ # Map provider status based on current internal phase
+ if context.status in [
+ BatchJobStatus.COUNTING_SUBMITTED,
+ BatchJobStatus.COUNTING_PROCESSING,
+ ]:
+ if new_provider_status == BatchJobStatus.READY_TO_PROCESS:
+ new_status = BatchJobStatus.COUNTING_READY_TO_PROCESS
+ elif new_provider_status == BatchJobStatus.FAILED:
+ new_status = BatchJobStatus.FAILED
+ elif new_provider_status == BatchJobStatus.CANCELLED:
+ new_status = BatchJobStatus.CANCELLED
+ elif new_provider_status == BatchJobStatus.PROCESSING:
+ new_status = BatchJobStatus.COUNTING_PROCESSING
+ else:
+ new_status = new_provider_status
+
+ if new_status != context.status:
+ context.status = new_status
+ context.updated_at = datetime.now(timezone.utc)
+ db_session.add(context)
+ db_session.commit()
+
+ except Exception as e:
+ self.logger.error(f"Failed to check batch status: {e}", exc_info=True)
+
+ return context.status
+
+ async def process_batch(
+ self, root_batch_id: str, db_session: Session
+ ) -> BatchProcessResult:
+ status = await self.get_status(root_batch_id, db_session)
+ context = db_session.get(BatchJobContext, root_batch_id)
+
+ # 1. Already Completed
+ if status == BatchJobStatus.COMPLETED and context.results:
+ return await self._finalize_completion(context, db_session)
+
+ # 2. Counting Phase Completed -> Submit Extraction
+ if status == BatchJobStatus.COUNTING_READY_TO_PROCESS:
+ return await self._process_counting_completion(context, db_session)
+
+ # 3. Extraction Phase Ready -> Process Results
+ if status == BatchJobStatus.READY_TO_PROCESS:
+ return await self._process_extraction_completion(context, db_session)
+
+ return BatchProcessResult(
+ status=status, message="Batch not ready for processing"
+ )
+
+ async def _process_counting_completion(
+ self, context: BatchJobContext, db_session: Session
+ ) -> BatchProcessResult:
+ try:
+ # Use counting client
+ client = self.entity_counter.llm_client
+ results_content = await client.retrieve_batch_results(
+ context.current_batch_id
+ )
+
+ # Determine expected models for validation
+ if context.config.get("hierarchical"):
+ current_idx = context.config.get("current_model_index", 0)
+ if 0 <= current_idx < len(self.model_registry.models):
+ target_model_names = [
+ self.model_registry.models[current_idx].__name__
+ ]
+ else:
+ target_model_names = self.model_registry.get_all_model_names()
+ else:
+ target_model_names = self.model_registry.get_all_model_names()
+
+ # Parse descriptions
+ descriptions = []
+
+ for line in results_content.strip().split("\n"):
+ if not line.strip():
+ continue
+ try:
+ item = json.loads(line)
+ content = client.extract_content_from_batch_response(item)
+ if content:
+ raw_json = json.loads(content)
+ if isinstance(raw_json, list) and raw_json:
+ raw_json = raw_json[0]
+ if isinstance(raw_json, dict):
+ validated_counts = self.entity_counter.validate_counts(
+ raw_json, target_model_names
+ )
+ for model_name, descs in validated_counts.items():
+ for desc in descs:
+ descriptions.append(f"[{model_name}] {desc}")
+ except Exception as e:
+ self.logger.warning(f"Failed to parse counting result: {e}")
+
+ # Update config with descriptions
+ new_config = context.config.copy()
+ new_config["expected_entity_descriptions"] = descriptions
+ context.config = new_config
+
+ # Proceed to Extraction
+ next_step = (
+ context.config.get("current_model_index", 0)
+ if context.config.get("hierarchical")
+ else 0
+ )
+ await self._submit_extraction_phase(
+ context, db_session, step_index=next_step
+ )
+
+ return BatchProcessResult(
+ status=BatchJobStatus.PROCESSING,
+ message="Transitioned from counting to extraction",
+ retry_batch_id=context.root_batch_id,
+ )
+
+ except Exception as e:
+ self.logger.error(f"Failed to process counting results: {e}", exc_info=True)
+ return BatchProcessResult(
+ status=BatchJobStatus.FAILED, message=f"Counting failed: {e}"
+ )
+
+ async def _process_extraction_completion(
+ self, context: BatchJobContext, db_session: Session
+ ) -> BatchProcessResult:
+ try:
+ results = await self._retrieve_and_validate_results(context)
+
+ if results:
+ consensus_output, details = self.consensus.get_consensus(results)
+ if details:
+ self.analytics_collector.record_consensus_run_details(details)
+
+ processed = self._process_consensus_output(consensus_output)
+
+ if context.config.get("hierarchical"):
+ return await self._process_hierarchical_step(
+ context, processed, db_session
+ )
+
+ # Finalize non-hierarchical
+ context.results = processed
+ context.status = BatchJobStatus.COMPLETED
+ context.updated_at = datetime.now(timezone.utc)
+ db_session.add(context)
+ db_session.commit()
+
+ return await self._finalize_completion(context, db_session)
+
+ # If no valid results, maybe retry?
+ return await self._handle_batch_retry(
+ context, context.root_batch_id, db_session
+ )
+
+ except Exception as e:
+ self.logger.error(f"Batch processing failed: {e}", exc_info=True)
+ return BatchProcessResult(status=BatchJobStatus.FAILED, message=str(e))
+
+ async def _process_hierarchical_step(
+ self,
+ context: BatchJobContext,
+ processed_results: List[Dict],
+ db_session: Session,
+ ) -> BatchProcessResult:
+ current_index = context.config.get("current_model_index", 0)
+
+ # Save step result to DB
+ step = BatchJobStep(
+ batch_id=context.root_batch_id,
+ step_index=current_index,
+ status=BatchJobStatus.COMPLETED,
+ result=processed_results,
+ metadata_json={"timestamp": datetime.now(timezone.utc).isoformat()},
+ )
+ db_session.add(step)
+
+ # Advance index
+ next_index = current_index + 1
+
+ # Update config
+ new_config = context.config.copy()
+ new_config["current_model_index"] = next_index
+ context.config = new_config
+ db_session.add(context)
+ db_session.commit()
+
+ if next_index >= len(self.model_registry.models):
+ # All steps done - aggregate results for final hydration
+ all_steps = db_session.exec(
+ select(BatchJobStep)
+ .where(BatchJobStep.batch_id == context.root_batch_id)
+ .order_by(BatchJobStep.step_index)
+ ).all()
+
+ final_results = []
+ for s in all_steps:
+ final_results.extend(s.result)
+
+ context.results = final_results
+ context.status = BatchJobStatus.COMPLETED
+ context.updated_at = datetime.now(timezone.utc)
+ db_session.add(context)
+ db_session.commit()
+ return await self._finalize_completion(context, db_session)
+
+ # Submit next step (counting or extraction)
+ model_name = self.model_registry.models[next_index].__name__
+
+ if context.config.get("count_entities"):
+ await self._submit_counting_phase(
+ context, db_session, step_index=next_index
+ )
+ return BatchProcessResult(
+ status=BatchJobStatus.COUNTING_PROCESSING,
+ message=f"Submitted counting step for {model_name}",
+ retry_batch_id=context.root_batch_id,
+ )
+ else:
+ await self._submit_extraction_phase(
+ context, db_session, step_index=next_index
+ )
+ return BatchProcessResult(
+ status=BatchJobStatus.PROCESSING,
+ message=f"Submitted hierarchical step for {model_name}",
+ retry_batch_id=context.root_batch_id,
+ )
+
+ async def _finalize_completion(
+ self, context: BatchJobContext, db_session: Session
+ ) -> BatchProcessResult:
+ from .result_processor import ResultProcessor
+
+ processor = ResultProcessor(
+ self.model_registry, self.analytics_collector, self.logger
+ )
+
+ # Determine default model type for hydration
+ default_model_type = None
+ if self.config.use_structured_output:
+ default_model_type = self.model_registry.root_model.__name__
+
+ hydrated = processor.hydrate(
+ context.results, db_session, default_model_type=default_model_type
+ )
+ return BatchProcessResult(
+ status=BatchJobStatus.COMPLETED,
+ hydrated_objects=hydrated,
+ original_pk_map=processor.original_pk_map,
+ )
+
+ async def _retrieve_and_validate_results(
+ self, context: BatchJobContext
+ ) -> List[List[Dict]]:
+ client = self.client_rotator.get_next_client()
+ results_content = await client.retrieve_batch_results(context.current_batch_id)
+
+ # Determine default model type clearly
+ default_model_type = None
+ if self.config.use_structured_output:
+ if context.config.get("hierarchical"):
+ current_idx = context.config.get("current_model_index", 0)
+ if 0 <= current_idx < len(self.model_registry.models):
+ default_model_type = self.model_registry.models[
+ current_idx
+ ].__name__
+ else:
+ default_model_type = self.model_registry.root_model.__name__
+
+ valid_revisions = []
+ for line in results_content.strip().split("\n"):
+ if not line.strip():
+ continue
+
+ try:
+ item = json.loads(line)
+ content = client.extract_content_from_batch_response(item)
+
+ if content:
+ validated = process_and_validate_llm_output(
+ raw_llm_content=content,
+ model_schema_map=self.model_registry.model_map,
+ revision_info_for_error="batch_revision",
+ analytics_collector=self.analytics_collector,
+ default_model_type=default_model_type,
+ )
+ if validated:
+ valid_revisions.append(validated)
+ except Exception as e:
+ self.logger.warning(f"Failed to validate batch revision: {e}")
+
+ return normalize_json_revisions(valid_revisions) if valid_revisions else []
+
+ async def _handle_batch_retry(
+ self, context: BatchJobContext, root_batch_id: str, db_session: Session
+ ):
+ max_retries = self.config.max_validation_retries_per_revision
+
+ if context.retry_count < max_retries:
+ context.retry_count += 1
+ self.logger.info(
+ f"Retrying batch {root_batch_id} ({context.retry_count}/{max_retries})"
+ )
+
+ # Resubmit current step
+ if "counting" in context.status.value:
+ await self._submit_counting_phase(context, db_session)
+ else:
+ current_idx = context.config.get("current_model_index", 0)
+ await self._submit_extraction_phase(
+ context, db_session, step_index=current_idx
+ )
+
+ return BatchProcessResult(
+ status=BatchJobStatus.PROCESSING,
+ message="Retry submitted",
+ retry_batch_id=root_batch_id,
+ )
+
+ context.status = BatchJobStatus.FAILED
+ context.last_error = "Max retries exceeded"
+ context.updated_at = datetime.now(timezone.utc)
+ db_session.add(context)
+ db_session.commit()
+
+ return BatchProcessResult(
+ status=BatchJobStatus.FAILED, message="Max retries exceeded"
+ )
+
+ def _create_batch_requests(
+ self,
+ system_prompt: str,
+ user_prompt: str,
+ json_schema: Optional[Dict] = None,
+ num_revisions: Optional[int] = None,
+ override_client: Optional[BaseLLMClient] = None,
+ ) -> List[Dict]:
+ requests = []
+ client = override_client or self.client_rotator.current_client
+ revisions = (
+ num_revisions
+ if num_revisions is not None
+ else self.config.num_llm_revisions
+ )
+
+ if self.config.use_structured_output and json_schema:
+ self.logger.info("Using structured output for batch requests")
+
+ for i in range(revisions):
+ body = {
+ "messages": [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ],
+ "temperature": client.temperature,
+ }
+ if hasattr(client, "model_name"):
+ body["model"] = client.model_name
+
+ if self.config.use_structured_output and json_schema:
+ body["response_format"] = {
+ "type": "json_schema",
+ "json_schema": {
+ "name": "extraction_response",
+ "schema": json_schema,
+ "strict": True,
+ },
+ }
+ elif self.model_registry.llm_schema_json:
+ body["response_format"] = {"type": "json_object"}
+
+ requests.append({"custom_id": f"rev-{i}", "body": body})
+ return requests
+
+ def _map_provider_status(self, provider_status) -> BatchJobStatus:
+ status_str = str(provider_status).lower()
+
+ if "complete" in status_str or "succeeded" in status_str:
+ return BatchJobStatus.READY_TO_PROCESS
+ elif "fail" in status_str:
+ return BatchJobStatus.FAILED
+ elif "cancel" in status_str:
+ return BatchJobStatus.CANCELLED
+ elif (
+ "process" in status_str or "active" in status_str or "running" in status_str
+ ):
+ return BatchJobStatus.PROCESSING
+
+ return BatchJobStatus.SUBMITTED
+
+ def _process_consensus_output(self, output) -> List[Dict[str, Any]]:
+ if output is None:
+ return []
+ if isinstance(output, list):
+ return output
+ if isinstance(output, dict):
+ if "results" in output and isinstance(output["results"], list):
+ return output["results"]
+ return [output]
+ return []
diff --git a/src/extrai/core/client_rotator.py b/src/extrai/core/client_rotator.py
new file mode 100644
index 0000000..48298e6
--- /dev/null
+++ b/src/extrai/core/client_rotator.py
@@ -0,0 +1,29 @@
+from typing import List, Union
+from .base_llm_client import BaseLLMClient
+
+
+class ClientRotator:
+ """
+ Manages rotation through a list of LLM clients.
+ """
+
+ def __init__(self, clients: Union[BaseLLMClient, List[BaseLLMClient]]):
+ self.clients = clients if isinstance(clients, list) else [clients]
+ if not self.clients:
+ raise ValueError("At least one client must be provided")
+ self._current_index = 0
+
+ def get_next_client(self) -> BaseLLMClient:
+ """Returns the next client in the rotation."""
+ client = self.clients[self._current_index]
+ self._current_index = (self._current_index + 1) % len(self.clients)
+ return client
+
+ @property
+ def current_client(self) -> BaseLLMClient:
+ """Returns the current client without advancing rotation."""
+ return self.clients[self._current_index]
+
+ def reset(self):
+ """Resets the rotation to the start."""
+ self._current_index = 0
diff --git a/src/extrai/core/code_generation/__init__.py b/src/extrai/core/code_generation/__init__.py
new file mode 100644
index 0000000..0b4a187
--- /dev/null
+++ b/src/extrai/core/code_generation/__init__.py
@@ -0,0 +1,3 @@
+from .python_builder import PythonModelBuilder
+
+__all__ = ["PythonModelBuilder"]
diff --git a/src/extrai/core/code_generation/python_builder.py b/src/extrai/core/code_generation/python_builder.py
new file mode 100644
index 0000000..51eb89f
--- /dev/null
+++ b/src/extrai/core/code_generation/python_builder.py
@@ -0,0 +1,324 @@
+import keyword
+from typing import Any, Dict, Set, List
+
+
+class ImportManager:
+ """Manages imports for the generated code, handling consolidation."""
+
+ def __init__(self):
+ self.typing_imports: Set[str] = set()
+ self.sqlmodel_imports: Set[str] = {"SQLModel"}
+ self.module_imports: Set[str] = set()
+ self.custom_imports: Set[str] = set()
+
+ def add_import_for_type(self, type_str: str):
+ if "datetime." in type_str:
+ self.module_imports.add("datetime")
+ if "uuid." in type_str:
+ self.module_imports.add("uuid")
+ if "Optional[" in type_str:
+ self.typing_imports.add("Optional")
+ if "List[" in type_str:
+ self.typing_imports.add("List")
+ if "Dict[" in type_str:
+ self.typing_imports.add("Dict")
+ if "Union[" in type_str:
+ self.typing_imports.add("Union")
+ if "Any" in type_str:
+ self.typing_imports.add("Any")
+
+ def add_custom_imports(self, imports: List[str]):
+ for imp in imports:
+ self.custom_imports.add(imp.strip())
+
+ def render(self) -> str:
+ import_lines = []
+
+ # Consolidate custom imports with auto-detected ones
+ for custom_imp in self.custom_imports:
+ if custom_imp.startswith("from sqlmodel"):
+ try:
+ items = {
+ item.strip()
+ for item in custom_imp.split(" import ")[1].split(",")
+ }
+ self.sqlmodel_imports.update(items)
+ except IndexError:
+ import_lines.append(custom_imp)
+ elif custom_imp.startswith("from typing"):
+ try:
+ items = {
+ item.strip()
+ for item in custom_imp.split(" import ")[1].split(",")
+ }
+ self.typing_imports.update(items)
+ except IndexError:
+ import_lines.append(custom_imp)
+ elif custom_imp.startswith("import "):
+ modules = {
+ mod.strip()
+ for mod in custom_imp.replace("import ", "").split(",")
+ if mod.strip()
+ }
+ if modules:
+ self.module_imports.update(modules)
+ else:
+ import_lines.append(custom_imp)
+ else:
+ import_lines.append(custom_imp) # Add other complex imports as is
+
+ if self.sqlmodel_imports:
+ import_lines.append(
+ f"from sqlmodel import {', '.join(sorted(list(self.sqlmodel_imports)))}"
+ )
+ if self.typing_imports:
+ import_lines.append(
+ f"from typing import {', '.join(sorted(list(self.typing_imports)))}"
+ )
+ if self.module_imports:
+ for mod in sorted(self.module_imports):
+ import_lines.append(f"import {mod}")
+
+ return "\n".join(sorted(set(import_lines)))
+
+
+class FieldGenerator:
+ """Generates the code for a single field in a SQLModel."""
+
+ def __init__(self, field_info: Dict[str, Any], import_manager: ImportManager):
+ self.field_info = field_info
+ self.imports = import_manager
+ self.field_name_original = self.field_info["name"]
+ self.field_name_python = self.field_name_original
+ self.args_map: Dict[str, str] = {}
+
+ def _handle_keyword_name(self):
+ if keyword.iskeyword(self.field_name_original):
+ self.field_name_python = self.field_name_original + "_"
+ self.args_map["alias"] = f'"{self.field_name_original}"'
+
+ def _get_default_value_arg(self):
+ if "default_factory" in self.field_info:
+ factory_str = self.field_info["default_factory"]
+ self.args_map["default_factory"] = factory_str
+ if "." in factory_str:
+ potential_module = factory_str.split(".")[0]
+ if potential_module.isidentifier() and potential_module not in [
+ "list",
+ "dict",
+ "set",
+ "tuple",
+ ]:
+ self.imports.module_imports.add(potential_module)
+ elif "default" in self.field_info:
+ default_val = self.field_info["default"]
+ if isinstance(default_val, str):
+ self.args_map["default"] = f'"{default_val}"'
+ elif isinstance(default_val, bool):
+ self.args_map["default"] = str(default_val)
+ elif default_val is None:
+ self.args_map["default"] = "None"
+ else:
+ self.args_map["default"] = str(default_val)
+
+ def _get_nullable_arg(self):
+ field_type_str = self.field_info["type"]
+ is_optional_type = field_type_str.startswith("Optional[")
+ is_pk = self.field_info.get("primary_key", False)
+ explicit_nullable = self.field_info.get("nullable")
+
+ if explicit_nullable is True:
+ self.args_map["nullable"] = "True"
+ elif explicit_nullable is False:
+ if not (is_pk and not is_optional_type) or is_optional_type:
+ self.args_map["nullable"] = "False"
+ elif is_optional_type:
+ self.args_map["nullable"] = "True"
+
+ if is_pk and is_optional_type and self.args_map.get("nullable") != "False":
+ self.args_map["nullable"] = "True"
+
+ def _get_sa_column_args(self):
+ sa_column_kwargs = self.field_info.get("sa_column_kwargs", {})
+ for k, v in sa_column_kwargs.items():
+ if k in ["server_default", "onupdate"]:
+ self.args_map[k] = f'"{v}"' if isinstance(v, str) else str(v)
+ elif k == "sa_type":
+ self.args_map["sa_type"] = str(v)
+ if str(v) == "JSON":
+ self.imports.sqlmodel_imports.add("JSON")
+ if "sqlalchemy." in str(v):
+ self.imports.module_imports.add("sqlalchemy")
+
+ def _get_common_args(self):
+ if "description" in self.field_info:
+ desc = self.field_info["description"]
+ escaped_desc = (
+ desc.replace("\\", "\\\\")
+ .replace('"', '\\"')
+ .replace("\n", "\\n")
+ .replace("\r", "\\r")
+ )
+ self.args_map["description"] = f'"{escaped_desc}"'
+ if "foreign_key" in self.field_info:
+ self.args_map["foreign_key"] = f'"{self.field_info["foreign_key"]}"'
+ if self.field_info.get("index"):
+ self.args_map["index"] = "True"
+ if self.field_info.get("primary_key"):
+ self.args_map["primary_key"] = "True"
+ if self.field_info.get("unique"):
+ self.args_map["unique"] = "True"
+
+ def _determine_field_arguments(self):
+ self._handle_keyword_name()
+ self._get_default_value_arg()
+ self._get_nullable_arg()
+ self._get_sa_column_args()
+ self._get_common_args()
+
+ def generate_code(self) -> str:
+ field_type_str = self.field_info["type"]
+ self.imports.add_import_for_type(field_type_str)
+
+ if "field_options_str" in self.field_info:
+ field_options = self.field_info["field_options_str"]
+ if "JSON" in field_options:
+ self.imports.sqlmodel_imports.add("JSON")
+ if "Relationship" in field_options:
+ self.imports.sqlmodel_imports.add("Relationship")
+
+ if keyword.iskeyword(self.field_name_original):
+ self.field_name_python = self.field_name_original + "_"
+ return f" {self.field_name_python}: {field_type_str} = {field_options}"
+
+ self._determine_field_arguments()
+
+ if not self.args_map:
+ return f" {self.field_name_python}: {field_type_str}"
+
+ self.imports.sqlmodel_imports.add("Field")
+ ordered_keys = [
+ "primary_key",
+ "alias",
+ "default",
+ "default_factory",
+ "unique",
+ "index",
+ "foreign_key",
+ "nullable",
+ "sa_type",
+ "description",
+ "server_default",
+ "onupdate",
+ ]
+ final_args_list = [
+ f"{key}={self.args_map[key]}"
+ for key in ordered_keys
+ if key in self.args_map
+ ]
+ field_args_str = ", ".join(final_args_list)
+ return (
+ f" {self.field_name_python}: {field_type_str} = Field({field_args_str})"
+ )
+
+
+class ClassCodeBuilder:
+ """Builds the final Python code string from its components."""
+
+ def __init__(
+ self,
+ model_name: str,
+ import_manager: ImportManager,
+ description: str,
+ table_name: str,
+ base_classes: List[str],
+ is_table_model: bool,
+ ):
+ self.model_name = model_name
+ self.import_manager = import_manager
+ self.description = description
+ self.table_name = table_name
+ self.base_classes = base_classes
+ self.is_table_model = is_table_model
+ self.fields_code: List[str] = []
+
+ def add_field(self, field_code: str):
+ self.fields_code.append(field_code)
+
+ def render_class_definition(self) -> str:
+ fields_str = "\n".join(self.fields_code) if self.fields_code else " pass"
+
+ base_classes_str = ", ".join(self.base_classes)
+ class_decorator_args = []
+ if self.is_table_model:
+ class_decorator_args.append("table=True")
+
+ class_header = f"class {self.model_name}({base_classes_str}"
+ if "SQLModel" in base_classes_str and class_decorator_args:
+ class_header += f", {', '.join(class_decorator_args)}"
+ class_header += "):"
+
+ docstring_section = ""
+ if self.description:
+ docstring_section = f' """{self.description}"""\n'
+
+ table_name_section = ""
+ if self.is_table_model:
+ table_name_section = f' __tablename__ = "{self.table_name}"\n\n'
+ elif self.description:
+ table_name_section = "\n"
+
+ return f"{class_header}\n{docstring_section}{table_name_section}{fields_str}\n"
+
+
+class PythonModelBuilder:
+ """Facade for generating Python code for SQLModels from description dictionaries."""
+
+ def generate_model_code(self, model_descriptions: List[Dict[str, Any]]) -> str:
+ """
+ Generates Python code for the provided model descriptions.
+
+ Args:
+ model_descriptions: A list of dictionaries, where each dictionary describes a SQLModel.
+
+ Returns:
+ A string containing the complete Python code with imports and class definitions.
+ """
+ import_manager = ImportManager()
+ class_definitions = []
+
+ # First pass to gather all imports from all model definitions
+ for model_desc in model_descriptions:
+ import_manager.add_custom_imports(model_desc.get("imports", []))
+ base_classes = model_desc.get("base_classes_str", ["SQLModel"])
+ if "SQLModel" in base_classes:
+ import_manager.sqlmodel_imports.add("SQLModel")
+ if "fields" in model_desc and model_desc["fields"]:
+ for f_info in model_desc["fields"]:
+ # This populates the import manager with types from fields
+ _ = FieldGenerator(f_info, import_manager).generate_code()
+
+ # Now, build each class definition
+ for model_desc in model_descriptions:
+ model_name = model_desc["model_name"]
+ base_classes = model_desc.get("base_classes_str", ["SQLModel"])
+
+ builder = ClassCodeBuilder(
+ model_name=model_name,
+ import_manager=import_manager,
+ description=model_desc.get("description", ""),
+ table_name=model_desc.get("table_name", f"{model_name.lower()}s"),
+ base_classes=base_classes,
+ is_table_model=model_desc.get("is_table_model", True),
+ )
+
+ if "fields" in model_desc and model_desc["fields"]:
+ for f_info in model_desc["fields"]:
+ field_generator = FieldGenerator(f_info, import_manager)
+ builder.add_field(field_generator.generate_code())
+
+ class_definitions.append(builder.render_class_definition())
+
+ imports_str = import_manager.render()
+ full_code = f"{imports_str}\n\n\n" + "\n\n".join(class_definitions)
+ return full_code
diff --git a/src/extrai/core/conflict_resolvers.py b/src/extrai/core/conflict_resolvers.py
new file mode 100644
index 0000000..3bdf486
--- /dev/null
+++ b/src/extrai/core/conflict_resolvers.py
@@ -0,0 +1,149 @@
+# extrai/core/conflict_resolvers.py
+from collections import Counter
+from typing import List, Optional, Callable, Dict, Any
+from extrai.utils.flattening_utils import Path, JSONValue
+from difflib import SequenceMatcher
+
+# Define conflict resolution strategies
+ConflictResolutionStrategy = Callable[
+ [Path, List[JSONValue], Optional[List[float]]], Optional[JSONValue]
+]
+
+
+def default_conflict_resolver(
+ path: Path, values: List[JSONValue], weights: Optional[List[float]] = None
+) -> Optional[JSONValue]:
+ """
+ Default conflict resolution: if no consensus, omit the field.
+ """
+ return None
+
+
+def prefer_most_common_resolver(
+ _path: Path, values: List[JSONValue], weights: Optional[List[float]] = None
+) -> Optional[JSONValue]:
+ """
+ Conflict resolution: prefer the most common value.
+ If weights are provided, prefers the value with the highest total weight.
+ """
+ if not values:
+ return None
+
+ if weights and len(weights) == len(values):
+ # Weighted voting
+ weighted_counts: Dict[Any, float] = {}
+ # We need to handle unhashable types (like dicts/lists) if they appear in values
+ # But JSONValue can be complex. Typically conflict resolution is on leaves (primitives).
+ # Flattening utils usually produce primitives at leaves, but lists can be values if not recursed?
+ # Assuming primitives for now (str, int, float, bool, None).
+
+ for val, w in zip(values, weights):
+ # If val is unhashable, we can't key it easily.
+ # Fallback to string repr or identity if needed, but for now assume hashable.
+ try:
+ weighted_counts[val] = weighted_counts.get(val, 0.0) + w
+ except TypeError:
+ # Unhashable type (e.g. list), skip optimization or use repr
+ # For safety, let's just pick the first one if we can't count.
+ # Or convert to tuple?
+ # Let's rely on standard Counter behavior for fallback.
+ pass
+
+ if weighted_counts:
+ # Pick value with max weight
+ # Break ties by first occurrence (insertion order in weighted_counts)
+ most_common_value = max(weighted_counts, key=weighted_counts.get)
+ return most_common_value
+
+ # Fallback to unweighted count
+ # Note: Counter works with unhashable types? No.
+ # If values contains unhashables, Counter(values) raises TypeError.
+ # We should handle that, but original code assumed they work or didn't handle lists as values?
+ # flattening_utils unflattening implies values are leaves.
+ try:
+ count = Counter(values)
+ most_common_value, _ = count.most_common(1)[0]
+ return most_common_value
+ except TypeError:
+ # Fallback for unhashable
+ return values[0]
+
+
+def levenshtein_similarity(a: str, b: str) -> float:
+ return SequenceMatcher(None, a, b).ratio()
+
+
+class SimilarityClusterResolver:
+ """
+ Resolves conflicts by clustering values based on string similarity.
+ Useful for filtering out outliers (e.g. "War" vs "Christmas", "Gifts").
+ """
+
+ def __init__(
+ self,
+ similarity_threshold: float = 0.6,
+ scorer: Callable[[str, str], float] = levenshtein_similarity,
+ ):
+ self.similarity_threshold = similarity_threshold
+ self.scorer = scorer
+
+ def __call__(
+ self, path: Path, values: List[JSONValue], weights: Optional[List[float]] = None
+ ) -> Optional[JSONValue]:
+ if not values:
+ return None
+
+ # Only applicable if values are strings
+ if not all(isinstance(v, str) for v in values):
+ return prefer_most_common_resolver(path, values, weights)
+
+ # 1. Compute pairwise similarities and build adjacency list
+ n = len(values)
+ adj = {i: [] for i in range(n)}
+ for i in range(n):
+ for j in range(i + 1, n):
+ score = self.scorer(values[i], values[j])
+ if score >= self.similarity_threshold:
+ adj[i].append(j)
+ adj[j].append(i)
+
+ # 2. Find connected components (clusters)
+ visited = set()
+ clusters = []
+ for i in range(n):
+ if i not in visited:
+ component = []
+ stack = [i]
+ visited.add(i)
+ while stack:
+ node = stack.pop()
+ component.append(node)
+ for neighbor in adj[node]:
+ if neighbor not in visited:
+ visited.add(neighbor)
+ stack.append(neighbor)
+ clusters.append(component)
+
+ if not clusters:
+ return prefer_most_common_resolver(path, values, weights)
+
+ # 3. Find the best cluster
+ # If weights are provided, pick the cluster with the highest total weight.
+ # Otherwise, pick the largest cluster.
+
+ if weights and len(weights) == n:
+
+ def cluster_weight(indices):
+ return sum(weights[i] for i in indices)
+
+ best_cluster_indices = max(clusters, key=cluster_weight)
+ else:
+ best_cluster_indices = max(clusters, key=len)
+
+ # 4. Pick the representative from the best cluster
+ cluster_values = [values[i] for i in best_cluster_indices]
+ cluster_weights = (
+ [weights[i] for i in best_cluster_indices] if weights else None
+ )
+
+ return prefer_most_common_resolver(path, cluster_values, cluster_weights)
diff --git a/src/extrai/core/db_writer.py b/src/extrai/core/db_writer.py
deleted file mode 100644
index 323ff70..0000000
--- a/src/extrai/core/db_writer.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# random_docs_to_sql/core/db_writer.py
-import logging
-from sqlalchemy.orm import Session
-from sqlalchemy.exc import SQLAlchemyError
-from typing import List, Any
-
-
-class DatabaseWriterError(Exception):
- """Custom exception for database writer errors."""
-
- pass
-
-
-def persist_objects(
- db_session: Session, objects_to_persist: List[Any], logger: logging.Logger
-) -> None:
- """
- Persists a list of SQLAlchemy objects to the database using the provided session.
-
- Args:
- db_session: The SQLAlchemy session to use for database operations.
- objects_to_persist: A list of SQLAlchemy model instances to be saved.
-
- Raises:
- DatabaseWriterError: If an error occurs during the database commit.
- """
- if not objects_to_persist:
- logger.info("No objects provided to persist.")
- return
-
- try:
- # All objects should already be associated with the session
- # from the hydration phase
- db_session.add_all(objects_to_persist)
- db_session.commit()
- logger.info(
- f"Successfully persisted {len(objects_to_persist)} objects to the database."
- )
- except SQLAlchemyError as e:
- logger.error(f"Database commit failed: {e}", exc_info=True)
- try:
- db_session.rollback()
- logger.info("Database session rolled back successfully.")
- except SQLAlchemyError as rollback_e:
- logger.error(
- f"Failed to rollback database session: {rollback_e}", exc_info=True
- )
- # Potentially raise a more critical error or handle nested failure
- raise DatabaseWriterError(f"Failed to persist objects due to: {e}")
- except Exception as e:
- logger.error(
- f"An unexpected error occurred during object persistence: {e}",
- exc_info=True,
- )
-
- if db_session.is_active:
- db_session.rollback()
- logger.info("Database session rolled back due to unexpected error.")
-
- raise DatabaseWriterError(f"An unexpected error occurred: {e}")
diff --git a/src/extrai/core/entity_counter.py b/src/extrai/core/entity_counter.py
new file mode 100644
index 0000000..e69f9a4
--- /dev/null
+++ b/src/extrai/core/entity_counter.py
@@ -0,0 +1,112 @@
+import logging
+from typing import List, Dict, Any, Optional
+from pydantic import create_model
+
+from .model_registry import ModelRegistry
+from .extraction_config import ExtractionConfig
+from .prompt_builder import (
+ generate_entity_counting_system_prompt,
+ generate_entity_counting_user_prompt,
+)
+
+
+class EntityCounter:
+ """Counts entities in input documents using LLM."""
+
+ def __init__(
+ self,
+ model_registry: ModelRegistry,
+ llm_client,
+ config: ExtractionConfig,
+ analytics_collector,
+ logger: logging.Logger,
+ ):
+ self.model_registry = model_registry
+ self.llm_client = llm_client
+ self.config = config
+ self.analytics_collector = analytics_collector
+ self.logger = logger
+
+ def prepare_counting_prompts(
+ self,
+ input_strings: List[str],
+ model_names: List[str],
+ custom_counting_context: str = "",
+ previous_entities: Optional[List[Dict[str, Any]]] = None,
+ ):
+ """Prepares prompts for batch counting."""
+ # Generate schema for models
+ schema_json = self.model_registry.get_schema_for_models(model_names)
+
+ # Build prompts
+ system_prompt = generate_entity_counting_system_prompt(
+ model_names, schema_json, custom_counting_context, previous_entities
+ )
+ user_prompt = generate_entity_counting_user_prompt(input_strings)
+
+ return system_prompt, user_prompt
+
+ def validate_counts(
+ self, raw_counts: Dict[str, Any], model_names: List[str]
+ ) -> Dict[str, List[str]]:
+ """Validates raw counting results against dynamic model."""
+ fields = {name: (List[str], ...) for name in model_names}
+ EntityCountModel = create_model("EntityCountModel", **fields)
+ try:
+ validated = EntityCountModel(**raw_counts)
+ return validated.model_dump()
+ except Exception as e:
+ self.logger.warning(f"Count validation failed: {e}")
+ return {}
+
+ async def count_entities(
+ self,
+ input_strings: List[str],
+ model_names: List[str],
+ custom_counting_context: str = "",
+ previous_entities: Optional[List[Dict[str, Any]]] = None,
+ ) -> Dict[str, List[str]]:
+ """Performs entity counting for specified models."""
+ self.logger.info(f"Counting entities for: {model_names}")
+
+ system_prompt, user_prompt = self.prepare_counting_prompts(
+ input_strings, model_names, custom_counting_context, previous_entities
+ )
+
+ # Create validation model
+ fields = {name: (List[str], ...) for name in model_names}
+ EntityCountModel = create_model("EntityCountModel", **fields)
+
+ # Call LLM
+ try:
+ # Get next client (assuming llm_client is list or has rotation)
+ if isinstance(self.llm_client, list):
+ client = self.llm_client[0]
+ else:
+ client = self.llm_client
+
+ result = await client.generate_and_validate_raw_json_output(
+ system_prompt=system_prompt,
+ user_prompt=user_prompt,
+ target_json_schema=None,
+ num_revisions=1,
+ max_validation_retries_per_revision=self.config.max_validation_retries_per_revision,
+ attempt_unwrap=False,
+ )
+
+ # Process result
+ if isinstance(result, list) and result:
+ result = result[0]
+
+ if isinstance(result, dict):
+ validated = EntityCountModel(**result)
+ counts = validated.model_dump()
+ self.logger.info(f"Entity counts: {counts}")
+ return counts
+
+ self.logger.warning("Entity counting returned invalid result")
+ return {}
+
+ except Exception as e:
+ self.logger.error(f"Entity counting failed: {e}")
+ return {}
diff --git a/src/extrai/core/example_json_generator.py b/src/extrai/core/example_json_generator.py
index 5353df3..1ef0de6 100644
--- a/src/extrai/core/example_json_generator.py
+++ b/src/extrai/core/example_json_generator.py
@@ -17,10 +17,7 @@
LLMOutputValidationError,
ConfigurationError,
)
-from .schema_inspector import (
- generate_llm_schema_from_models,
- discover_sqlmodels_from_root,
-)
+from .schema_inspector import SchemaInspector
class ExampleJSONGenerator:
@@ -59,16 +56,21 @@ def __init__(
self.output_model = output_model
self.analytics_collector = analytics_collector
self.max_validation_retries_per_revision = max_validation_retries_per_revision
+ self.schema_inspector = SchemaInspector(self.logger)
# Derive schema and root model name from the SQLModel
try:
# Discover all related models starting from the root model
- all_models = discover_sqlmodels_from_root(output_model)
+ all_models = self.schema_inspector.discover_sqlmodels_from_root(
+ output_model
+ )
# Generate the comprehensive schema for the LLM, which includes all related models
# to guide the LLM in creating a nested example.
- self.target_json_schema_for_llm_str = generate_llm_schema_from_models(
- initial_model_classes=all_models
+ self.target_json_schema_for_llm_str = (
+ self.schema_inspector.generate_llm_schema_from_models(
+ initial_model_classes=all_models
+ )
)
# The schema for basic validation by the LLM client needs to match the new
@@ -106,7 +108,9 @@ async def generate_example(self) -> str:
try:
# Discover all related models to build the schema map for validation.
- all_models = discover_sqlmodels_from_root(self.output_model)
+ all_models = self.schema_inspector.discover_sqlmodels_from_root(
+ self.output_model
+ )
model_schema_map = {model.__name__: model for model in all_models}
validated_revisions = await self.llm_client.generate_json_revisions(
diff --git a/src/extrai/core/extraction_config.py b/src/extrai/core/extraction_config.py
new file mode 100644
index 0000000..eb52c43
--- /dev/null
+++ b/src/extrai/core/extraction_config.py
@@ -0,0 +1,27 @@
+# extrai/core/extraction_config.py
+
+from dataclasses import dataclass
+from typing import Callable, Optional
+
+
+@dataclass
+class ExtractionConfig:
+ """Configuration for extraction workflows."""
+
+ num_llm_revisions: int = 3
+ max_validation_retries_per_revision: int = 2
+ consensus_threshold: float = 0.51
+ conflict_resolver: Optional[Callable] = None
+ use_hierarchical_extraction: bool = False
+ use_structured_output: bool = False
+
+ def __post_init__(self):
+ """Validates configuration parameters."""
+ if self.num_llm_revisions < 1:
+ raise ValueError("num_llm_revisions must be at least 1")
+
+ if self.max_validation_retries_per_revision < 1:
+ raise ValueError("max_validation_retries_per_revision must be at least 1")
+
+ if not (0.0 <= self.consensus_threshold <= 1.0):
+ raise ValueError("consensus_threshold must be between 0.0 and 1.0")
diff --git a/src/extrai/core/extraction_context_preparer.py b/src/extrai/core/extraction_context_preparer.py
new file mode 100644
index 0000000..1f49274
--- /dev/null
+++ b/src/extrai/core/extraction_context_preparer.py
@@ -0,0 +1,118 @@
+import json
+import logging
+from typing import List, Optional, Union, Callable
+from sqlmodel import SQLModel
+
+from .model_registry import ModelRegistry
+from .example_json_generator import ExampleJSONGenerator, ExampleGenerationError
+from .analytics_collector import WorkflowAnalyticsCollector
+from .errors import WorkflowError
+from .base_llm_client import BaseLLMClient
+from extrai.utils.serialization_utils import serialize_sqlmodel_with_relationships
+
+
+class ExtractionContextPreparer:
+ """
+ Helper class to prepare context for extraction, including example generation.
+ """
+
+ def __init__(
+ self,
+ model_registry: ModelRegistry,
+ analytics_collector: WorkflowAnalyticsCollector,
+ max_retries: int,
+ logger: logging.Logger,
+ ):
+ self.model_registry = model_registry
+ self.analytics_collector = analytics_collector
+ self.max_retries = max_retries
+ self.logger = logger
+
+ async def prepare_example(
+ self,
+ extraction_example_json: str,
+ extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]],
+ client_provider: Callable[[], BaseLLMClient],
+ ) -> str:
+ """
+ Prepares or auto-generates extraction example.
+
+ Priority:
+ 1. Use provided extraction_example_json if available
+ 2. Serialize extraction_example_object if provided
+ 3. Auto-generate example using LLM
+ """
+ # If JSON provided, use it directly
+ if extraction_example_json:
+ self.logger.info("Using provided extraction example JSON")
+ return extraction_example_json
+
+ # If object provided, serialize it
+ if extraction_example_object:
+ self.logger.info("Serializing extraction example object")
+ return self._serialize_example_object(extraction_example_object)
+
+ # Auto-generate
+ self.logger.info("No example provided, auto-generating...")
+ return await self._auto_generate_example(client_provider)
+
+ def _serialize_example_object(self, obj: Union[SQLModel, List[SQLModel]]) -> str:
+ """Serializes SQLModel object(s) to JSON."""
+ objects = obj if isinstance(obj, list) else [obj]
+ serialized = []
+
+ for o in objects:
+ if isinstance(o, SQLModel):
+ serialized.append(serialize_sqlmodel_with_relationships(o))
+ else:
+ self.logger.warning(
+ f"Skipping non-SQLModel object in example: {type(o)}"
+ )
+
+ if not serialized:
+ self.logger.warning("No valid SQLModel objects to serialize")
+ return ""
+
+ return json.dumps(serialized, default=str, indent=2)
+
+ async def _auto_generate_example(
+ self, client_provider: Callable[[], BaseLLMClient]
+ ) -> str:
+ """Auto-generates an extraction example using LLM."""
+ try:
+ generator = ExampleJSONGenerator(
+ llm_client=client_provider(),
+ output_model=self.model_registry.root_model,
+ analytics_collector=self.analytics_collector,
+ max_validation_retries_per_revision=self.max_retries,
+ logger=self.logger,
+ )
+
+ self.logger.info(
+ f"Auto-generating extraction example for "
+ f"{self.model_registry.root_model.__name__}..."
+ )
+
+ example = await generator.generate_example()
+
+ self.analytics_collector.record_custom_event(
+ "example_json_auto_generation_success"
+ )
+ self.logger.info("Successfully auto-generated extraction example")
+
+ return example
+
+ except ExampleGenerationError as e:
+ self.analytics_collector.record_custom_event(
+ "example_json_auto_generation_failure"
+ )
+ raise WorkflowError(
+ f"Failed to auto-generate extraction example: {e}"
+ ) from e
+ except Exception as e:
+ self.analytics_collector.record_custom_event(
+ "example_json_auto_generation_unexpected_failure"
+ )
+ raise WorkflowError(
+ f"Unexpected error during extraction example auto-generation: {e}"
+ ) from e
diff --git a/src/extrai/core/extraction_pipeline.py b/src/extrai/core/extraction_pipeline.py
new file mode 100644
index 0000000..0e92139
--- /dev/null
+++ b/src/extrai/core/extraction_pipeline.py
@@ -0,0 +1,253 @@
+import logging
+from typing import List, Dict, Any, Optional, Union
+from sqlmodel import SQLModel
+
+from extrai.core.base_llm_client import BaseLLMClient
+from .client_rotator import ClientRotator
+from .extraction_context_preparer import ExtractionContextPreparer
+from .model_registry import ModelRegistry
+from .extraction_config import ExtractionConfig
+from .prompt_builder import PromptBuilder
+from .entity_counter import EntityCounter
+from .llm_runner import LLMRunner
+from .hierarchical_extractor import HierarchicalExtractor
+from .analytics_collector import WorkflowAnalyticsCollector
+from .model_wrapper_builder import ModelWrapperBuilder
+from .extraction_request_factory import ExtractionRequestFactory
+
+
+class ExtractionPipeline:
+ """
+ Manages the standard extraction pipeline.
+
+ Flow:
+ 1. Prepare extraction example (auto-generate if needed)
+ 2. Count entities (optional)
+ 3. Run extraction (standard or hierarchical)
+ 4. Return consensus results
+
+ This class coordinates between multiple components to execute
+ the complete extraction workflow from input strings to consensus JSON.
+ """
+
+ def __init__(
+ self,
+ model_registry: ModelRegistry,
+ llm_client: Union["BaseLLMClient", List["BaseLLMClient"]],
+ config: ExtractionConfig,
+ analytics_collector: WorkflowAnalyticsCollector,
+ logger: logging.Logger,
+ counting_llm_client: Optional[BaseLLMClient] = None,
+ ):
+ """
+ Initialize the extraction pipeline.
+
+ Args:
+ model_registry: Registry of SQLModel schemas
+ llm_client: Single client or list of LLM clients for rotation
+ config: Extraction configuration
+ analytics_collector: Collector for tracking metrics
+ logger: Logger instance
+ counting_llm_client: Optional specific client for counting tasks
+ """
+ self.model_registry = model_registry
+ self.config = config
+ self.analytics_collector = analytics_collector
+ self.logger = logger
+
+ # Initialize sub-components
+ self.client_rotator = ClientRotator(llm_client)
+ self.prompt_builder = PromptBuilder(model_registry, logger)
+ self.entity_counter = EntityCounter(
+ model_registry,
+ counting_llm_client or llm_client,
+ config,
+ analytics_collector,
+ logger,
+ )
+ self.context_preparer = ExtractionContextPreparer(
+ model_registry,
+ analytics_collector,
+ config.max_validation_retries_per_revision,
+ logger,
+ )
+ self.llm_runner = LLMRunner(
+ model_registry, llm_client, config, analytics_collector, logger
+ )
+ self.model_wrapper_builder = ModelWrapperBuilder()
+
+ self.request_factory = ExtractionRequestFactory(
+ model_registry, self.prompt_builder, self.model_wrapper_builder, logger
+ )
+
+ # Initialize hierarchical extractor if needed
+ self.hierarchical_extractor = None
+ if config.use_hierarchical_extraction:
+ self.hierarchical_extractor = HierarchicalExtractor(
+ model_registry=model_registry,
+ prompt_builder=self.prompt_builder,
+ entity_counter=self.entity_counter,
+ llm_runner=self.llm_runner,
+ logger=logger,
+ request_factory=self.request_factory,
+ model_wrapper_builder=self.model_wrapper_builder,
+ use_structured_output=config.use_structured_output,
+ config=config,
+ )
+ logger.warning(
+ "Hierarchical extraction enabled. "
+ "This may significantly increase LLM API calls and processing time "
+ "based on model complexity and the number of entities."
+ )
+
+ async def extract(
+ self,
+ input_strings: List[str],
+ extraction_example_json: str = "",
+ extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None,
+ custom_extraction_process: str = "",
+ custom_extraction_guidelines: str = "",
+ custom_final_checklist: str = "",
+ custom_context: str = "",
+ count_entities: bool = False,
+ custom_counting_context: str = "",
+ ) -> List[Dict[str, Any]]:
+ """
+ Executes extraction and returns consensus JSON.
+
+ Args:
+ input_strings: List of document strings to extract from
+ extraction_example_json: Optional JSON string for few-shot prompting
+ extraction_example_object: Optional SQLModel object(s) to use as example
+ custom_extraction_process: Optional custom extraction process instructions
+ custom_extraction_guidelines: Optional custom extraction guidelines
+ custom_final_checklist: Optional custom final checklist
+ custom_context: Optional custom contextual information
+ count_entities: If True, performs entity counting before extraction
+
+ Returns:
+ List of dictionaries representing extracted entities (consensus output)
+
+ Raises:
+ WorkflowError: If extraction fails
+ """
+ self.logger.info(
+ f"Starting extraction for {self.model_registry.root_model.__name__}..."
+ )
+
+ # Step 1: Prepare example
+ example_json = await self.context_preparer.prepare_example(
+ extraction_example_json,
+ extraction_example_object,
+ self.client_rotator.get_next_client,
+ )
+
+ # Step 2: Count entities if requested
+ # Note: For hierarchical extraction, counting is handled per-model within the extractor
+ expected_entity_descriptions = None
+ if count_entities and not self.config.use_hierarchical_extraction:
+ expected_entity_descriptions = await self._count_entities(
+ input_strings, custom_counting_context
+ )
+ if expected_entity_descriptions is not None:
+ self.logger.info(f"Entity count: {len(expected_entity_descriptions)}")
+
+ # Step 3: Run extraction (hierarchical or standard)
+ if self.config.use_hierarchical_extraction:
+ self.logger.info("Using hierarchical extraction mode")
+ # We assume self.hierarchical_extractor is initialized if config says so
+ if not self.hierarchical_extractor:
+ # Should have been init in __init__, but safeguard
+ self.hierarchical_extractor = HierarchicalExtractor(
+ self.model_registry,
+ self.prompt_builder,
+ self.entity_counter,
+ self.llm_runner,
+ self.logger,
+ self.request_factory,
+ self.model_wrapper_builder,
+ self.config.use_structured_output,
+ self.config,
+ )
+
+ results = await self.hierarchical_extractor.extract(
+ input_strings=input_strings,
+ extraction_example_json=example_json,
+ custom_extraction_process=custom_extraction_process,
+ custom_extraction_guidelines=custom_extraction_guidelines,
+ custom_final_checklist=custom_final_checklist,
+ custom_context=custom_context,
+ count_entities=count_entities,
+ custom_counting_context=custom_counting_context,
+ )
+ else:
+ # Unified Non-Hierarchical Flow
+ self.logger.info(
+ f"Using {'structured' if self.config.use_structured_output else 'standard'} extraction mode"
+ )
+
+ request = self.request_factory.prepare_request(
+ input_strings=input_strings,
+ config=self.config,
+ extraction_example_json=example_json,
+ custom_extraction_process=custom_extraction_process,
+ custom_extraction_guidelines=custom_extraction_guidelines,
+ custom_final_checklist=custom_final_checklist,
+ custom_context=custom_context,
+ expected_entity_descriptions=expected_entity_descriptions,
+ )
+
+ self.logger.debug(
+ f"System prompt length: {len(request.system_prompt)} chars"
+ )
+ self.logger.debug(f"User prompt length: {len(request.user_prompt)} chars")
+
+ if request.response_model:
+ results = await self.llm_runner.run_structured_extraction_cycle(
+ system_prompt=request.system_prompt,
+ user_prompt=request.user_prompt,
+ response_model=request.response_model,
+ )
+ else:
+ results = await self.llm_runner.run_extraction_cycle(
+ system_prompt=request.system_prompt, user_prompt=request.user_prompt
+ )
+
+ self.logger.info(f"Extraction completed. Found {len(results)} entities.")
+ return results
+
+ async def _count_entities(
+ self, input_strings: List[str], custom_counting_context: str = ""
+ ) -> Optional[List[str]]:
+ """
+ Counts entities in the input documents.
+
+ Args:
+ input_strings: Documents to analyze
+ custom_counting_context: Custom context for counting phase
+
+ Returns:
+ List of descriptions of all model entities, or None if counting fails
+ """
+ all_model_names = self.model_registry.get_all_model_names()
+
+ try:
+ counts = await self.entity_counter.count_entities(
+ input_strings, all_model_names, custom_counting_context
+ )
+ flat_descriptions = []
+ for model_name, descriptions in counts.items():
+ for desc in descriptions:
+ flat_descriptions.append(f"[{model_name}] {desc}")
+ return flat_descriptions
+ except Exception as e:
+ self.logger.warning(f"Entity counting failed: {e}")
+ return None
+
+ def __repr__(self) -> str:
+ """String representation of the pipeline."""
+ mode = "hierarchical" if self.config.use_hierarchical_extraction else "standard"
+ return (
+ f"ExtractionPipeline(mode={mode}, "
+ f"root={self.model_registry.root_model.__name__})"
+ )
diff --git a/src/extrai/core/extraction_request_factory.py b/src/extrai/core/extraction_request_factory.py
new file mode 100644
index 0000000..1433411
--- /dev/null
+++ b/src/extrai/core/extraction_request_factory.py
@@ -0,0 +1,145 @@
+import logging
+from typing import List, Dict, Any, Optional, NamedTuple
+
+from extrai.core.model_registry import ModelRegistry
+from extrai.core.prompt_builder import PromptBuilder
+from extrai.core.model_wrapper_builder import ModelWrapperBuilder
+from extrai.core.extraction_config import ExtractionConfig
+from extrai.utils.serialization_utils import make_json_serializable
+
+
+class ExtractionRequest(NamedTuple):
+ system_prompt: str
+ user_prompt: str
+ json_schema: Optional[Dict[str, Any]]
+ model_name: Optional[str]
+ response_model: Optional[Any] = None
+
+
+class ExtractionRequestFactory:
+ """
+ Factory to prepare extraction requests (prompts and schemas).
+ Centralizes logic for Standard, Structured, and Hierarchical extraction preparation.
+ """
+
+ def __init__(
+ self,
+ model_registry: ModelRegistry,
+ prompt_builder: PromptBuilder,
+ model_wrapper_builder: ModelWrapperBuilder,
+ logger: Optional[logging.Logger] = None,
+ ):
+ self.model_registry = model_registry
+ self.prompt_builder = prompt_builder
+ self.model_wrapper_builder = model_wrapper_builder
+ self.logger = logger or logging.getLogger(__name__)
+
+ def prepare_request(
+ self,
+ input_strings: List[str],
+ config: ExtractionConfig,
+ extraction_example_json: str = "",
+ custom_extraction_process: str = "",
+ custom_extraction_guidelines: str = "",
+ custom_final_checklist: str = "",
+ custom_context: str = "",
+ expected_entity_descriptions: Optional[List[str]] = None,
+ previous_entities: Optional[List[Dict[str, Any]]] = None,
+ hierarchical_model_index: Optional[int] = None,
+ ) -> ExtractionRequest:
+ """
+ Prepares the extraction request based on the configuration and current state.
+
+ Args:
+ input_strings: List of document strings.
+ config: Extraction configuration.
+ extraction_example_json: Example JSON for few-shot.
+ custom_extraction_process: Custom instructions.
+ custom_extraction_guidelines: Custom guidelines.
+ custom_final_checklist: Custom checklist (standard mode only).
+ custom_context: Additional context string.
+ expected_entity_descriptions: List of descriptions (from counting).
+ previous_entities: List of previously extracted entities (for hierarchical).
+ hierarchical_model_index: Index of the model to extract (hierarchical only).
+
+ Returns:
+ ExtractionRequest containing prompts, schema, and target model name.
+ """
+ # 1. Determine Target Model
+ if config.use_hierarchical_extraction:
+ if hierarchical_model_index is None:
+ hierarchical_model_index = 0
+
+ if not (0 <= hierarchical_model_index < len(self.model_registry.models)):
+ raise ValueError(
+ f"Invalid hierarchical_model_index: {hierarchical_model_index}"
+ )
+
+ target_model = self.model_registry.models[hierarchical_model_index]
+ target_model_name = target_model.__name__
+ else:
+ target_model = self.model_registry.root_model
+ target_model_name = None
+
+ # 2. Serialize Previous Entities (Context)
+ serializable_previous_entities = None
+ if previous_entities:
+ serializable_previous_entities = make_json_serializable(previous_entities)
+
+ # 3. Generate Request
+ json_schema = None
+ wrapper_model = None
+
+ if config.use_structured_output:
+ # If hierarchical, we only want the shallow model for this step (include_relationships=False)
+ # If standard (not hierarchical), we want deep extraction (include_relationships=True)
+ include_relationships = not config.use_hierarchical_extraction
+
+ wrapper_model = self.model_wrapper_builder.generate_wrapper_model(
+ target_model, include_relationships=include_relationships
+ )
+
+ system_prompt, user_prompt = self.prompt_builder.build_structured_prompts(
+ input_strings=input_strings,
+ custom_extraction_process=custom_extraction_process,
+ custom_extraction_guidelines=custom_extraction_guidelines,
+ custom_context=custom_context,
+ extraction_example_json=extraction_example_json,
+ expected_entity_descriptions=expected_entity_descriptions,
+ previous_entities=serializable_previous_entities,
+ target_model_name=target_model_name
+ if config.use_hierarchical_extraction
+ else None,
+ )
+ json_schema = wrapper_model.model_json_schema()
+
+ else:
+ if config.use_hierarchical_extraction:
+ schema_json = self.model_registry.get_schema_for_models(
+ [target_model_name]
+ )
+ else:
+ schema_json = self.model_registry.llm_schema_json
+
+ system_prompt, user_prompt = self.prompt_builder.build_prompts(
+ input_strings=input_strings,
+ schema_json=schema_json,
+ extraction_example_json=extraction_example_json,
+ custom_extraction_process=custom_extraction_process,
+ custom_extraction_guidelines=custom_extraction_guidelines,
+ custom_final_checklist=custom_final_checklist,
+ custom_context=custom_context,
+ expected_entity_descriptions=expected_entity_descriptions,
+ previous_entities=serializable_previous_entities,
+ target_model_name=target_model_name
+ if config.use_hierarchical_extraction
+ else None,
+ )
+
+ return ExtractionRequest(
+ system_prompt=system_prompt,
+ user_prompt=user_prompt,
+ json_schema=json_schema,
+ model_name=target_model_name,
+ response_model=wrapper_model,
+ )
diff --git a/src/extrai/core/hierarchical_extractor.py b/src/extrai/core/hierarchical_extractor.py
new file mode 100644
index 0000000..62f0234
--- /dev/null
+++ b/src/extrai/core/hierarchical_extractor.py
@@ -0,0 +1,129 @@
+import logging
+from typing import List, Dict, Any, Tuple, Optional
+
+from .model_registry import ModelRegistry
+from .prompt_builder import PromptBuilder
+from .entity_counter import EntityCounter
+from .llm_runner import LLMRunner
+from .model_wrapper_builder import ModelWrapperBuilder
+from .extraction_request_factory import ExtractionRequestFactory
+from extrai.core.extraction_config import ExtractionConfig
+from extrai.utils.serialization_utils import make_json_serializable
+
+
+class HierarchicalExtractor:
+ """
+ Performs hierarchical extraction by processing models level-by-level.
+
+ Uses breadth-first traversal to extract parent entities first,
+ then uses them as context for extracting children.
+ """
+
+ def __init__(
+ self,
+ model_registry: ModelRegistry,
+ prompt_builder: PromptBuilder,
+ entity_counter: EntityCounter,
+ llm_runner: LLMRunner,
+ logger: logging.Logger,
+ request_factory: ExtractionRequestFactory,
+ model_wrapper_builder: ModelWrapperBuilder = None,
+ use_structured_output: bool = False,
+ config: Optional[ExtractionConfig] = None,
+ ):
+ self.model_registry = model_registry
+ self.prompt_builder = prompt_builder
+ self.entity_counter = entity_counter
+ self.llm_runner = llm_runner
+ self.logger = logger
+ self.request_factory = request_factory
+ self.model_wrapper_builder = model_wrapper_builder
+ self.use_structured_output = use_structured_output
+ self.config = config
+
+ async def extract(
+ self,
+ input_strings: List[str],
+ extraction_example_json: str,
+ custom_extraction_process: str,
+ custom_extraction_guidelines: str,
+ custom_final_checklist: str,
+ custom_context: str,
+ count_entities: bool,
+ custom_counting_context: str = "",
+ ) -> List[Dict[str, Any]]:
+ """Executes hierarchical extraction."""
+ self.logger.info("Starting hierarchical extraction...")
+
+ models = self.model_registry.models
+ results_store: Dict[Tuple[str, str], Dict[str, Any]] = {}
+
+ for i, model_class in enumerate(models):
+ model_name = model_class.__name__
+ self.logger.info(f"Processing model: {model_name}")
+
+ # Count entities if needed
+ expected_entity_descriptions = None
+ if count_entities:
+ # Prepare previous entities for context
+ previous_entities = None
+ if results_store:
+ previous_entities = make_json_serializable(
+ list(results_store.values())
+ )
+
+ counts = await self.entity_counter.count_entities(
+ input_strings,
+ [model_name],
+ custom_counting_context,
+ previous_entities=previous_entities,
+ )
+ expected_entity_descriptions = counts.get(model_name)
+
+ if not self.config:
+ raise ValueError(
+ "ExtractionConfig is required for HierarchicalExtractor"
+ )
+
+ request = self.request_factory.prepare_request(
+ input_strings=input_strings,
+ config=self.config,
+ extraction_example_json=extraction_example_json,
+ custom_extraction_process=custom_extraction_process,
+ custom_extraction_guidelines=custom_extraction_guidelines,
+ custom_final_checklist=custom_final_checklist,
+ custom_context=custom_context,
+ expected_entity_descriptions=expected_entity_descriptions,
+ previous_entities=list(results_store.values())
+ if results_store
+ else None,
+ hierarchical_model_index=i,
+ )
+
+ if self.use_structured_output:
+ entities = await self.llm_runner.run_structured_extraction_cycle(
+ system_prompt=request.system_prompt,
+ user_prompt=request.user_prompt,
+ response_model=request.response_model,
+ )
+ else:
+ entities = await self.llm_runner.run_extraction_cycle(
+ system_prompt=request.system_prompt, user_prompt=request.user_prompt
+ )
+
+ # Store results
+ for idx, entity in enumerate(entities):
+ if "_type" not in entity:
+ entity["_type"] = model_name
+
+ temp_id = entity.get("_temp_id")
+ storage_id = temp_id if temp_id else f"__synthetic_{idx}__"
+
+ if (model_name, storage_id) not in results_store:
+ results_store[(model_name, storage_id)] = entity
+
+ self.logger.info(
+ f"Completed {model_name}. Total entities: {len(results_store)}"
+ )
+
+ return list(results_store.values())
diff --git a/src/extrai/core/json_consensus.py b/src/extrai/core/json_consensus.py
index 8d884cd..294a861 100644
--- a/src/extrai/core/json_consensus.py
+++ b/src/extrai/core/json_consensus.py
@@ -2,7 +2,7 @@
import logging
import math
from collections import Counter
-from typing import List, Dict, Any, Callable, Optional, Union, Tuple
+from typing import List, Dict, Any, Optional, Union, Tuple
from extrai.utils.flattening_utils import (
flatten_json,
unflatten_json,
@@ -12,6 +12,12 @@
JSONArray,
FlattenedJSON,
)
+from extrai.core.conflict_resolvers import (
+ ConflictResolutionStrategy,
+ default_conflict_resolver,
+ prefer_most_common_resolver,
+ levenshtein_similarity,
+)
# Sentinel value to indicate that no consensus was reached for a path.
_NO_CONSENSUS = object()
@@ -19,39 +25,11 @@
# Define a type for a list of JSON revisions
JSONRevisions = List[Union[JSONObject, JSONArray]]
-# Define conflict resolution strategies
-ConflictResolutionStrategy = Callable[[Path, List[JSONValue]], Optional[JSONValue]]
-
-
-def default_conflict_resolver(
- path: Path, values: List[JSONValue]
-) -> Optional[JSONValue]:
- """
- Default conflict resolution: if no consensus, omit the field.
- (Or choose the most common, even if below threshold, or raise error - configurable)
- For this default, we'll omit.
- """
- # print(f"Conflict at path {path}: Values {values}. No consensus achieved. Omitting.")
- return None
-
-
-def prefer_most_common_resolver(
- _path: Path, values: List[JSONValue]
-) -> Optional[JSONValue]:
- """
- Conflict resolution: prefer the most common value even if it doesn't meet the threshold.
- If multiple values have the same highest frequency, it picks one arbitrarily (based on Counter behavior).
- """
- if not values:
- return None
- count = Counter(values)
- most_common_value, _ = count.most_common(1)[0]
- return most_common_value
-
class JSONConsensus:
"""
Calculates a consensus JSON object from multiple JSON revisions.
+ Supports weighted consensus based on global revision similarity.
"""
def __init__(
@@ -68,14 +46,9 @@ def __init__(
Args:
consensus_threshold: The minimum proportion of revisions that must agree on a
value for it to be included in the consensus.
- E.g., 0.5 means more than 50% agreement needed.
- A value of 0.0 would mean any single occurrence is enough (less useful).
- A value of 1.0 would mean unanimous agreement is required.
conflict_resolver: A function to call when no value for a path meets the
- consensus threshold. It takes the path and list of
- conflicting values and returns a single value or None.
- If None, a default resolver that omits the field is used.
- logger: An optional logger instance. If not provided, a default logger is created.
+ consensus threshold.
+ logger: An optional logger instance.
"""
self.logger = logger or logging.getLogger(self.__class__.__name__)
if not logger:
@@ -86,7 +59,7 @@ def __init__(
"Extrai threshold must be between 0.0 (exclusive) and 1.0 (inclusive)."
)
self.consensus_threshold = consensus_threshold
- self.conflict_resolver = conflict_resolver
+ self.conflict_resolver = conflict_resolver or default_conflict_resolver
def get_consensus(
self, revisions: JSONRevisions
@@ -97,11 +70,17 @@ def get_consensus(
if not revisions:
return {}, analytics
- path_to_values = self._aggregate_paths(revisions)
- analytics["unique_paths_considered"] = len(path_to_values)
+ # Calculate revision weights based on global similarity (Jaccard-like on flattened fields)
+ revision_weights = self._calculate_revision_weights(revisions)
+ analytics["revision_weights"] = revision_weights
+
+ # Aggregate paths, keeping track of which revision provided which value
+ path_to_values_with_indices = self._aggregate_paths(revisions)
+ analytics["unique_paths_considered"] = len(path_to_values_with_indices)
+ # Build consensus
consensus_flat_json = self._build_consensus_json(
- path_to_values, num_revisions, analytics
+ path_to_values_with_indices, num_revisions, analytics, revision_weights
)
analytics["paths_in_consensus_output"] = len(consensus_flat_json)
@@ -109,6 +88,12 @@ def get_consensus(
consensus_flat_json, revisions
)
+ if analytics["unique_paths_considered"] > 0:
+ analytics["consensus_confidence_score"] = (
+ analytics["paths_agreed_by_threshold"]
+ / analytics["unique_paths_considered"]
+ )
+
return final_consensus_object, analytics
def _initialize_analytics(self, num_revisions: int) -> Dict[str, Any]:
@@ -119,31 +104,114 @@ def _initialize_analytics(self, num_revisions: int) -> Dict[str, Any]:
"paths_agreed_by_threshold": 0,
"paths_resolved_by_conflict_resolver": 0,
"paths_omitted_due_to_no_consensus_or_resolver_omission": 0,
+ "consensus_confidence_score": 0.0,
+ "average_string_similarity": 0.0, # Average Levenshtein ratio (1.0 = identical)
}
- def _aggregate_paths(self, revisions: JSONRevisions) -> Dict[Path, List[JSONValue]]:
- path_to_values: Dict[Path, List[JSONValue]] = {}
+ def _calculate_revision_weights(self, revisions: JSONRevisions) -> List[float]:
+ """
+ Calculates weights for each revision based on its similarity to other revisions.
+ Revisions that are similar to others get higher weights (centrality).
+ """
+ n = len(revisions)
+ if n <= 1:
+ return [1.0] * n
+
+ flat_revs = [flatten_json(r) for r in revisions]
+ scores = [[0.0] * n for _ in range(n)]
+
+ for i in range(n):
+ for j in range(i, n):
+ if i == j:
+ scores[i][j] = 1.0
+ continue
+
+ # Compare flat_revs[i] and flat_revs[j]
+ keys_i = set(flat_revs[i].keys())
+ keys_j = set(flat_revs[j].keys())
+ common_keys = keys_i.intersection(keys_j)
+ union_keys = keys_i.union(keys_j)
+
+ if not union_keys:
+ similarity = 0.0
+ else:
+ score_sum = 0.0
+ for k in common_keys:
+ val_i = flat_revs[i][k]
+ val_j = flat_revs[j][k]
+ if isinstance(val_i, str) and isinstance(val_j, str):
+ score_sum += levenshtein_similarity(val_i, val_j)
+ else:
+ score_sum += 1.0 if val_i == val_j else 0.0
+
+ similarity = score_sum / len(union_keys)
+
+ scores[i][j] = similarity
+ scores[j][i] = similarity
+
+ # Weight for rev i = sum of similarities to others
+ weights = []
+ for i in range(n):
+ w = sum(scores[i][j] for j in range(n) if i != j)
+ weights.append(w)
+
+ total = sum(weights)
+ if total == 0:
+ return [1.0 / n] * n
+
+ return [w / total for w in weights]
+
+ def _aggregate_paths(
+ self, revisions: JSONRevisions
+ ) -> Dict[Path, List[Tuple[JSONValue, int]]]:
+ """
+ Aggregates values for each path, preserving the source revision index.
+ """
+ path_to_values: Dict[Path, List[Tuple[JSONValue, int]]] = {}
flattened_revisions = [flatten_json(rev) for rev in revisions]
- for flat_rev in flattened_revisions:
+ for idx, flat_rev in enumerate(flattened_revisions):
for path, value in flat_rev.items():
- path_to_values.setdefault(path, []).append(value)
+ path_to_values.setdefault(path, []).append((value, idx))
return path_to_values
def _build_consensus_json(
self,
- path_to_values: Dict[Path, List[JSONValue]],
+ path_to_values: Dict[Path, List[Tuple[JSONValue, int]]],
num_revisions: int,
analytics: Dict[str, Any],
+ revision_weights: List[float],
) -> FlattenedJSON:
consensus_flat_json: FlattenedJSON = {}
- for path, values in path_to_values.items():
- agreed_value = self._get_consensus_for_path(path, values, num_revisions)
+
+ total_string_sim = 0.0
+ string_path_count = 0
+
+ for path, values_with_indices in path_to_values.items():
+ values = [v for v, i in values_with_indices]
+ indices = [i for v, i in values_with_indices]
+ current_weights = [revision_weights[i] for i in indices]
+
+ # Analytics: Calculate average string similarity
+ if len(values) > 1 and all(isinstance(v, str) for v in values):
+ path_sim_sum = 0.0
+ pair_count = 0
+ for i in range(len(values)):
+ for j in range(i + 1, len(values)):
+ path_sim_sum += levenshtein_similarity(values[i], values[j])
+ pair_count += 1
+ if pair_count > 0:
+ total_string_sim += path_sim_sum / pair_count
+ string_path_count += 1
+
+ agreed_value = self._get_consensus_for_path(
+ path, values, current_weights, num_revisions
+ )
if agreed_value is not _NO_CONSENSUS:
consensus_flat_json[path] = agreed_value
analytics["paths_agreed_by_threshold"] += 1
else:
- # This is a conflict. Record the disagreement details.
+ # Conflict
value_counts = Counter(values)
disagreement_details = {
"path": ".".join(map(str, path)),
@@ -155,19 +223,23 @@ def _build_consensus_json(
disagreement_details
)
- # For '_temp_id' and '_type', always prefer the most common value.
+ # Special handling for _temp_id and _type
if path[-1] in ["_temp_id", "_type"]:
self.logger.debug(
f"Conflict at path '{'.'.join(map(str, path))}': "
f"Using most common value resolver for special attribute."
)
- resolved_value = prefer_most_common_resolver(path, values)
+ resolved_value = prefer_most_common_resolver(
+ path, values, current_weights
+ )
else:
self.logger.debug(
f"Conflict at path '{'.'.join(map(str, path))}': "
f"Invoking custom conflict resolver. Values: {values}"
)
- resolved_value = self.conflict_resolver(path, values)
+ resolved_value = self.conflict_resolver(
+ path, values, current_weights
+ )
if resolved_value is not None:
self.logger.debug(
@@ -183,23 +255,50 @@ def _build_consensus_json(
analytics[
"paths_omitted_due_to_no_consensus_or_resolver_omission"
] += 1
+
+ if string_path_count > 0:
+ analytics["average_string_similarity"] = (
+ total_string_sim / string_path_count
+ )
+
return consensus_flat_json
def _get_consensus_for_path(
- self, path: Path, values: List[JSONValue], num_revisions: int
+ self,
+ path: Path,
+ values: List[JSONValue],
+ weights: List[float],
+ num_revisions: int,
) -> Union[JSONValue, object]:
- most_common_candidate = prefer_most_common_resolver(path, values)
- max_count = values.count(most_common_candidate)
-
- # Unanimity check
- is_unanimous = max_count == num_revisions
- if math.isclose(self.consensus_threshold, 1.0):
- return most_common_candidate if is_unanimous else _NO_CONSENSUS
-
- # Threshold check
- agreement_ratio = max_count / num_revisions
- if agreement_ratio > self.consensus_threshold:
- return most_common_candidate
+ # Use weighted voting if weights provided
+ most_common_candidate = prefer_most_common_resolver(path, values, weights)
+
+ # Calculate agreement ratio based on weights if available, else count
+ if weights and len(weights) == len(values):
+ aggrement_ratio = sum(
+ w for v, w in zip(values, weights) if v == most_common_candidate
+ )
+ max_count = values.count(most_common_candidate)
+ is_unanimous = max_count == num_revisions
+ if math.isclose(self.consensus_threshold, 1.0):
+ return most_common_candidate if is_unanimous else _NO_CONSENSUS
+
+ # Threshold check
+ if aggrement_ratio > self.consensus_threshold:
+ return most_common_candidate
+
+ return _NO_CONSENSUS
+
+ else:
+ # Fallback to unweighted
+ max_count = values.count(most_common_candidate)
+ is_unanimous = max_count == num_revisions
+ if math.isclose(self.consensus_threshold, 1.0):
+ return most_common_candidate if is_unanimous else _NO_CONSENSUS
+
+ agreement_ratio = max_count / num_revisions
+ if agreement_ratio > self.consensus_threshold:
+ return most_common_candidate
return _NO_CONSENSUS
diff --git a/src/extrai/core/llm_runner.py b/src/extrai/core/llm_runner.py
new file mode 100644
index 0000000..d26510e
--- /dev/null
+++ b/src/extrai/core/llm_runner.py
@@ -0,0 +1,415 @@
+# extrai/core/llm_runner.py
+
+import logging
+import asyncio
+from typing import List, Dict, Any, Union
+
+from .model_registry import ModelRegistry
+from .extraction_config import ExtractionConfig
+from .json_consensus import JSONConsensus, default_conflict_resolver
+from .analytics_collector import WorkflowAnalyticsCollector
+from .base_llm_client import BaseLLMClient
+from .errors import (
+ LLMInteractionError,
+ ConsensusProcessError,
+ LLMConfigurationError,
+ LLMOutputParseError,
+ LLMOutputValidationError,
+ LLMAPICallError,
+)
+from extrai.utils.alignment_utils import normalize_json_revisions
+
+
+class LLMRunner:
+ """
+ Manages LLM client rotation and extraction cycles.
+
+ Responsibilities:
+ - Rotate through multiple LLM clients for load balancing
+ - Execute parallel LLM calls for multiple revisions
+ - Run consensus mechanism on results
+ - Handle LLM-related errors gracefully
+
+ This class abstracts away the complexity of managing multiple LLM
+ clients and coordinating their outputs through consensus.
+ """
+
+ def __init__(
+ self,
+ model_registry: ModelRegistry,
+ llm_client: Union[BaseLLMClient, List[BaseLLMClient]],
+ config: ExtractionConfig,
+ analytics_collector: WorkflowAnalyticsCollector,
+ logger: logging.Logger,
+ ):
+ """
+ Initialize the LLM runner.
+
+ Args:
+ model_registry: Registry of SQLModel schemas
+ llm_client: Single client or list of LLM clients
+ config: Extraction configuration
+ analytics_collector: Collector for tracking metrics
+ logger: Logger instance
+
+ Raises:
+ ValueError: If llm_client list is empty or contains invalid clients
+ """
+ self.model_registry = model_registry
+ self.config = config
+ self.analytics_collector = analytics_collector
+ self.logger = logger
+
+ # Setup clients with validation
+ self.clients = self._setup_clients(llm_client)
+ self.client_index = 0
+
+ # Setup consensus mechanism
+ self.consensus = JSONConsensus(
+ consensus_threshold=config.consensus_threshold,
+ conflict_resolver=config.conflict_resolver or default_conflict_resolver,
+ logger=logger,
+ )
+
+ self.logger.info(
+ f"LLMRunner initialized with {len(self.clients)} client(s), "
+ f"{config.num_llm_revisions} revisions per cycle"
+ )
+
+ def _setup_clients(
+ self, llm_client: Union[BaseLLMClient, List[BaseLLMClient]]
+ ) -> List[BaseLLMClient]:
+ """
+ Validates and normalizes LLM client input.
+
+ Args:
+ llm_client: Single client or list of clients
+
+ Returns:
+ List of validated LLM clients
+
+ Raises:
+ ValueError: If input is invalid
+ """
+ if isinstance(llm_client, list):
+ if not llm_client:
+ raise ValueError("llm_client list cannot be empty")
+
+ if not all(isinstance(c, BaseLLMClient) for c in llm_client):
+ raise ValueError(
+ "All items in llm_client list must be instances of BaseLLMClient"
+ )
+ clients = llm_client
+ elif isinstance(llm_client, BaseLLMClient):
+ clients = [llm_client]
+ else:
+ raise ValueError(
+ "llm_client must be an instance of BaseLLMClient or a list of them"
+ )
+
+ # Set logger on all clients
+ for client in clients:
+ client.logger = self.logger
+
+ return clients
+
+ def get_next_client(self) -> BaseLLMClient:
+ """
+ Returns next client in rotation (round-robin).
+
+ This enables load balancing across multiple LLM providers
+ or API keys.
+
+ Returns:
+ Next BaseLLMClient in the rotation
+ """
+ client = self.clients[self.client_index]
+ self.client_index = (self.client_index + 1) % len(self.clients)
+ return client
+
+ async def run_extraction_cycle(
+ self, system_prompt: str, user_prompt: str
+ ) -> List[Dict[str, Any]]:
+ """
+ Runs a complete extraction cycle.
+
+ Steps:
+ 1. Generate multiple revisions in parallel using different clients
+ 2. Normalize results to handle array ordering issues
+ 3. Run consensus mechanism to reconcile differences
+ 4. Return processed output
+
+ Args:
+ system_prompt: System prompt for LLM
+ user_prompt: User prompt containing documents
+
+ Returns:
+ List of consensus entity dictionaries
+
+ Raises:
+ LLMInteractionError: If LLM calls fail
+ ConsensusProcessError: If consensus fails
+ """
+ self.logger.info(
+ f"Starting extraction cycle with {self.config.num_llm_revisions} revisions"
+ )
+
+ # Step 1: Generate revisions in parallel
+ revisions = await self._generate_revisions(system_prompt, user_prompt)
+
+ self.logger.debug(f"Generated {len(revisions)} revisions before normalization")
+
+ # Step 2: Normalize for consensus (handles array ordering)
+ revisions = normalize_json_revisions(revisions)
+
+ self.logger.debug(f"Normalized to {len(revisions)} revisions for consensus")
+
+ # Step 3: Run consensus
+ results = self._run_consensus(revisions)
+
+ self.logger.info(f"Extraction cycle completed with {len(results)} entities")
+
+ return results
+
+ async def run_structured_extraction_cycle(
+ self,
+ system_prompt: str,
+ user_prompt: str,
+ response_model: Any,
+ ) -> List[Dict[str, Any]]:
+ """
+ Runs a structured extraction cycle using response_model directly.
+ """
+ self.logger.info(
+ f"Starting structured extraction cycle with {self.config.num_llm_revisions} revisions"
+ )
+
+ tasks = []
+ for i in range(self.config.num_llm_revisions):
+ client = self.get_next_client()
+ tasks.append(
+ asyncio.create_task(
+ client.generate_structured(
+ system_prompt=system_prompt,
+ user_prompt=user_prompt,
+ response_model=response_model,
+ )
+ )
+ )
+
+ try:
+ results = await asyncio.gather(*tasks)
+ except Exception as e:
+ self.logger.error(f"Structured extraction failed: {e}")
+ raise LLMInteractionError(f"Structured extraction failed: {e}") from e
+
+ # Convert Pydantic models to dicts
+ revisions = []
+ for result in results:
+ if hasattr(result, "model_dump"):
+ revisions.append(result.model_dump(mode="json"))
+ elif hasattr(result, "dict"):
+ revisions.append(result.dict())
+ else:
+ self.logger.warning(f"Result {type(result)} is not a Pydantic model.")
+
+ # Extract the list of entities if present
+ normalized_revisions = []
+ for rev in revisions:
+ if "entities" in rev and isinstance(rev["entities"], list):
+ normalized_revisions.append(rev["entities"])
+ else:
+ normalized_revisions.append(rev)
+
+ # Step 2: Normalize
+ normalized_revisions = normalize_json_revisions(normalized_revisions)
+
+ # Step 3: Consensus
+ final_results = self._run_consensus(normalized_revisions)
+
+ return final_results
+
+ async def _generate_revisions(
+ self, system_prompt: str, user_prompt: str
+ ) -> List[Any]:
+ """
+ Generates multiple LLM revisions in parallel.
+
+ Each revision is generated by a different client (round-robin)
+ to distribute load and increase diversity of outputs.
+
+ Args:
+ system_prompt: System prompt for LLM
+ user_prompt: User prompt containing documents
+
+ Returns:
+ List of revision outputs from LLM clients
+
+ Raises:
+ LLMInteractionError: If LLM interaction fails
+ """
+ tasks = []
+
+ # Create parallel tasks for each revision
+ for i in range(self.config.num_llm_revisions):
+ client = self.get_next_client()
+
+ self.logger.debug(
+ f"Creating revision task {i + 1}/{self.config.num_llm_revisions} "
+ f"with client {type(client).__name__}"
+ )
+
+ task = asyncio.create_task(
+ client.generate_json_revisions(
+ system_prompt=system_prompt,
+ user_prompt=user_prompt,
+ num_revisions=1, # Each task generates 1 revision
+ model_schema_map=self.model_registry.model_map,
+ max_validation_retries_per_revision=self.config.max_validation_retries_per_revision,
+ analytics_collector=self.analytics_collector,
+ )
+ )
+ tasks.append(task)
+
+ # Execute all tasks in parallel
+ try:
+ revisions = await asyncio.gather(*tasks)
+
+ # Validate we got results
+ if not revisions and self.config.num_llm_revisions > 0:
+ raise LLMInteractionError(
+ "LLM client returned no revisions despite being requested."
+ )
+
+ return revisions
+ except (
+ LLMConfigurationError,
+ LLMOutputParseError,
+ LLMOutputValidationError,
+ LLMAPICallError,
+ LLMInteractionError,
+ ) as client_err:
+ # Known LLM client errors
+ self.logger.error(f"LLM client operation failed: {client_err}")
+ raise LLMInteractionError(
+ f"LLM client operation failed: {client_err}"
+ ) from client_err
+
+ except Exception as e:
+ # Unexpected errors
+ self.logger.error(f"Unexpected error during LLM interaction: {e}")
+ raise LLMInteractionError(
+ f"An unexpected error occurred during LLM interaction: {e}"
+ ) from e
+
+ def _run_consensus(self, revisions: List[Any]) -> List[Dict[str, Any]]:
+ """
+ Runs consensus mechanism on revisions.
+
+ Args:
+ revisions: List of normalized revision outputs
+
+ Returns:
+ List of consensus entity dictionaries
+
+ Raises:
+ ConsensusProcessError: If consensus fails
+ """
+ try:
+ self.logger.debug(f"Running consensus on {len(revisions)} revisions")
+
+ # Run consensus
+ consensus_output, details = self.consensus.get_consensus(revisions)
+
+ # Record analytics if available
+ if details:
+ self.analytics_collector.record_consensus_run_details(details)
+ self.logger.debug(f"Consensus details: {details}")
+
+ # Process and normalize output
+ processed = self._process_consensus_output(consensus_output)
+
+ self.logger.debug(f"Consensus produced {len(processed)} entities")
+
+ return processed
+
+ except ConsensusProcessError:
+ # Re-raise consensus errors as-is
+ raise
+
+ except Exception as e:
+ # Wrap unexpected errors
+ self.logger.error(f"Consensus processing failed: {e}")
+ raise ConsensusProcessError(
+ f"Failed during JSON consensus processing: {e}"
+ ) from e
+
+ def _process_consensus_output(self, consensus_output: Any) -> List[Dict[str, Any]]:
+ """
+ Normalizes consensus output to list format.
+
+ The consensus mechanism can return various formats:
+ - None (no consensus reached)
+ - List of dicts (standard format)
+ - Dict with 'results' key
+ - Single dict (wrap in list)
+
+ Args:
+ consensus_output: Raw output from consensus mechanism
+
+ Returns:
+ Normalized list of entity dictionaries
+
+ Raises:
+ ConsensusProcessError: If output format is unexpected
+ """
+ # Handle None
+ if consensus_output is None:
+ self.logger.warning("Consensus returned None, returning empty list")
+ return []
+
+ # Handle list (standard format)
+ if isinstance(consensus_output, list):
+ return consensus_output
+
+ # Handle dict
+ if isinstance(consensus_output, dict):
+ # Check for 'results' key (wrapped format)
+ if "results" in consensus_output and isinstance(
+ consensus_output["results"], list
+ ):
+ return consensus_output["results"]
+
+ # Single entity dict, wrap in list
+ return [consensus_output]
+
+ # Unexpected type
+ raise ConsensusProcessError(
+ f"Unexpected consensus output type: {type(consensus_output)}. "
+ f"Expected None, list, or dict."
+ )
+
+ def get_client_count(self) -> int:
+ """
+ Returns the number of LLM clients in rotation.
+
+ Returns:
+ Number of clients
+ """
+ return len(self.clients)
+
+ def reset_client_rotation(self):
+ """
+ Resets client rotation to start from the first client.
+
+ Useful for testing or ensuring consistent behavior.
+ """
+ self.client_index = 0
+ self.logger.debug("Client rotation reset to index 0")
+
+ def __repr__(self) -> str:
+ """String representation of the runner."""
+ return (
+ f"LLMRunner(clients={len(self.clients)}, "
+ f"revisions={self.config.num_llm_revisions})"
+ )
diff --git a/src/extrai/core/model_registry.py b/src/extrai/core/model_registry.py
new file mode 100644
index 0000000..06eca47
--- /dev/null
+++ b/src/extrai/core/model_registry.py
@@ -0,0 +1,175 @@
+# extrai/core/model_registry.py
+
+import json
+import logging
+from typing import Type, List, Optional
+from sqlmodel import SQLModel
+
+from .errors import ConfigurationError
+from .schema_inspector import SchemaInspector
+
+
+class ModelRegistry:
+ """
+ Manages SQLModel schemas and their JSON representations.
+
+ Responsibilities:
+ - Discover all models from root
+ - Generate JSON schemas for LLM
+ - Provide model lookup by name
+ - Cache schemas for performance
+ """
+
+ def __init__(
+ self, root_model: Type[SQLModel], logger: Optional[logging.Logger] = None
+ ):
+ """
+ Initialize the model registry.
+
+ Args:
+ root_model: The root SQLModel class to discover from
+ logger: Optional logger instance
+
+ Raises:
+ ConfigurationError: If model discovery or schema generation fails
+ """
+ self.logger = logger or logging.getLogger(__name__)
+ self.root_model = root_model
+ self.inspector = SchemaInspector(self.logger)
+
+ # Validate root model
+ try:
+ if not root_model or not issubclass(root_model, SQLModel):
+ raise ConfigurationError("root_model must be a valid SQLModel class")
+ except TypeError:
+ raise ConfigurationError("root_model must be a valid SQLModel class")
+
+ # Discover and validate models
+ self.models = self._discover_models(root_model)
+ self.model_map = {m.__name__: m for m in self.models}
+
+ # Generate schemas
+ self.llm_schema_json = self._generate_llm_schema()
+
+ self.logger.info(
+ f"ModelRegistry initialized with {len(self.models)} models: "
+ f"{', '.join(self.model_map.keys())}"
+ )
+
+ def _discover_models(self, root_model: Type[SQLModel]) -> List[Type[SQLModel]]:
+ """
+ Discovers all SQLModel classes from root.
+
+ Args:
+ root_model: The root model to start discovery from
+
+ Returns:
+ List of discovered SQLModel classes
+
+ Raises:
+ ConfigurationError: If discovery fails or no models found
+ """
+ try:
+ models = self.inspector.discover_sqlmodels_from_root(root_model)
+ if not models:
+ raise ConfigurationError(
+ f"No SQLModel classes discovered from root model {root_model.__name__}"
+ )
+ return models
+ except Exception as e:
+ raise ConfigurationError(
+ f"Failed to discover SQLModel classes from {root_model.__name__}: {e}"
+ ) from e
+
+ def _generate_llm_schema(self) -> str:
+ """
+ Generates JSON schema for LLM prompts.
+
+ Returns:
+ JSON string representation of the schema
+
+ Raises:
+ ConfigurationError: If schema generation fails or produces invalid JSON
+ """
+ try:
+ schema = self.inspector.generate_llm_schema_from_models(self.models)
+ if not schema:
+ raise ConfigurationError("Generated LLM schema is empty")
+
+ # Validate it's valid JSON
+ json.loads(schema)
+ return schema
+ except json.JSONDecodeError as e:
+ raise ConfigurationError(
+ f"Generated LLM schema is invalid JSON: {e}"
+ ) from e
+ except Exception as e:
+ raise ConfigurationError(f"Failed to generate LLM schema: {e}") from e
+
+ def get_schema_for_models(self, model_names: List[str]) -> str:
+ """
+ Generates schema JSON for specific models.
+
+ Args:
+ model_names: List of model class names to include in schema
+
+ Returns:
+ JSON string representation of the schema for specified models
+
+ Note:
+ If no valid models are found in model_names, returns the full schema
+ """
+ models = [
+ self.model_map[name] for name in model_names if name in self.model_map
+ ]
+
+ if not models:
+ self.logger.warning(
+ f"No valid models found in {model_names}, using full schema"
+ )
+ return self.llm_schema_json
+
+ try:
+ return self.inspector.generate_llm_schema_from_models(models)
+ except Exception as e:
+ self.logger.error(f"Failed to generate schema for {model_names}: {e}")
+ return self.llm_schema_json
+
+ def get_model_by_name(self, name: str) -> Optional[Type[SQLModel]]:
+ """
+ Retrieves a model class by name.
+
+ Args:
+ name: The name of the model class
+
+ Returns:
+ The SQLModel class if found, None otherwise
+ """
+ return self.model_map.get(name)
+
+ def get_all_model_names(self) -> List[str]:
+ """
+ Returns list of all discovered model names.
+
+ Returns:
+ List of model class names
+ """
+ return list(self.model_map.keys())
+
+ def has_model(self, name: str) -> bool:
+ """
+ Checks if a model with the given name exists.
+
+ Args:
+ name: The name of the model class
+
+ Returns:
+ True if model exists, False otherwise
+ """
+ return name in self.model_map
+
+ def __repr__(self) -> str:
+ """String representation of the registry."""
+ return (
+ f"ModelRegistry(root={self.root_model.__name__}, models={len(self.models)})"
+ )
diff --git a/src/extrai/core/model_wrapper_builder.py b/src/extrai/core/model_wrapper_builder.py
new file mode 100644
index 0000000..e7b1393
--- /dev/null
+++ b/src/extrai/core/model_wrapper_builder.py
@@ -0,0 +1,157 @@
+from typing import Type, List, Optional, Any, Dict
+from pydantic import BaseModel, create_model, Field
+from sqlmodel import SQLModel
+from sqlalchemy import inspect
+from sqlalchemy.orm import RelationshipProperty
+
+
+class ModelWrapperBuilder:
+ """
+ Utility class to convert SQLModel classes (with Relationship fields)
+ into pure Pydantic models suitable for structured LLM output (e.g. OpenAI).
+ Replaces Relationships with nested models.
+ """
+
+ def __init__(self):
+ self._generated_models: Dict[Type[SQLModel], Type[BaseModel]] = {}
+
+ def generate_wrapper_model(
+ self, root_sqlmodel: Type[SQLModel], include_relationships: bool = True
+ ) -> Type[BaseModel]:
+ """
+ Generates a Pydantic wrapper model for the given root SQLModel.
+ This wrapper creates a hierarchy of Pydantic models where relationships
+ are replaced by nested lists or single instances of the related Pydantic model.
+
+ It also wraps the result in a container to ensure we capture a list of the root entities.
+
+ Args:
+ root_sqlmodel: The root SQLModel class.
+ include_relationships: If False, relationships will be excluded from the schema.
+ Useful for hierarchical extraction steps.
+ """
+ self._generated_models = {}
+
+ pydantic_model = self._create_pydantic_model_recursive(
+ root_sqlmodel, include_relationships
+ )
+
+ # Let's define a wrapper that has a field `entities` which is a list of the root model.
+ wrapper_name = f"{root_sqlmodel.__name__}ExtractionResult"
+
+ wrapper_model = create_model(
+ wrapper_name,
+ entities=(
+ List[pydantic_model],
+ Field(
+ description=f"List of extracted {root_sqlmodel.__name__} entities."
+ ),
+ ),
+ )
+
+ return wrapper_model
+
+ def _enrich_field_description(self, field_info: Any) -> Any:
+ """
+ Appends validation constraints to the field description to help the LLM.
+ """
+ import copy
+
+ new_field_info = copy.copy(field_info)
+
+ # Using .metadata to access constraints from Pydantic v2
+ constraints = []
+ if hasattr(new_field_info, "metadata"):
+ for item in new_field_info.metadata:
+ if hasattr(item, "max_length"):
+ constraints.append(f"max_length={item.max_length}")
+ if hasattr(item, "min_length"):
+ constraints.append(f"min_length={item.min_length}")
+ # Looking for numeric constraints
+ if hasattr(item, "ge"):
+ constraints.append(f"min_value={item.ge}")
+ if hasattr(item, "le"):
+ constraints.append(f"max_value={item.le}")
+ if hasattr(item, "gt"):
+ constraints.append(f"greater_than={item.gt}")
+ if hasattr(item, "lt"):
+ constraints.append(f"less_than={item.lt}")
+
+ # Fallback for older Pydantic or direct attributes if metadata is not used
+ if (
+ hasattr(new_field_info, "max_length")
+ and new_field_info.max_length is not None
+ ):
+ if f"max_length={new_field_info.max_length}" not in constraints:
+ constraints.append(f"max_length={new_field_info.max_length}")
+
+ if constraints:
+ constraint_str = "Constraints: " + ", ".join(constraints)
+ if new_field_info.description:
+ new_field_info.description = (
+ f"{new_field_info.description} ({constraint_str})"
+ )
+ else:
+ new_field_info.description = constraint_str
+
+ return new_field_info
+
+ def _create_pydantic_model_recursive(
+ self, sql_model: Type[SQLModel], include_relationships: bool = True
+ ) -> Type[BaseModel]:
+ if sql_model in self._generated_models:
+ return self._generated_models[sql_model]
+
+ model_name = f"{sql_model.__name__}Structure"
+
+ fields = {}
+ inspector = inspect(sql_model)
+
+ for name, field_info in sql_model.model_fields.items():
+ # Check if it is a relationship field, if so, we will handle it later
+ if name in inspector.relationships:
+ continue
+
+ # Enrich description with constraints
+ enriched_field_info = self._enrich_field_description(field_info)
+ fields[name] = (field_info.annotation, enriched_field_info)
+
+ relationships = {}
+ if include_relationships:
+ inspector = inspect(sql_model)
+
+ for rel in inspector.relationships:
+ if isinstance(rel, RelationshipProperty):
+ target_model = rel.mapper.class_
+
+ if rel.direction.name == "MANYTOONE":
+ # Skip child->parent links to enforce hierarchy
+ continue
+
+ # Recurse
+ nested_model = self._create_pydantic_model_recursive(
+ target_model, include_relationships
+ )
+
+ if rel.uselist:
+ # List[NestedModel]
+ field_type = List[nested_model]
+ field_desc = f"List of {target_model.__name__} items."
+ else:
+ # NestedModel (Optional?)
+ field_type = Optional[nested_model]
+ field_desc = f"Related {target_model.__name__} item."
+
+ relationships[rel.key] = (
+ field_type,
+ Field(default=None, description=field_desc),
+ )
+
+ # Merge fields
+ all_fields = {**fields, **relationships}
+
+ # Create the model
+ model = create_model(model_name, **all_fields)
+
+ self._generated_models[sql_model] = model
+ return model
diff --git a/src/extrai/core/prompt_builder.py b/src/extrai/core/prompt_builder.py
index e9b4922..1f4e226 100644
--- a/src/extrai/core/prompt_builder.py
+++ b/src/extrai/core/prompt_builder.py
@@ -1,318 +1,110 @@
-def generate_system_prompt(
- schema_json: str,
- extraction_example_json: str = "",
- custom_extraction_process: str = "",
- custom_extraction_guidelines: str = "",
- custom_final_checklist: str = "",
- custom_context: str = "",
-) -> str:
- """
- Generates a generic system prompt for guiding an LLM to extract information
- from text and structure it according to a provided JSON schema.
-
- Args:
- schema_json: A string containing the JSON schema for the target data structure.
- extraction_example_json: An optional string containing an example of a JSON
- object that conforms to the schema.
- custom_extraction_process: Optional custom instructions for the extraction process.
- custom_extraction_guidelines: Optional custom guidelines for extraction.
- custom_final_checklist: Optional custom final checklist for the LLM.
- custom_context: Optional custom contextual information to be included in the prompt.
-
- Returns:
- A string representing the system prompt.
- """
-
- default_extraction_process = """\
-# EXTRACTION PROCESS
-Follow this step-by-step process meticulously:
-1. **Understand the Goal:** Your primary objective is to extract information from the provided text and structure it precisely according to the JSON schema.
-2. **Full Text Analysis:** Read and comprehend the entirety of the provided document(s) before initiating extraction. This helps in understanding context and relationships.
-3. **Schema Adherence:** The provided JSON schema is your definitive guide. All extracted data must conform to this schema in terms of structure, field names, and data types.
-4. **Identify Relevant Data:** Locate all data points within the text that correspond to the fields defined in the JSON schema.
-5. **Map Data to Schema:** Carefully assign the identified data to the correct fields in the schema.
-6. **Handle Ambiguity and Missing Information:**
- * If information for a field is ambiguous, use your reasoning capabilities to determine the most plausible interpretation based on the context.
- * If information for an optional field is not present, omit the field or use `null` if the schema allows.
- * For required fields, if information is genuinely missing and cannot be inferred, this is a critical issue. However, strive to find or infer it. If the schema defines a default, consider that.
-7. **Prioritize Explicit Information:** Base your extraction on information explicitly stated in the text. Avoid making assumptions unless absolutely necessary and clearly justifiable by the context.
-8. **Synthesize from Multiple Documents:** If multiple documents are provided, synthesize the information comprehensively. If conflicting information arises, prioritize what appears to be the most current, official, or reliable source. Note any significant discrepancies if the output format allows, but the primary goal is a single coherent JSON.
-9. **Data Type Conformance:** Strictly adhere to the data types specified in the JSON schema (e.g., string, number, boolean, array, object). Numbers should be formatted as numbers (e.g., `123`, `12.34`), not strings containing numbers (e.g., `"123"`). Booleans should be `true` or `false`.
-10. **Nested Structures and Relationships:**
- * For nested objects or arrays, ensure your JSON output accurately reflects the hierarchical structure defined in the schema.
- * If the schema implies relationships between different entities (e.g., using foreign keys or requiring linking), ensure these are correctly represented.
- * If temporary identifiers are needed to link entities within the JSON output (e.g., for items that will later become related records in a database), generate unique and descriptive temporary IDs (e.g., "_temp_id_entityName_XYZ123"). Use these temporary IDs consistently for all references within the current JSON output.
"""
-
- default_extraction_guidelines = """\
-# IMPORTANT EXTRACTION GUIDELINES
-- **Output Format:** Your entire output must be a single, valid JSON object. Do not include any explanatory text, comments, apologies, or any other content before or after the JSON object.
-- **Output Structure Mandate:** Your response MUST be a single JSON object. This object MUST have a single top-level key named "result". The value of this "result" key MUST be the JSON object that conforms to the provided JSON schema. Example: `{"result": {your_schema_compliant_object_here}}`. Do NOT use any other top-level keys. Do NOT return the schema-compliant object directly as the root.
-- **Field Names:** Use the exact field names (case-sensitive) as specified in the JSON schema for the object under the "result" key.
-- **Structured Elements:** Pay close attention to structured elements within the text, such as tables, lists, headings, and emphasized text, as they often contain key information.
-- **Dates and Times:** Unless the schema specifies a different format, use ISO 8601 format for dates (YYYY-MM-DD) and date-times (YYYY-MM-DDTHH:MM:SSZ).
-- **Enumerations (Enums):** If a field in the schema is an enumeration with a predefined set of allowed values, ensure that the extracted value is one of those permitted values.
-- **Null Values:** Use `null` for optional fields where data is not available or not applicable, provided the schema allows for null values for that field. Do not use strings like "N/A", "Not available", or empty strings "" unless the schema explicitly defines such string literals as valid values.
-- **String Values:** Ensure all string values in the JSON are correctly escaped (e.g., quotes within strings).
-- **Meticulousness:** Accuracy is paramount. Double-check your extracted data against the source text and the schema before finalizing your output.
-"""
-
- default_final_checklist = """\
-# FINAL CHECK BEFORE SUBMISSION
-1. **Valid JSON?** Is the entire output a single, syntactically correct JSON object?
-2. **Output Structure Correct?** Does the output JSON object have a single top-level key named "result"?
-3. **Schema Conformity?** Does the JSON object under the "result" key strictly adhere to all aspects of the provided JSON schema (all required fields present, correct data types for all values, correct structure for nested objects and arrays)?
-4. **Field Name Accuracy?** Are all field names within the object under the "result" key exactly as specified in the schema (case-sensitive)?
-5. **Relationship Integrity?** If temporary IDs or other linking mechanisms were required within the object under the "result" key, are they used correctly and consistently?
-6. **Null Handling?** Are `null` values used appropriately for missing optional data, according to schema constraints?
-7. **No Extraneous Text?** Is there absolutely no text or characters outside of the main JSON object itself?
-"""
-
- # Use custom instructions if provided, otherwise use defaults
- extraction_process = custom_extraction_process or default_extraction_process
- extraction_guidelines = (
- custom_extraction_guidelines or default_extraction_guidelines
- )
- final_checklist = custom_final_checklist or default_final_checklist
-
- prompt_parts = [
- "You are an advanced AI specializing in data extraction and structuring. Your task is to analyze user-provided text and transform the relevant information into a structured JSON object, strictly adhering to the provided JSON schema.",
- "You must focus on precision, accuracy, and complete adherence to the schema.",
- "\n# JSON SCHEMA TO ADHERE TO:",
- "```json",
- schema_json,
- "```",
- ]
-
- if custom_context: # New block
- prompt_parts.append("\n# ADDITIONAL CONTEXT:")
- prompt_parts.append(custom_context)
-
- prompt_parts.extend([f"\n{extraction_process}", f"\n{extraction_guidelines}"])
-
- if extraction_example_json:
- prompt_parts.append("\n# EXAMPLE OF EXTRACTION:")
- prompt_parts.append(
- "## CONCEPTUAL INPUT TEXT (This is illustrative; your actual input text will be different):"
- )
- prompt_parts.append(
- "\"Imagine a piece of text that contains details about an entity or event. For instance, if the schema is about a 'Book', the text might say: 'The Great Novel, written by Jane Author in 2023, has 300 pages and is published by World Publishers. ISBN: 978-0123456789.'\""
- )
- prompt_parts.append(
- "## EXAMPLE EXTRACTED JSON (This JSON conforms to the schema based on the conceptual text above):"
- )
- prompt_parts.append("```json")
-
- if extraction_example_json.strip().startswith(
- "{"
- ) and extraction_example_json.strip().endswith("}"):
- prompt_parts.append(f'{{\n "result": {extraction_example_json}\n}}')
- else:
- prompt_parts.append(extraction_example_json)
- prompt_parts.append("```")
-
- prompt_parts.append(f"\n{final_checklist}")
- prompt_parts.append(
- "\nProceed with the extraction based on the user's documents. Your response MUST be only the single, valid JSON object. Do not include any other narrative, explanations, or conversational elements in your output."
- )
-
- return "\n\n".join(prompt_parts).strip()
-
-
-def generate_user_prompt_for_docs(
- documents: list[str], custom_context: str = ""
-) -> str:
- """
- Generates a simple user prompt containing the documents for extraction.
-
- Args:
- documents: A list of strings, where each string is a document
- or a piece of text for extraction.
- custom_context: Optional custom contextual information to be included in the prompt.
-
- Returns:
- A string representing the user prompt with the documents.
- """
- separator = "\n\n---END OF DOCUMENT---\n\n---START OF NEW DOCUMENT---\n\n"
- combined_documents = separator.join(documents)
-
- prompt = """
-Please extract information from the following document(s) strictly according to the schema and instructions previously provided (in the system prompt).
+This module serves as a facade for the prompt generation logic, which has been
+modularized into the `extrai.core.prompts` package.
"""
- if custom_context:
- prompt += f"\n{custom_context}\n"
-
- prompt += f"""
-# DOCUMENT(S) FOR EXTRACTION:
-
-{combined_documents}
-
----
-Remember: Your output must be only a single, valid JSON object.
-""".strip()
- return prompt
-
-def generate_sqlmodel_creation_system_prompt(
- schema_json: str, user_task_description: str
-) -> str:
+import logging
+from typing import List, Optional, Tuple
+from extrai.core.model_registry import ModelRegistry
+
+from extrai.core.prompts.common import generate_user_prompt_for_docs
+from extrai.core.prompts.extraction import (
+ generate_system_prompt,
+)
+from extrai.core.prompts.structured_extraction import (
+ generate_structured_system_prompt,
+)
+from extrai.core.prompts.sqlmodel import (
+ generate_sqlmodel_creation_system_prompt,
+)
+from extrai.core.prompts.counting import (
+ generate_entity_counting_system_prompt,
+ generate_entity_counting_user_prompt,
+)
+from extrai.core.prompts.examples import (
+ generate_prompt_for_example_json_generation,
+)
+
+
+class PromptBuilder:
"""
- Generates a specialized system prompt for guiding an LLM to create a
- SQLModel class description (as a JSON object).
-
- The LLM will be given input documents (via the user prompt) and this system
- prompt. Its goal is to produce a JSON object that describes a new SQLModel,
- and this JSON object must conform to the `schema_json` provided here.
-
- Args:
- schema_json: A string containing the JSON schema that the LLM's output
- (the SQLModel description JSON) must conform to. This typically
- comes from "sqlmodel_description_schema.json".
- user_task_description: A natural language description from the user about
- what entities or data structure they want to model.
-
- Returns:
- A string representing the system prompt for SQLModel description generation.
+ Facade class for generating prompts, maintaining compatibility with
+ pipeline components that expect an object instance.
"""
- prompt_parts = [
- "You are an AI assistant tasked with designing one or more SQLModel class definitions.",
- "Your goal is to generate a JSON object that contains a list of SQLModel class descriptions. This description will then be used to generate Python code.",
- "You will be provided with a user's task description and relevant documents (in the user prompt) to inform your design.",
- "\n# REQUIREMENTS FOR YOUR OUTPUT:",
- "1. Your entire output MUST be a single, valid JSON object.",
- "2. This JSON object MUST contain a single top-level key: `sql_models`. The value of this key MUST be a list of JSON objects, where each object in the list describes a single SQLModel.",
- "3. Each object in the `sql_models` list MUST strictly adhere to the following JSON schema for a SQLModel description:",
- "```json",
- schema_json,
- "```",
- "\n# IMPORTANT CONSIDERATIONS FOR DATABASE TABLE MODELS:",
- "The SQLModel you are describing will typically be a database table (this is the default if `is_table_model` is not specified or is `true` in your output JSON).",
- "When defining fields for such table models:",
- "- **Scalar Types:** Standard types like `str`, `int`, `float`, `bool`, `datetime.datetime`, `uuid.UUID` are generally fine.",
- "- **List and Dict Types:** If a field needs to store a list (e.g., `List[str]`) or a dictionary (e.g., `Dict[str, Any]`), these cannot be directly mapped to standard SQL column types. You MUST specify how they should be stored using the `field_options_str` property for that field. The recommended way is to store them as JSON.",
- ' - **Example for `List[str]`:** For a field `tags: List[str]`, you should include this in its description object: `"field_options_str": "Field(default_factory=list, sa_type=JSON)"`',
- ' - **Example for `Dict[str, Any]`:** For a field `metadata: Dict[str, Any]`, include: `"field_options_str": "Field(default_factory=dict, sa_type=JSON)"`',
- '- **Import JSON:** If you use `sa_type=JSON` in any `field_options_str`, you MUST also add `"from sqlmodel import JSON"` to the main `imports` array in your generated JSON description.',
- "Failure to correctly define `List` or `Dict` fields for table models (by not using `field_options_str` with `sa_type=JSON` or a similar valid SQLAlchemy type) will lead to errors.",
- '- **Required Fields and Defaults:** Any field that is NOT `Optional` (e.g., `type: "str"`, `type: "int"`) is a REQUIRED field. For all required fields, you MUST provide a sensible `default` value in its description object to ensure the model can be instantiated for validation. For strings, use `""` as the default. For numbers, use `0` or `0.0`. For booleans, use `false`. Failure to provide a default for a required field will cause the system to crash.',
- "- **Relationships and Foreign Keys:** When modeling relationships (e.g., one-to-many), you must define fields for both the foreign key and the relationship itself.",
- ' - **Foreign Key Field:** The model on the "many" side of a relationship (e.g., `LineItem`) needs a foreign key field. This field MUST be defined as `Optional` with a `default` of `None` to pass validation.',
- ' - **Foreign Key Naming Consistency:** The `foreign_key` value is critical. It MUST be a string in the format `"table_name.column_name"`. The `table_name` part MUST exactly match the `table_name` defined in the parent model. For example, if the `Invoice` model has `"table_name": "invoices"`, then the foreign key in `LineItem` MUST be `"invoices.id"`. A mismatch like `"invoice.id"` will cause a crash.',
- ' - **Relationship Fields:** Both models should have a `Relationship` attribute. The "one" side gets a `List` of the "many" side, and the "many" side gets an `Optional` of the "one" side. Use `field_options_str` to define them. Example for `Invoice`: `{"name": "line_items", "type": "List[\\"LineItem\\"]", "field_options_str": "Relationship(back_populates=\\"invoice\\")"}`. Example for `LineItem`: `{"name": "invoice", "type": "Optional[\\"Invoice\\"]", "field_options_str": "Relationship(back_populates=\\"line_items\\")"}`.',
- ' - **Imports for Relationships:** If you use `Relationship`, you MUST add `"from sqlmodel import Relationship"` to the `imports` array. If you use `List`, you must import it from `typing`.',
- "\n# USER'S TASK:",
- f'The user wants to define a SQLModel based on the following objective: "{user_task_description}"',
- "Consider the documents provided by the user to understand the entities, fields, types, and relationships needed for this model. Pay close attention to the requirements for List/Dict types if the model is a table, and try to provide default values for required fields.",
- "Focus on creating a comprehensive and accurate model description in the JSON format specified by the schema.",
- ]
-
- # Hardcoded example of a SQLModel description JSON
- example_json = """
-{
- "sql_models": [
- {
- "model_name": "ExampleItem",
- "table_name": "example_items",
- "description": "An example item model for demonstration.",
- "fields": [
- {
- "name": "id",
- "type": "Optional[int]",
- "primary_key": true,
- "default": null,
- "nullable": true,
- "description": "The unique identifier for the item."
- },
- {
- "name": "name",
- "type": "str",
- "description": "The name of the item.",
- "max_length": 100,
- "nullable": false
- },
- {
- "name": "quantity",
- "type": "int",
- "description": "The number of items in stock.",
- "default": 0,
- "ge": 0
- },
- {
- "name": "created_at",
- "type": "datetime.datetime",
- "default_factory": "datetime.datetime.utcnow",
- "description": "Timestamp of when the item was created."
- },
- {
- "name": "categories",
- "type": "List[str]",
- "description": "Categories for the item, stored as JSON.",
- "field_options_str": "Field(default_factory=list, sa_type=JSON)"
- }
- ],
- "imports": [
- "from typing import Optional, List",
- "import datetime",
- "from sqlmodel import SQLModel, Field, JSON"
- ]
- }
- ]
-}
-"""
- prompt_parts.extend(
- [
- "\n# EXAMPLE OF A VALID SQLMODEL DESCRIPTION JSON (Illustrating a list of models):",
- "This is an example of the kind of JSON object you should produce (it conforms to the schema above):",
- "```json",
- example_json.strip(),
- "```",
- ]
- )
- prompt_parts.append(
- "\nCarefully analyze the user's task and the provided documents. "
- "Generate only the single JSON object that describes the SQLModels, wrapped in the `sql_models` key. "
- "Do not include any other narrative, explanations, or conversational elements in your output."
- )
+ def __init__(
+ self, model_registry: ModelRegistry, logger: Optional[logging.Logger] = None
+ ):
+ self.model_registry = model_registry
+ self.logger = logger or logging.getLogger(__name__)
+
+ def build_prompts(
+ self,
+ input_strings: List[str],
+ schema_json: str,
+ extraction_example_json: str = "",
+ custom_extraction_process: str = "",
+ custom_extraction_guidelines: str = "",
+ custom_final_checklist: str = "",
+ custom_context: str = "",
+ expected_entity_descriptions: Optional[List[str]] = None,
+ previous_entities: Optional[List[dict]] = None,
+ target_model_name: Optional[str] = None,
+ ) -> Tuple[str, str]:
+ """
+ Builds system and user prompts for extraction.
+ """
+ system_prompt = generate_system_prompt(
+ schema_json=schema_json,
+ extraction_example_json=extraction_example_json,
+ custom_extraction_process=custom_extraction_process,
+ custom_extraction_guidelines=custom_extraction_guidelines,
+ custom_final_checklist=custom_final_checklist,
+ custom_context=custom_context,
+ expected_entity_descriptions=expected_entity_descriptions,
+ previous_entities=previous_entities,
+ target_model_name=target_model_name,
+ )
- return "\n\n".join(prompt_parts).strip()
+ user_prompt = generate_user_prompt_for_docs(input_strings)
+
+ return system_prompt, user_prompt
+
+ def build_structured_prompts(
+ self,
+ input_strings: List[str],
+ custom_extraction_process: str = "",
+ custom_extraction_guidelines: str = "",
+ custom_context: str = "",
+ extraction_example_json: str = "",
+ expected_entity_descriptions: Optional[List[str]] = None,
+ previous_entities: Optional[List[dict]] = None,
+ target_model_name: Optional[str] = None,
+ ) -> Tuple[str, str]:
+ """
+ Builds prompts for structured extraction.
+ """
+ system_prompt = generate_structured_system_prompt(
+ custom_extraction_process=custom_extraction_process,
+ custom_extraction_guidelines=custom_extraction_guidelines,
+ custom_context=custom_context,
+ extraction_example_json=extraction_example_json,
+ expected_entity_descriptions=expected_entity_descriptions,
+ previous_entities=previous_entities,
+ target_model_name=target_model_name,
+ )
+ user_prompt = generate_user_prompt_for_docs(input_strings)
-def generate_prompt_for_example_json_generation(
- target_model_schema_str: str, root_model_name: str
-) -> str:
- """
- Generates a system prompt for guiding an LLM to create a single, valid
- example JSON object based on a provided schema.
+ return system_prompt, user_prompt
- Args:
- target_model_schema_str: A string containing the JSON schema for which
- an example is to be generated.
- root_model_name: The name of the root model/entity this schema represents
- (e.g., "Product", "User"). Used for context in the prompt.
- Returns:
- A string representing the system prompt for example JSON generation.
- """
- prompt_parts = [
- "You are an AI assistant tasked with generating a sample JSON object.",
- f"The goal is to create a single, valid JSON object that conforms to the provided schema for a model named '{root_model_name}' and its related models.",
- "This sample will be used as a few-shot example for another LLM task, so it needs to be accurate and representative.",
- "\n# JSON SCHEMA TO ADHERE TO:",
- "```json",
- target_model_schema_str,
- "```",
- "\n# INSTRUCTIONS FOR YOUR OUTPUT:",
- "1. **Output Content:** Your entire output MUST be a single, valid JSON object.",
- "2. **Output Structure:** Your output MUST be a single JSON object with a top-level key named 'entities'. The value of 'entities' MUST be a list of JSON objects, where each object represents a single data entity.",
- "3. **No Extra Text:** Do NOT include any explanatory text, comments, apologies, markdown formatting (like ```json), or any other content before or after the JSON object.",
- "4. **Schema Compliance:** Strictly adhere to all field names (case-sensitive), data types (string, number, boolean, array, object), and structural requirements defined in the schema for each entity in the 'entities' list.",
- "5. **Entity Metadata:** Each object inside the 'entities' list MUST include two metadata fields:",
- ' * `_type`: This field\'s value MUST be a string matching the name of the model it represents (e.g., "Product", "ProductSpecs").',
- ' * `_temp_id`: This field\'s value MUST be a unique temporary string identifier for that specific entity instance (e.g., "product_example_001", "spec_example_001"). Use these IDs in the `_ref_id` or `_ref_ids` fields to link entities.',
- "6. **Simplicity and Clarity:** The generated example should be simple and illustrative. Populate all other fields (defined in the schema) with plausible, concise, and representative data. Avoid overly complex or lengthy values unless the schema demands it.",
- f"7. **Completeness and Relationships:** Your 'entities' list should contain an instance of the root model (`{root_model_name}`) and at least one instance of each of its related models as described in the schema. For example, if generating an example for a 'Product' that has 'ProductSpecs', the 'entities' list should contain at least one 'Product' object and one 'ProductSpecs' object, linked together using their `_temp_id`s in the appropriate `_ref_id` or `_ref_ids` field.",
- f"\nConsider the schema for '{root_model_name}' and its related models. Generate a representative set of linked entities in the format `{{\"entities\": [...]}}`.",
- "Proceed with generating the JSON object.",
- ]
- return "\n\n".join(prompt_parts).strip()
+__all__ = [
+ "PromptBuilder",
+ "generate_system_prompt",
+ "generate_user_prompt_for_docs",
+ "generate_sqlmodel_creation_system_prompt",
+ "generate_entity_counting_system_prompt",
+ "generate_entity_counting_user_prompt",
+ "generate_prompt_for_example_json_generation",
+]
diff --git a/src/extrai/core/prompts/__init__.py b/src/extrai/core/prompts/__init__.py
new file mode 100644
index 0000000..5898c7b
--- /dev/null
+++ b/src/extrai/core/prompts/__init__.py
@@ -0,0 +1,17 @@
+from .extraction import generate_system_prompt
+from .common import generate_user_prompt_for_docs
+from .sqlmodel import generate_sqlmodel_creation_system_prompt
+from .counting import (
+ generate_entity_counting_system_prompt,
+ generate_entity_counting_user_prompt,
+)
+from .examples import generate_prompt_for_example_json_generation
+
+__all__ = [
+ "generate_system_prompt",
+ "generate_user_prompt_for_docs",
+ "generate_sqlmodel_creation_system_prompt",
+ "generate_entity_counting_system_prompt",
+ "generate_entity_counting_user_prompt",
+ "generate_prompt_for_example_json_generation",
+]
diff --git a/src/extrai/core/prompts/common.py b/src/extrai/core/prompts/common.py
new file mode 100644
index 0000000..d94553b
--- /dev/null
+++ b/src/extrai/core/prompts/common.py
@@ -0,0 +1,36 @@
+from typing import List
+
+
+def generate_user_prompt_for_docs(
+ documents: List[str], custom_context: str = ""
+) -> str:
+ """
+ Generates a generic user prompt containing the documents for extraction.
+ Used by both standard and structured extraction flows.
+
+ Args:
+ documents: A list of strings, where each string is a document
+ or a piece of text for extraction.
+ custom_context: Optional custom contextual information to be included in the prompt.
+
+ Returns:
+ A string representing the user prompt with the documents.
+ """
+ separator = "\n\n---END OF DOCUMENT---\n\n---START OF NEW DOCUMENT---\n\n"
+ combined_documents = separator.join(documents)
+
+ prompt = """
+Please extract information from the following document(s).
+"""
+ if custom_context:
+ prompt += f"\n{custom_context}\n"
+
+ prompt += f"""
+# DOCUMENT(S) FOR EXTRACTION:
+
+{combined_documents}
+
+---
+Remember: Your output must be only a single, valid JSON object.
+""".strip()
+ return prompt
diff --git a/src/extrai/core/prompts/counting.py b/src/extrai/core/prompts/counting.py
new file mode 100644
index 0000000..6dafaf9
--- /dev/null
+++ b/src/extrai/core/prompts/counting.py
@@ -0,0 +1,105 @@
+import json
+from typing import List, Dict, Any, Optional
+
+
+def generate_entity_counting_system_prompt(
+ model_names: list[str],
+ schema_json: str = None,
+ custom_counting_context: str = "",
+ previous_entities: Optional[List[Dict[str, Any]]] = None,
+) -> str:
+ """
+ Generates a system prompt for counting entities in the provided documents.
+
+ Args:
+ model_names: A list of names of the models/entities to count.
+ schema_json: A string containing the JSON schema for the models.
+ This helps the LLM understand the structure of the entities to count.
+ custom_counting_context: Optional custom context to guide the counting phase.
+ previous_entities: Optional list of previously extracted entities for context.
+
+ Returns:
+ A string representing the system prompt for entity counting.
+ """
+ model_list_str = ", ".join(model_names)
+ prompt = f"""
+You are an expert data analyst. Your task is to analyze the provided documents and count the occurrences of specific entities.
+
+You need to count the following entities: {model_list_str}.
+"""
+
+ if custom_counting_context:
+ prompt += f"""
+# CUSTOM CONTEXT:
+{custom_counting_context}
+"""
+
+ if previous_entities:
+ entities_json = json.dumps(previous_entities, indent=2)
+ prompt += f"""
+# PREVIOUSLY EXTRACTED ENTITIES:
+{entities_json}
+
+IMPORTANT: If the entities you are counting are related to any of the previously extracted entities above, you MUST specify the unique ID (or temp_id) of that related entity in your description string. This ensures correct linking in subsequent steps.
+"""
+
+ prompt += f"""
+# ENTITY DEFINITIONS:
+To help you identify these entities correctly, here are their schema definitions:
+```json
+{schema_json}
+```
+"""
+
+ prompt += """
+# OUTPUT INSTRUCTIONS:
+1. **Output Format:** Your output must be a single, valid JSON object.
+2. **Keys:** The JSON object keys must be the exact names of the entities provided above.
+3. **Values:** The values must be a list of strings, where each string is a description of the entity found.
+4. **Order:** The order of the descriptions in the list must match the order of appearance in the document.
+5. **Relational Detail:** If an entity relates to a previously extracted entity (e.g., a child entity belonging to a parent), your description MUST include the ID of that parent entity from the provided context.
+6. **No Extra Text:** Do NOT include any explanations, markdown formatting, or text outside the JSON object.
+
+Example Output:
+{{
+ "Invoice": [
+ "Invoice #123 from ABC Corp with a value of 50euros",
+ "Invoice #456 from XYZ Inc with a value of 506euros",
+ "Invoice #789 from Foo Bar with a value of 30euros"
+ ],
+ "LineItem": [
+ "Item A - Widget linked to Invoice ID: invoice_123",
+ "Item B - Gadget linked to Invoice ID: invoice_123",
+ "Item C - Doohickey linked to Invoice ID: invoice_456",
+ ]
+}}
+
+Proceed with identifying and describing the entities in the user's documents.
+""".strip()
+ return prompt
+
+
+def generate_entity_counting_user_prompt(documents: list[str]) -> str:
+ """
+ Generates a user prompt containing the documents for entity counting.
+
+ Args:
+ documents: A list of strings, where each string is a document.
+
+ Returns:
+ A string representing the user prompt.
+ """
+ separator = "\n\n---END OF DOCUMENT---\n\n---START OF NEW DOCUMENT---\n\n"
+ combined_documents = separator.join(documents)
+
+ prompt = f"""
+Please count the entities in the following document(s) according to the instructions in the system prompt.
+
+# DOCUMENT(S) TO ANALYZE:
+
+{combined_documents}
+
+---
+Remember: Your output must be only a single, valid JSON object mapping entity names to counts.
+""".strip()
+ return prompt
diff --git a/src/extrai/core/prompts/examples.py b/src/extrai/core/prompts/examples.py
new file mode 100644
index 0000000..14d7a86
--- /dev/null
+++ b/src/extrai/core/prompts/examples.py
@@ -0,0 +1,38 @@
+def generate_prompt_for_example_json_generation(
+ target_model_schema_str: str, root_model_name: str
+) -> str:
+ """
+ Generates a system prompt for guiding an LLM to create a single, valid
+ example JSON object based on a provided schema.
+
+ Args:
+ target_model_schema_str: A string containing the JSON schema for which
+ an example is to be generated.
+ root_model_name: The name of the root model/entity this schema represents
+ (e.g., "Product", "User"). Used for context in the prompt.
+
+ Returns:
+ A string representing the system prompt for example JSON generation.
+ """
+ prompt_parts = [
+ "You are an AI assistant tasked with generating a sample JSON object.",
+ f"The goal is to create a single, valid JSON object that conforms to the provided schema for a model named '{root_model_name}' and its related models.",
+ "This sample will be used as a few-shot example for another LLM task, so it needs to be accurate and representative.",
+ "\n# JSON SCHEMA TO ADHERE TO:",
+ "```json",
+ target_model_schema_str,
+ "```",
+ "\n# INSTRUCTIONS FOR YOUR OUTPUT:",
+ "1. **Output Content:** Your entire output MUST be a single, valid JSON object.",
+ "2. **Output Structure:** Your output MUST be a single JSON object with a top-level key named 'entities'. The value of 'entities' MUST be a list of JSON objects, where each object represents a single data entity.",
+ "3. **No Extra Text:** Do NOT include any explanatory text, comments, apologies, markdown formatting (like ```json), or any other content before or after the JSON object.",
+ "4. **Schema Compliance:** Strictly adhere to all field names (case-sensitive), data types (string, number, boolean, array, object), and structural requirements defined in the schema for each entity in the 'entities' list.",
+ "5. **Entity Metadata:** Each object inside the 'entities' list MUST include two metadata fields:",
+ ' * `_type`: This field\'s value MUST be a string matching the name of the model it represents (e.g., "Product", "ProductSpecs").',
+ ' * `_temp_id`: This field\'s value MUST be a unique temporary string identifier for that specific entity instance (e.g., "product_example_001", "spec_example_001"). Use these IDs in the `_ref_id` or `_ref_ids` fields to link entities.',
+ "6. **Simplicity and Clarity:** The generated example should be simple and illustrative. Populate all other fields (defined in the schema) with plausible, concise, and representative data. Avoid overly complex or lengthy values unless the schema demands it.",
+ f"7. **Completeness and Relationships:** Your 'entities' list should contain an instance of the root model (`{root_model_name}`) and at least one instance of each of its related models as described in the schema. For example, if generating an example for a 'Product' that has 'ProductSpecs', the 'entities' list should contain at least one 'Product' object and one 'ProductSpecs' object, linked together using their `_temp_id`s in the appropriate `_ref_id` or `_ref_ids` field.",
+ f"\nConsider the schema for '{root_model_name}' and its related models. Generate a representative set of linked entities in the format `{{\"entities\": [...]}}`.",
+ "Proceed with generating the JSON object.",
+ ]
+ return "\n\n".join(prompt_parts).strip()
diff --git a/src/extrai/core/prompts/extraction.py b/src/extrai/core/prompts/extraction.py
new file mode 100644
index 0000000..8bca2a2
--- /dev/null
+++ b/src/extrai/core/prompts/extraction.py
@@ -0,0 +1,166 @@
+import json
+from typing import Optional, List, Dict, Any
+
+
+def generate_system_prompt(
+ schema_json: str,
+ extraction_example_json: str = "",
+ custom_extraction_process: str = "",
+ custom_extraction_guidelines: str = "",
+ custom_final_checklist: str = "",
+ custom_context: str = "",
+ expected_entity_descriptions: Optional[List[str]] = None,
+ previous_entities: Optional[List[Dict[str, Any]]] = None,
+ target_model_name: Optional[str] = None,
+) -> str:
+ """
+ Generates a generic system prompt for guiding an LLM to extract information
+ from text and structure it according to a provided JSON schema.
+
+ Args:
+ schema_json: A string containing the JSON schema for the target data structure.
+ extraction_example_json: An optional string containing an example of a JSON
+ object that conforms to the schema.
+ custom_extraction_process: Optional custom instructions for the extraction process.
+ custom_extraction_guidelines: Optional custom guidelines for extraction.
+ custom_final_checklist: Optional custom final checklist for the LLM.
+ custom_context: Optional custom contextual information to be included in the prompt.
+ expected_entity_descriptions: Optional list of descriptions for the entities to be extracted.
+ previous_entities: List of previously extracted entities for hierarchical linking.
+ target_model_name: Name of the specific model to extract (for hierarchical steps).
+
+ Returns:
+ A string representing the system prompt.
+ """
+
+ default_extraction_process = """\
+# EXTRACTION PROCESS
+Follow this step-by-step process meticulously:
+1. **Understand the Goal:** Your primary objective is to extract information from the provided text and structure it precisely according to the JSON schema.
+2. **Full Text Analysis:** Read and comprehend the entirety of the provided document(s) before initiating extraction. This helps in understanding context and relationships.
+3. **Schema Adherence:** The provided JSON schema is your definitive guide. All extracted data must conform to this schema in terms of structure, field names, and data types.
+4. **Identify Relevant Data:** Locate all data points within the text that correspond to the fields defined in the JSON schema.
+5. **Map Data to Schema:** Carefully assign the identified data to the correct fields in the schema.
+6. **Handle Ambiguity and Missing Information:**
+ * If information for a field is ambiguous, use your reasoning capabilities to determine the most plausible interpretation based on the context.
+ * If information for an optional field is not present, omit the field or use `null` if the schema allows.
+ * For required fields, if information is genuinely missing and cannot be inferred, this is a critical issue. However, strive to find or infer it. If the schema defines a default, consider that.
+7. **Prioritize Explicit Information:** Base your extraction on information explicitly stated in the text. Avoid making assumptions unless absolutely necessary and clearly justifiable by the context.
+8. **Synthesize from Multiple Documents:** If multiple documents are provided, synthesize the information comprehensively. If conflicting information arises, prioritize what appears to be the most current, official, or reliable source. Note any significant discrepancies if the output format allows, but the primary goal is a single coherent JSON.
+9. **Data Type Conformance:** Strictly adhere to the data types specified in the JSON schema (e.g., string, number, boolean, array, object). Numbers should be formatted as numbers (e.g., `123`, `12.34`), not strings containing numbers (e.g., `"123"`). Booleans should be `true` or `false`.
+10. **Nested Structures and Relationships:**
+ * For nested objects or arrays, ensure your JSON output accurately reflects the hierarchical structure defined in the schema.
+ * If the schema implies relationships between different entities (e.g., using foreign keys or requiring linking), ensure these are correctly represented.
+ * If temporary identifiers are needed to link entities within the JSON output, generate unique and descriptive temporary IDs based on the entity's key attributes.
+11. **ID and Temporary ID Generation Directives:**
+ * **Explicit IDs:** If the text contains an explicit identifier for an entity (e.g., "ID: 123", "Code: A-55"), use it for the `id` field if the schema has one.
+ * **Temporary IDs:** When generating temporary IDs for linking entities (e.g., for `temp_id`, `_id`, or foreign keys):
+ * **Format:** You MUST use the format `[entity_type]_[key_attribute]` in `snake_case`. E.g., `user_john_doe`, `order_12345`.
+ * **Determinism:** Do NOT use random strings (like UUIDs) or simple counters (like `item_1`) unless there is absolutely no distinguishing attribute. Random values make consistency checking impossible.
+ * **Sanitization:** Convert to lowercase and replace spaces/special characters with underscores.
+ * **Consistency:** If the same entity appears multiple times, it MUST have the identical temporary ID every time.
+"""
+
+ default_extraction_guidelines = """\
+# IMPORTANT EXTRACTION GUIDELINES
+- **Ordering:** Maintain the order of items as they appear in the source text when populating arrays.
+- **Output Format:** Your entire output must be a single, valid JSON object. Do not include any other explanatory text, comments, apologies, or any other content before or after the JSON object.
+- **Output Structure Mandate:** Your response MUST be a single JSON object. This object MUST have a single top-level key named "result". The value of this "result" key MUST be the JSON object that conforms to the provided JSON schema. Example: `{"result": {your_schema_compliant_object_here}}`. Do NOT use any other top-level keys. Do NOT return the schema-compliant object directly as the root.
+- **Field Names:** Use the exact field names (case-sensitive) as specified in the JSON schema for the object under the "result" key.
+- **Structured Elements:** Pay close attention to structured elements within the text, such as tables, lists, headings, and emphasized text, as they often contain key information.
+- **Dates and Times:** Unless the schema specifies a different format, use ISO 8601 format for dates (YYYY-MM-DD) and date-times (YYYY-MM-DDTHH:MM:SSZ).
+- **Enumerations (Enums):** If a field in the schema is an enumeration with a predefined set of allowed values, ensure that the extracted value is one of those permitted values.
+- **Null Values:** Use `null` for optional fields where data is not available or not applicable, provided the schema allows for null values for that field. Do not use strings like "N/A", "Not available", or empty strings "" unless the schema explicitly defines such string literals as valid values.
+- **String Values:** Ensure all string values in the JSON are correctly escaped (e.g., quotes within strings).
+- **Foreign Key Fields:** If a model has a required foreign key field (e.g., `object_id`) and you are establishing the relationship using a temporary ID field (e.g., `airline_ref_id`), you MUST provide a placeholder value (e.g., `0`) for the foreign key field if the schema requires it. This ensures the JSON remains valid against the schema constraints.
+- **ID Consistency:** Ensure that `id` and `temp_id` values are consistent throughout the JSON. If you refer to `user_john_doe` in one place, do not refer to them as `user_john` elsewhere. Avoid generating random UUIDs or hashes for IDs unless explicitly instructed. Prefer human-readable, content-derived IDs for temporary linking.
+- **Meticulousness:** Accuracy is paramount. Double-check your extracted data against the source text and the schema before finalizing your output.
+"""
+
+ default_final_checklist = """\
+# FINAL CHECK BEFORE SUBMISSION
+1. **Valid JSON?** Is the entire output a single, syntactically correct JSON object?
+2. **Output Structure Correct?** Does the output JSON object have a single top-level key named "result"?
+3. **Schema Conformity?** Does the JSON object under the "result" key strictly adhere to all aspects of the provided JSON schema (all required fields present, correct data types for all values, correct structure for nested objects and arrays)?
+4. **Field Name Accuracy?** Are all field names within the object under the "result" key exactly as specified in the schema (case-sensitive)?
+5. **Relationship Integrity?** If temporary IDs or other linking mechanisms were required within the object under the "result" key, are they used correctly and consistently?
+6. **Null Handling?** Are `null` values used appropriately for missing optional data, according to schema constraints?
+7. **No Extraneous Text?** Is there absolutely no text or characters outside of the main JSON object itself?
+"""
+
+ # Use custom instructions if provided, otherwise use defaults
+ extraction_process = custom_extraction_process or default_extraction_process
+ extraction_guidelines = (
+ custom_extraction_guidelines or default_extraction_guidelines
+ )
+ final_checklist = custom_final_checklist or default_final_checklist
+
+ prompt_parts = [
+ "You are an advanced AI specializing in data extraction and structuring. Your task is to analyze user-provided text and transform the relevant information into a structured JSON object, strictly adhering to the provided JSON schema.",
+ "You must focus on precision, accuracy, and complete adherence to the schema.",
+ "\n# JSON SCHEMA TO ADHERE TO:",
+ "```json",
+ schema_json,
+ "```",
+ ]
+
+ if target_model_name:
+ prompt_parts.append("\n# TARGET ENTITY")
+ prompt_parts.append(
+ f"Your task is to extract **only** entities of type '{target_model_name}'. "
+ "Do not extract other entity types in this step."
+ )
+
+ if expected_entity_descriptions:
+ prompt_parts.append("\n# EXPECTED ENTITIES & ORDER:")
+ prompt_parts.append(
+ "You MUST extract entities matching the following descriptions, in this exact order:"
+ )
+ for i, desc in enumerate(expected_entity_descriptions, 1):
+ prompt_parts.append(f"{i}. {desc}")
+ prompt_parts.append(
+ f"\nYou must extract EXACTLY {len(expected_entity_descriptions)} items/entities corresponding to these descriptions."
+ )
+
+ if custom_context:
+ prompt_parts.append("\n# ADDITIONAL CONTEXT:")
+ prompt_parts.append(custom_context)
+
+ if previous_entities:
+ entities_json = json.dumps(previous_entities, indent=2)
+ prompt_parts.append("\n# PREVIOUSLY EXTRACTED ENTITIES:")
+ prompt_parts.append(entities_json)
+ prompt_parts.append(
+ "\nIMPORTANT: Use the 'id' values from the entities above to populate foreign key fields "
+ "(e.g. 'recipe_id') in the new entities you extract. Ensure correct linking."
+ )
+
+ prompt_parts.extend([f"\n{extraction_process}", f"\n{extraction_guidelines}"])
+
+ if extraction_example_json:
+ prompt_parts.append("\n# EXAMPLE OF EXTRACTION:")
+ prompt_parts.append(
+ "## CONCEPTUAL INPUT TEXT (This is illustrative; your actual input text will be different):"
+ )
+ prompt_parts.append(
+ "\"Imagine a piece of text that contains details about an entity or event. For instance, if the schema is about a 'Book', the text might say: 'The Great Novel, written by Jane Author in 2023, has 300 pages and is published by World Publishers. ISBN: 978-0123456789.'\""
+ )
+ prompt_parts.append(
+ "## EXAMPLE EXTRACTED JSON (This JSON conforms to the schema based on the conceptual text above):"
+ )
+ prompt_parts.append("```json")
+
+ if extraction_example_json.strip().startswith(
+ "{"
+ ) and extraction_example_json.strip().endswith("}"):
+ prompt_parts.append(f'{{\n "result": {extraction_example_json}\n}}')
+ else:
+ prompt_parts.append(extraction_example_json)
+ prompt_parts.append("```")
+
+ prompt_parts.append(f"\n{final_checklist}")
+ prompt_parts.append(
+ "\nProceed with the extraction based on the user's documents. Your response MUST be only the single, valid JSON object. Do not include any other narrative, explanations, or conversational elements in your output."
+ )
+
+ return "\n\n".join(prompt_parts).strip()
diff --git a/src/extrai/core/prompts/sqlmodel.py b/src/extrai/core/prompts/sqlmodel.py
new file mode 100644
index 0000000..bb2322a
--- /dev/null
+++ b/src/extrai/core/prompts/sqlmodel.py
@@ -0,0 +1,123 @@
+def generate_sqlmodel_creation_system_prompt(
+ schema_json: str, user_task_description: str
+) -> str:
+ """
+ Generates a specialized system prompt for guiding an LLM to create a
+ SQLModel class description (as a JSON object).
+
+ The LLM will be given input documents (via the user prompt) and this system
+ prompt. Its goal is to produce a JSON object that describes a new SQLModel,
+ and this JSON object must conform to the `schema_json` provided here.
+
+ Args:
+ schema_json: A string containing the JSON schema that the LLM's output
+ (the SQLModel description JSON) must conform to. This typically
+ comes from "sqlmodel_description_schema.json".
+ user_task_description: A natural language description from the user about
+ what entities or data structure they want to model.
+
+ Returns:
+ A string representing the system prompt for SQLModel description generation.
+ """
+ prompt_parts = [
+ "You are an AI assistant tasked with designing one or more SQLModel class definitions.",
+ "Your goal is to generate a JSON object that contains a list of SQLModel class descriptions. This description will then be used to generate Python code.",
+ "You will be provided with a user's task description and relevant documents (in the user prompt) to inform your design.",
+ "\n# REQUIREMENTS FOR YOUR OUTPUT:",
+ "1. Your entire output MUST be a single, valid JSON object.",
+ "2. This JSON object MUST contain a single top-level key: `sql_models`. The value of this key MUST be a list of JSON objects, where each object in the list describes a single SQLModel.",
+ "3. Each object in the `sql_models` list MUST strictly adhere to the following JSON schema for a SQLModel description:",
+ "```json",
+ schema_json,
+ "```",
+ "\n# IMPORTANT CONSIDERATIONS FOR DATABASE TABLE MODELS:",
+ "The SQLModel you are describing will typically be a database table (this is the default if `is_table_model` is not specified or is `true` in your output JSON).",
+ "When defining fields for such table models:",
+ "- **Scalar Types:** Standard types like `str`, `int`, `float`, `bool`, `datetime.datetime`, `uuid.UUID` are generally fine.",
+ "- **List and Dict Types:** If a field needs to store a list (e.g., `List[str]`) or a dictionary (e.g., `Dict[str, Any]`), these cannot be directly mapped to standard SQL column types. You MUST specify how they should be stored using the `field_options_str` property for that field. The recommended way is to store them as JSON.",
+ ' - **Example for `List[str]`:** For a field `tags: List[str]`, you should include this in its description object: `"field_options_str": "Field(default_factory=list, sa_type=JSON)"`',
+ ' - **Example for `Dict[str, Any]`:** For a field `metadata: Dict[str, Any]`, include: `"field_options_str": "Field(default_factory=dict, sa_type=JSON)"`',
+ '- **Import JSON:** If you use `sa_type=JSON` in any `field_options_str`, you MUST also add `"from sqlmodel import JSON"` to the main `imports` array in your generated JSON description.',
+ "Failure to correctly define `List` or `Dict` fields for table models (by not using `field_options_str` with `sa_type=JSON` or a similar valid SQLAlchemy type) will lead to errors.",
+ '- **Required Fields and Defaults:** Any field that is NOT `Optional` (e.g., `type: "str"`, `type: "int"`) is a REQUIRED field. For all required fields, you MUST provide a sensible `default` value in its description object to ensure the model can be instantiated for validation. For strings, use `""` as the default. For numbers, use `0` or `0.0`. For booleans, use `false`. Failure to provide a default for a required field will cause the system to crash.',
+ "- **Relationships and Foreign Keys:** When modeling relationships (e.g., one-to-many), you must define fields for both the foreign key and the relationship itself.",
+ ' - **Foreign Key Field:** The model on the "many" side of a relationship (e.g., `LineItem`) needs a foreign key field. This field MUST be defined as `Optional` with a `default` of `None` to pass validation.',
+ ' - **Foreign Key Naming Consistency:** The `foreign_key` value is critical. It MUST be a string in the format `"table_name.column_name"`. The `table_name` part MUST exactly match the `table_name` defined in the parent model. For example, if the `Invoice` model has `"table_name": "invoices"`, then the foreign key in `LineItem` MUST be `"invoices.id"`. A mismatch like `"invoice.id"` will cause a crash.',
+ ' - **Relationship Fields:** Both models should have a `Relationship` attribute. The "one" side gets a `List` of the "many" side, and the "many" side gets an `Optional` of the "one" side. Use `field_options_str` to define them. Example for `Invoice`: `{"name": "line_items", "type": "List[\\"LineItem\\"]", "field_options_str": "Relationship(back_populates=\\"invoice\\")"}`. Example for `LineItem`: `{"name": "invoice", "type": "Optional[\\"Invoice\\"]", "field_options_str": "Relationship(back_populates=\\"line_items\\")"}`.',
+ ' - **Imports for Relationships:** If you use `Relationship`, you MUST add `"from sqlmodel import Relationship"` to the `imports` array. If you use `List`, you must import it from `typing`.',
+ "\n# USER'S TASK:",
+ f'The user wants to define a SQLModel based on the following objective: "{user_task_description}"',
+ "Consider the documents provided by the user to understand the entities, fields, types, and relationships needed for this model. Pay close attention to the requirements for List/Dict types if the model is a table, and try to provide default values for required fields.",
+ "Focus on creating a comprehensive and accurate model description in the JSON format specified by the schema.",
+ ]
+
+ # Hardcoded example of a SQLModel description JSON
+ example_json = """
+{
+ "sql_models": [
+ {
+ "model_name": "ExampleItem",
+ "table_name": "example_items",
+ "description": "An example item model for demonstration.",
+ "fields": [
+ {
+ "name": "id",
+ "type": "Optional[int]",
+ "primary_key": true,
+ "default": null,
+ "nullable": true,
+ "description": "The unique identifier for the item."
+ },
+ {
+ "name": "name",
+ "type": "str",
+ "description": "The name of the item.",
+ "max_length": 100,
+ "nullable": false
+ },
+ {
+ "name": "quantity",
+ "type": "int",
+ "description": "The number of items in stock.",
+ "default": 0,
+ "ge": 0
+ },
+ {
+ "name": "created_at",
+ "type": "datetime.datetime",
+ "default_factory": "datetime.datetime.utcnow",
+ "description": "Timestamp of when the item was created."
+ },
+ {
+ "name": "categories",
+ "type": "List[str]",
+ "description": "Categories for the item, stored as JSON.",
+ "field_options_str": "Field(default_factory=list, sa_type=JSON)"
+ }
+ ],
+ "imports": [
+ "from typing import Optional, List",
+ "import datetime",
+ "from sqlmodel import SQLModel, Field, JSON"
+ ]
+ }
+ ]
+}
+"""
+ prompt_parts.extend(
+ [
+ "\n# EXAMPLE OF A VALID SQLMODEL DESCRIPTION JSON (Illustrating a list of models):",
+ "This is an example of the kind of JSON object you should produce (it conforms to the schema above):",
+ "```json",
+ example_json.strip(),
+ "```",
+ ]
+ )
+
+ prompt_parts.append(
+ "\nCarefully analyze the user's task and the provided documents. "
+ "Generate only the single JSON object that describes the SQLModels, wrapped in the `sql_models` key. "
+ "Do not include any other narrative, explanations, or conversational elements in your output."
+ )
+
+ return "\n\n".join(prompt_parts).strip()
diff --git a/src/extrai/core/prompts/structured_extraction.py b/src/extrai/core/prompts/structured_extraction.py
new file mode 100644
index 0000000..8b8b430
--- /dev/null
+++ b/src/extrai/core/prompts/structured_extraction.py
@@ -0,0 +1,88 @@
+import json
+from typing import Optional, List, Dict, Any
+
+
+def generate_structured_system_prompt(
+ custom_extraction_process: str = "",
+ custom_extraction_guidelines: str = "",
+ custom_context: str = "",
+ extraction_example_json: str = "",
+ expected_entity_descriptions: Optional[List[str]] = None,
+ previous_entities: Optional[List[Dict[str, Any]]] = None,
+ target_model_name: Optional[str] = None,
+) -> str:
+ """
+ Generates a system prompt tailored for structured output extraction.
+ Simplified instructions as the structure is enforced by the API.
+
+ Args:
+ custom_extraction_process: Optional custom instructions for the extraction process.
+ custom_extraction_guidelines: Optional custom guidelines for extraction.
+ custom_context: Optional custom contextual information.
+ extraction_example_json: Optional example JSON string.
+ expected_entity_descriptions: Optional list of descriptions for the entities to be extracted.
+ previous_entities: List of previously extracted entities for hierarchical linking.
+ target_model_name: Name of the specific model to extract (for hierarchical steps).
+
+ Returns:
+ A string representing the system prompt.
+ """
+
+ default_instructions = """\
+# EXTRACTION INSTRUCTIONS
+You are an expert data extraction AI. Your goal is to extract structured data from the provided text.
+
+1. **Analyze the Text:** Read the provided documents carefully.
+2. **Extract Entities:** Identify all entities that match the requested structure.
+3. **Accuracy:** Ensure all extracted data is accurate and supported by the text.
+4. **Inference:** If a field is missing but can be reasonably inferred from context, you may do so. Otherwise, leave it as null/None.
+5. **Relationships:** Capture relationships by nesting entities as defined in the structure.
+"""
+
+ parts = [default_instructions]
+
+ if target_model_name:
+ parts.append("# TARGET ENTITY")
+ parts.append(
+ f"Your task is to extract **only** entities of type '{target_model_name}'. "
+ "Do not extract other entity types in this step."
+ )
+
+ if expected_entity_descriptions:
+ parts.append("# EXPECTED ENTITIES & ORDER")
+ parts.append(
+ "You MUST extract entities matching the following descriptions, in this exact order:"
+ )
+ for i, desc in enumerate(expected_entity_descriptions, 1):
+ parts.append(f"{i}. {desc}")
+ parts.append(
+ f"\nYou must extract EXACTLY {len(expected_entity_descriptions)} items/entities corresponding to these descriptions."
+ )
+
+ # Assemble comprehensive custom instructions
+ instructions_parts = []
+ if custom_extraction_process:
+ instructions_parts.append(custom_extraction_process)
+
+ if custom_context:
+ instructions_parts.append(f"CONTEXT:\n{custom_context}")
+
+ if previous_entities:
+ entities_json = json.dumps(previous_entities, indent=2)
+ instructions_parts.append(
+ f"PREVIOUSLY EXTRACTED ENTITIES:\n{entities_json}\n\n"
+ "IMPORTANT: Use the 'id' values from the entities above to populate foreign key fields "
+ "(e.g. 'recipe_id') in the new entities you extract. Ensure correct linking."
+ )
+
+ if custom_extraction_guidelines:
+ instructions_parts.append(f"GUIDELINES:\n{custom_extraction_guidelines}")
+
+ if extraction_example_json:
+ instructions_parts.append(f"EXAMPLE REFERENCE:\n{extraction_example_json}")
+
+ if instructions_parts:
+ parts.append("# CUSTOM INSTRUCTIONS")
+ parts.append("\n\n".join(instructions_parts))
+
+ return "\n\n".join(parts)
diff --git a/src/extrai/core/result_processor.py b/src/extrai/core/result_processor.py
new file mode 100644
index 0000000..d106e91
--- /dev/null
+++ b/src/extrai/core/result_processor.py
@@ -0,0 +1,701 @@
+import logging
+import uuid
+from typing import (
+ List,
+ Dict,
+ Any,
+ Optional,
+ Type,
+ get_origin,
+ get_args,
+ Union,
+ NamedTuple,
+)
+from sqlalchemy.orm import Session
+from sqlalchemy import create_engine, inspect
+from sqlalchemy.exc import SQLAlchemyError
+from sqlmodel import SQLModel
+
+from .model_registry import ModelRegistry
+from .errors import HydrationError, WorkflowError
+
+SQLModelInstance = SQLModel
+
+
+class DatabaseWriterError(Exception):
+ """Custom exception for database writer errors."""
+
+ pass
+
+
+class PrimaryKeyInfo(NamedTuple):
+ name: Optional[str]
+ type: Optional[Type[Any]]
+ has_uuid_factory: bool
+
+
+class DirectHydrator:
+ """
+ Hydrates SQLModel objects directly from structured nested dictionaries.
+ Used when the LLM output is guaranteed to match the model structure (e.g. Structured Output).
+ Does not require _temp_id or _type fields.
+ Supports recursive hydration of nested relationships.
+ """
+
+ def __init__(
+ self,
+ session: Session,
+ logger: Optional[logging.Logger] = None,
+ original_pk_map: Dict[tuple[str, Any], SQLModelInstance] = None,
+ all_instances: List[SQLModelInstance] = None,
+ ):
+ self.session = session
+ self.logger = logger or logging.getLogger(__name__)
+ self.original_pk_map = original_pk_map if original_pk_map is not None else {}
+ self.all_instances = all_instances if all_instances is not None else []
+
+ def hydrate(
+ self,
+ data: List[Dict[str, Any]],
+ model_map: Dict[str, Type[SQLModel]],
+ default_model_class: Optional[Type[SQLModel]] = None,
+ ) -> List[SQLModelInstance]:
+ instances = []
+ for item in data:
+ try:
+ # Determine model class
+ _type = item.get("_type")
+ model_class = None
+
+ if _type and _type in model_map:
+ model_class = model_map[_type]
+ elif default_model_class:
+ model_class = default_model_class
+
+ if not model_class:
+ raise ValueError(
+ f"Could not determine model class for item (missing _type and no default): {item}"
+ )
+
+ instance = self._hydrate_recursive(item, model_class, model_map)
+ self.session.add(instance)
+ instances.append(instance)
+ except Exception as e:
+ self.logger.error(
+ f"Failed to hydrate item directly: {e}", exc_info=True
+ )
+ raise ValueError(f"Direct hydration failed: {e}") from e
+ return instances
+
+ def _hydrate_recursive(
+ self,
+ data: Dict[str, Any],
+ model_class: Type[SQLModel],
+ model_map: Dict[str, Type[SQLModel]],
+ ) -> SQLModelInstance:
+ """
+ Recursively hydrates an instance and its relationships.
+ """
+ # 1. Identify relationship fields
+ mapper = inspect(model_class)
+ relationships = {r.key: r for r in mapper.relationships}
+
+ # 2. Separate scalar data from relationship data
+ scalar_data = {}
+ relation_data = {}
+
+ for k, v in data.items():
+ if k in relationships:
+ relation_data[k] = v
+ else:
+ scalar_data[k] = v
+
+ pk_field_name = None
+ for field_name, model_field in model_class.model_fields.items():
+ if getattr(model_field, "primary_key", False):
+ pk_field_name = field_name
+ break
+
+ # Capture Original PK
+ if pk_field_name and pk_field_name in scalar_data:
+ original_pk = scalar_data[pk_field_name]
+ if original_pk is not None:
+ # We store it temporarily, will map to instance after creation
+ # Note: We need the type name. Assuming _type is in data or model_class.__name__
+ type_name = data.get("_type", model_class.__name__)
+ self.original_pk_map[(type_name, original_pk)] = None
+
+ del scalar_data[pk_field_name]
+
+ if "_type" in scalar_data:
+ del scalar_data["_type"]
+
+ instance = model_class.model_validate(scalar_data)
+
+ # Map original PK to instance
+ if pk_field_name and pk_field_name in data: # check original data
+ original_pk = data[pk_field_name]
+ if original_pk is not None:
+ type_name = data.get("_type", model_class.__name__)
+ self.original_pk_map[(type_name, original_pk)] = instance
+
+ self.all_instances.append(instance)
+
+ # 4. Populate relationships
+ for rel_key, rel_value in relation_data.items():
+ if rel_value is None:
+ setattr(instance, rel_key, None)
+ continue
+
+ rel_prop = relationships[rel_key]
+ target_class = rel_prop.mapper.class_
+
+ if isinstance(rel_value, list):
+ # One-to-Many / Many-to-Many
+ related_instances = []
+ for child_data in rel_value:
+ if isinstance(child_data, dict):
+ # Handle polymorphism in child
+ child_class = target_class
+ if "_type" in child_data and child_data["_type"] in model_map:
+ child_class = model_map[child_data["_type"]]
+
+ child_instance = self._hydrate_recursive(
+ child_data, child_class, model_map
+ )
+ related_instances.append(child_instance)
+ setattr(instance, rel_key, related_instances)
+
+ elif isinstance(rel_value, dict):
+ # Many-to-One / One-to-One
+ child_class = target_class
+ if "_type" in rel_value and rel_value["_type"] in model_map:
+ child_class = model_map[rel_value["_type"]]
+
+ child_instance = self._hydrate_recursive(
+ rel_value, child_class, model_map
+ )
+ setattr(instance, rel_key, child_instance)
+
+ return instance
+
+
+class SQLAlchemyHydrator:
+ """
+ Hydrates SQLModel objects from consensus JSON data.
+ It uses a two-pass strategy: first, create all object instances,
+ then link their relationships using temporary IDs.
+ """
+
+ def __init__(
+ self,
+ session: Session,
+ logger: Optional[logging.Logger] = None,
+ original_pk_map: Dict[tuple[str, Any], SQLModelInstance] = None,
+ all_instances: List[SQLModelInstance] = None,
+ ):
+ """
+ Initializes the Hydrator.
+
+ Args:
+ session: The SQLAlchemy session to use for database operations
+ and instance management (e.g., adding instances).
+ logger: Optional logger instance.
+ """
+ self.session: Session = session
+ self.temp_id_to_instance_map: Dict[
+ str, SQLModelInstance
+ ] = {} # Stores _temp_id -> SQLModel instance
+ self.original_pk_map = original_pk_map if original_pk_map is not None else {}
+ self.all_instances = all_instances if all_instances is not None else []
+ self.logger = logger or logging.getLogger(__name__)
+
+ def _filter_special_fields(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """Removes _temp_id, _type, and relationship reference fields before Pydantic validation."""
+ return {
+ k: v
+ for k, v in data.items()
+ if k not in ["_temp_id", "_type"]
+ and not k.endswith("_ref_id")
+ and not k.endswith("_ref_ids")
+ }
+
+ def _validate_entities_list(self, entities_list: List[Dict[str, Any]]) -> None:
+ """Performs initial validation on the input entities list."""
+ if not isinstance(entities_list, list):
+ raise TypeError(
+ f"Input 'entities_list' must be a list. Got: {type(entities_list)}"
+ )
+ if not all(isinstance(item, dict) for item in entities_list):
+ first_non_dict = next(
+ (item for item in entities_list if not isinstance(item, dict)), None
+ )
+ raise ValueError(
+ "All items in 'entities_list' must be dictionaries. "
+ f"Found an item of type: {type(first_non_dict)}."
+ )
+
+ def _get_primary_key_info(self, model_class: Type[SQLModel]) -> PrimaryKeyInfo:
+ """Introspects the model to find primary key details."""
+ for field_name, model_field in model_class.model_fields.items():
+ if getattr(model_field, "primary_key", False):
+ pk_type = model_field.annotation
+ origin_type = get_origin(pk_type)
+ if origin_type is Union:
+ args = get_args(pk_type)
+ pk_type = next(
+ (
+ arg
+ for arg in args
+ if arg is not type(None) and arg is not None
+ ),
+ None,
+ )
+
+ has_uuid_factory = False
+ if model_field.default_factory:
+ factory_func = model_field.default_factory
+ if factory_func is uuid.uuid4 or (
+ callable(factory_func)
+ and getattr(factory_func, "__name__", "").lower() == "uuid4"
+ ):
+ has_uuid_factory = True
+
+ return PrimaryKeyInfo(
+ name=field_name, type=pk_type, has_uuid_factory=has_uuid_factory
+ )
+
+ return PrimaryKeyInfo(name=None, type=None, has_uuid_factory=False)
+
+ def _generate_pk_if_needed(
+ self, instance: SQLModelInstance, model_class: Type[SQLModel]
+ ) -> None:
+ """Generates a primary key for the instance if it's needed."""
+ pk_info = self._get_primary_key_info(model_class)
+
+ if not pk_info.name:
+ return
+
+ current_pk_value = getattr(instance, pk_info.name, None)
+
+ if current_pk_value is not None or pk_info.has_uuid_factory:
+ return
+
+ if pk_info.type is uuid.UUID:
+ setattr(instance, pk_info.name, uuid.uuid4())
+ elif pk_info.type is str:
+ setattr(instance, pk_info.name, str(uuid.uuid4()))
+
+ def _create_single_instance(
+ self,
+ entity_data: Dict[str, Any],
+ model_schema_map: Dict[str, Type[SQLModel]],
+ ) -> None:
+ """Creates a single SQLModel instance from its dictionary representation."""
+ _temp_id = entity_data.get("_temp_id")
+ _type = entity_data.get("_type")
+
+ if not _temp_id or not _type:
+ raise ValueError(
+ "Entity data in 'entities' list is missing '_temp_id' or '_type'."
+ )
+ if _type not in model_schema_map:
+ raise ValueError(
+ f"No SQLModel class found in model_schema_map for type: '{_type}'."
+ )
+ if _temp_id in self.temp_id_to_instance_map:
+ raise ValueError(
+ f"Duplicate _temp_id '{_temp_id}' found in 'entities' list."
+ )
+
+ model_class = model_schema_map[_type]
+
+ filtered_data = self._filter_special_fields(entity_data.copy())
+
+ pk_field_name: Optional[str] = None
+ for field_name, model_field in model_class.model_fields.items():
+ if getattr(model_field, "primary_key", False):
+ pk_field_name = field_name
+ break
+
+ if pk_field_name and pk_field_name in filtered_data:
+ # Store the original PK value for later foreign key resolution
+ original_pk = filtered_data[pk_field_name]
+ if original_pk is not None:
+ self.original_pk_map[(_type, original_pk)] = (
+ None # Will be set to instance later
+ )
+ del filtered_data[pk_field_name]
+
+ try:
+ instance = model_class.model_validate(filtered_data)
+ except Exception as e:
+ raise ValueError(
+ f"Failed to instantiate/validate SQLModel '{_type}' for _temp_id '{_temp_id}': {e}"
+ ) from e
+
+ # Update the original_pk_map with the actual instance
+ if pk_field_name and pk_field_name in entity_data:
+ original_pk = entity_data[pk_field_name]
+ if original_pk is not None:
+ self.original_pk_map[(_type, original_pk)] = instance
+
+ self._generate_pk_if_needed(instance, model_class)
+ self.temp_id_to_instance_map[_temp_id] = instance
+
+ def _create_and_map_instances(
+ self,
+ entities_list: List[Dict[str, Any]],
+ model_schema_map: Dict[str, Type[SQLModel]],
+ ) -> None:
+ """Pass 1: Creates and maps all SQLModel instances."""
+ for entity_data in entities_list:
+ self._create_single_instance(entity_data, model_schema_map)
+
+ def _link_to_one_relation(
+ self,
+ instance: SQLModelInstance,
+ relation_name: str,
+ ref_id: Any,
+ entity_data: Dict[str, Any],
+ ) -> None:
+ """Handles the logic for a single to-one relationship."""
+ if ref_id is None:
+ setattr(instance, relation_name, None)
+ return
+
+ if isinstance(ref_id, str) and ref_id in self.temp_id_to_instance_map:
+ related_instance = self.temp_id_to_instance_map[ref_id]
+ setattr(instance, relation_name, related_instance)
+ else:
+ _temp_id = entity_data.get("_temp_id", "N/A")
+ _type = entity_data.get("_type", "N/A")
+ self.logger.warning(
+ f"Referenced _temp_id '{ref_id}' for relation "
+ f"'{relation_name}' on instance '{_temp_id}' (type: {_type}) not found or invalid type."
+ )
+
+ def _link_to_many_relation(
+ self,
+ instance: SQLModelInstance,
+ relation_name: str,
+ ref_ids: Any,
+ entity_data: Dict[str, Any],
+ ) -> None:
+ """Handles the logic for a single to-many relationship."""
+ _temp_id = entity_data.get("_temp_id", "N/A")
+ _type = entity_data.get("_type", "N/A")
+
+ if not isinstance(ref_ids, list):
+ if ref_ids is not None:
+ self.logger.warning(
+ f"Value for '{relation_name}_ref_ids' on instance '{_temp_id}' is not a list as expected for '_ref_ids'. Value: {ref_ids}"
+ )
+ setattr(instance, relation_name, [])
+ return
+
+ related_instances = []
+ for ref_id in ref_ids:
+ if isinstance(ref_id, str) and ref_id in self.temp_id_to_instance_map:
+ related_instances.append(self.temp_id_to_instance_map[ref_id])
+ else:
+ self.logger.warning(
+ f"Referenced _temp_id '{ref_id}' in list for relation "
+ f"'{relation_name}' on instance '{_temp_id}' (type: {_type}) not found or invalid type."
+ )
+ setattr(instance, relation_name, related_instances)
+
+ def _link_relations_for_instance(self, entity_data: Dict[str, Any]) -> None:
+ """Links relationships for a single instance by dispatching to specialized helpers."""
+ _temp_id = entity_data["_temp_id"]
+ instance = self.temp_id_to_instance_map[_temp_id]
+
+ for key, value in entity_data.items():
+ if key.endswith("_ref_id"):
+ relation_name = key[:-7]
+ if hasattr(instance, relation_name):
+ self._link_to_one_relation(
+ instance, relation_name, value, entity_data
+ )
+ elif key.endswith("_ref_ids"):
+ relation_name = key[:-8]
+ if hasattr(instance, relation_name):
+ self._link_to_many_relation(
+ instance, relation_name, value, entity_data
+ )
+
+ def _link_relationships(self, entities_list: List[Dict[str, Any]]) -> None:
+ """Pass 2: Links all created instances together."""
+ for entity_data in entities_list:
+ self._link_relations_for_instance(entity_data)
+
+ def _add_instances_to_session(self) -> None:
+ """Adds all created instances to the SQLAlchemy session."""
+ for instance in self.temp_id_to_instance_map.values():
+ self.session.add(instance)
+
+ def hydrate(
+ self,
+ entities_list: List[Dict[str, Any]],
+ model_schema_map: Dict[str, Type[SQLModel]],
+ ) -> List[SQLModelInstance]:
+ """
+ Hydrates SQLModel objects from a list of entity data dictionaries.
+ """
+ self._validate_entities_list(entities_list)
+
+ self.temp_id_to_instance_map.clear()
+
+ # Pass 1: Create all object instances without relationships.
+ self._create_and_map_instances(entities_list, model_schema_map)
+
+ # Pass 2: Link the created instances together.
+ self._link_relationships(entities_list)
+
+ self._add_instances_to_session()
+
+ return list(self.temp_id_to_instance_map.values())
+
+
+def persist_objects(
+ db_session: Session, objects_to_persist: List[Any], logger: logging.Logger
+) -> None:
+ """
+ Persists a list of SQLAlchemy objects to the database using the provided session.
+
+ Args:
+ db_session: The SQLAlchemy session to use for database operations.
+ objects_to_persist: A list of SQLAlchemy model instances to be saved.
+
+ Raises:
+ DatabaseWriterError: If an error occurs during the database commit.
+ """
+ if not objects_to_persist:
+ logger.info("No objects provided to persist.")
+ return
+
+ try:
+ # All objects should already be associated with the session
+ # from the hydration phase
+ db_session.add_all(objects_to_persist)
+ db_session.commit()
+ logger.info(
+ f"Successfully persisted {len(objects_to_persist)} objects to the database."
+ )
+ except SQLAlchemyError as e:
+ logger.error(f"Database commit failed: {e}", exc_info=True)
+ try:
+ db_session.rollback()
+ logger.info("Database session rolled back successfully.")
+ except SQLAlchemyError as rollback_e:
+ logger.error(
+ f"Failed to rollback database session: {rollback_e}", exc_info=True
+ )
+ # Potentially raise a more critical error or handle nested failure
+ raise DatabaseWriterError(f"Failed to persist objects due to: {e}")
+ except Exception as e:
+ logger.error(
+ f"An unexpected error occurred during object persistence: {e}",
+ exc_info=True,
+ )
+
+ if db_session.is_active:
+ db_session.rollback()
+ logger.info("Database session rolled back due to unexpected error.")
+
+ raise DatabaseWriterError(f"An unexpected error occurred: {e}")
+
+
+class ResultProcessor:
+ """Handles hydration and persistence of extraction results."""
+
+ def __init__(
+ self,
+ model_registry: ModelRegistry,
+ analytics_collector,
+ logger: logging.Logger,
+ ):
+ self.model_registry = model_registry
+ self.analytics_collector = analytics_collector
+ self.logger = logger
+ self.original_pk_map: Dict[tuple[str, Any], SQLModelInstance] = {}
+ self.all_hydrated_instances: List[SQLModelInstance] = []
+
+ def hydrate(
+ self,
+ results: List[Dict[str, Any]],
+ db_session: Optional[Session] = None,
+ default_model_type: Optional[str] = None,
+ ) -> List[Any]:
+ """
+ Hydrates dictionaries into SQLModel objects.
+
+ Args:
+ results: List of dictionaries to hydrate.
+ db_session: Optional SQLAlchemy session.
+ default_model_type: Optional override for the default model type.
+ If provided, it guides the DirectHydrator fallback.
+ """
+ if not results:
+ return []
+
+ session = self._get_or_create_session(db_session)
+
+ try:
+ self.logger.info(f"Hydrating {len(results)} objects...")
+
+ # Determine Strategy based on data content
+ first_item = results[0]
+ use_direct_hydration = False
+
+ # If _temp_id is missing, we must use DirectHydrator (Graph Reconstruction requires _temp_id)
+ if "_temp_id" not in first_item:
+ use_direct_hydration = True
+
+ # If default_model_type is explicitly provided, we assume DirectHydrator
+ if default_model_type:
+ use_direct_hydration = True
+
+ if use_direct_hydration:
+ self.logger.info(
+ f"Using DirectHydrator (default_model_type={default_model_type or 'Auto-detect'})"
+ )
+
+ default_model_class = None
+ if default_model_type:
+ default_model_class = self.model_registry.model_map.get(
+ default_model_type
+ )
+ if not default_model_class:
+ default_model_class = self.model_registry.root_model
+
+ hydrator = DirectHydrator(
+ session,
+ self.logger,
+ self.original_pk_map,
+ self.all_hydrated_instances,
+ )
+ hydrated = hydrator.hydrate(
+ results,
+ model_map=self.model_registry.model_map,
+ default_model_class=default_model_class,
+ )
+ else:
+ self.logger.info("Using SQLAlchemyHydrator for graph reconstruction")
+ hydrator = SQLAlchemyHydrator(
+ session=session,
+ logger=self.logger,
+ original_pk_map=self.original_pk_map,
+ all_instances=self.all_hydrated_instances,
+ )
+ hydrated = hydrator.hydrate(results, self.model_registry.model_map)
+ self.all_hydrated_instances.extend(hydrated)
+
+ self.analytics_collector.record_hydration_success(len(hydrated))
+ self.logger.info(f"Successfully hydrated {len(hydrated)} objects")
+
+ return hydrated
+
+ except Exception as e:
+ self.analytics_collector.record_hydration_failure()
+ raise HydrationError(f"Hydration failed: {e}") from e
+
+ finally:
+ if db_session is None and session:
+ session.close()
+
+ def persist(self, objects: List[Any], db_session: Session):
+ """Persists objects to database."""
+ if not objects:
+ self.logger.info("No objects to persist")
+ return
+
+ self._link_foreign_keys(objects)
+
+ try:
+ persist_objects(
+ db_session=db_session,
+ objects_to_persist=objects,
+ logger=self.logger,
+ )
+ except DatabaseWriterError:
+ db_session.rollback()
+ raise
+ except Exception as e:
+ db_session.rollback()
+ raise WorkflowError(f"Persistence failed: {e}") from e
+
+ def _link_foreign_keys(
+ self, instances: Optional[List[SQLModelInstance]] = None
+ ) -> None:
+ """
+ Links foreign keys for all hydrated instances before persisting.
+ """
+ target_instances = (
+ instances if instances is not None else self.all_hydrated_instances
+ )
+ if self.original_pk_map:
+ self._perform_fk_recovery(target_instances, self.original_pk_map)
+
+ def _perform_fk_recovery(
+ self,
+ instances: List[SQLModelInstance],
+ original_pk_map: Dict[tuple[str, Any], SQLModelInstance],
+ ) -> None:
+ """
+ Scans all hydrated instances for Foreign Key fields that are set (not None)
+ but might refer to an original ID that was stripped.
+ Attempts to link these to the correct instance using original_pk_map.
+ """
+ count_recovered = 0
+ for instance in instances:
+ model_class = type(instance)
+ mapper = inspect(model_class)
+
+ for rel in mapper.relationships:
+ # We only care about Many-to-One (FK holder)
+ if rel.direction.name != "MANYTOONE":
+ continue
+
+ if not rel.local_remote_pairs:
+ continue
+
+ local_col, remote_col = rel.local_remote_pairs[0]
+
+ # Check if FK field has a value on the instance
+ fk_value = getattr(instance, local_col.name, None)
+ if fk_value is None:
+ continue
+
+ # Check if relationship is already set
+ current_rel_value = getattr(instance, rel.key, None)
+ if current_rel_value is not None:
+ continue
+
+ # Try to find target instance in map
+ target_class = rel.mapper.class_
+ target_type = target_class.__name__
+
+ key = (target_type, fk_value)
+ if key in original_pk_map:
+ target_instance = original_pk_map[key]
+ setattr(instance, rel.key, target_instance)
+ count_recovered += 1
+ self.logger.debug(
+ f"Recovered relationship {model_class.__name__}.{rel.key} "
+ f"using FK {fk_value} -> {target_type}"
+ )
+
+ if count_recovered > 0:
+ self.logger.info(
+ f"Universal FK Recovery: Restored {count_recovered} relationships."
+ )
+
+ def _get_or_create_session(self, db_session: Optional[Session]) -> Session:
+ """Creates temporary in-memory session if none provided."""
+ if db_session:
+ return db_session
+
+ engine = create_engine("sqlite:///:memory:")
+ SQLModel.metadata.create_all(engine)
+ return Session(engine)
diff --git a/src/extrai/core/schema_inspector.py b/src/extrai/core/schema_inspector.py
index f7a7893..70544e2 100644
--- a/src/extrai/core/schema_inspector.py
+++ b/src/extrai/core/schema_inspector.py
@@ -1,778 +1,555 @@
-# extrai/core/schema_inspector.py
-
import json
+import logging
+import enum
+from typing import Type, List, Optional, Any, Dict, Set, Tuple
from sqlalchemy import inspect, Column, Table
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.schema import UniqueConstraint, PrimaryKeyConstraint
-import enum
-import datetime
-from typing import Any, Dict, Type, Set, Optional, List, get_origin, get_args, Tuple
from sqlmodel import SQLModel
+from extrai.utils.type_mapping import (
+ map_sql_type_to_llm_type,
+ get_python_type_str_from_pydantic_annotation,
+)
-from typing import Union as TypingUnion
-
-
-def _process_union_types(args, recurse_func):
- """Helper to process Union types, filtering and sorting."""
- if not args:
- return "union"
- union_types_str = [recurse_func(arg) for arg in args]
- processed_union_types = sorted(set(t for t in union_types_str if t != "none"))
- if len(processed_union_types) == 1:
- return processed_union_types[0]
- return f"union[{','.join(processed_union_types)}]"
-
-
-# Handler registry for different type origins
-ORIGIN_HANDLERS = {
- Optional: lambda args, r: r(args[0])
- if args and args[0] is not type(None)
- else "none",
- list: lambda args, r: f"list[{','.join([r(arg) for arg in args])}]"
- if args
- else "list",
- List: lambda args, r: f"list[{','.join([r(arg) for arg in args])}]"
- if args
- else "list",
- dict: lambda args, r: f"dict[{r(args[0])},{r(args[1])}]"
- if args and len(args) == 2
- else "dict",
- Dict: lambda args, r: f"dict[{r(args[0])},{r(args[1])}]"
- if args and len(args) == 2
- else "dict",
- TypingUnion: _process_union_types,
-}
-
-# Data-driven approach for base types
-BASE_TYPE_MAP = {
- int: "int",
- str: "str",
- bool: "bool",
- float: "float",
- datetime.date: "date",
- datetime.datetime: "datetime",
- bytes: "bytes",
- Any: "any",
- type(None): "none",
-}
-
-
-# Helper function to get a simplified string from Pydantic/SQLModel annotations
-def _get_python_type_str_from_pydantic_annotation(annotation: Any) -> str:
- origin = get_origin(annotation)
- args = get_args(annotation)
-
- if origin in ORIGIN_HANDLERS:
- return ORIGIN_HANDLERS[origin](
- args, _get_python_type_str_from_pydantic_annotation
- )
- if annotation in BASE_TYPE_MAP:
- return BASE_TYPE_MAP[annotation]
-
- if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
- return "enum"
-
- if hasattr(annotation, "__name__"):
- name_lower = annotation.__name__.lower()
- if name_lower == "secretstr":
- return "str"
- return name_lower
-
- # Fallback
- cleaned_annotation_str = str(annotation).lower().replace("typing.", "")
- if cleaned_annotation_str.startswith("~"):
- cleaned_annotation_str = cleaned_annotation_str[1:]
- return cleaned_annotation_str
-
-
-# --- Data-driven mappings for type conversion ---
-SIMPLE_PYTHON_TYPE_MAP = {
- "int": "integer",
- "str": "string",
- "bool": "boolean",
- "float": "number (float/decimal)",
- "date": "string (date format)",
- "datetime": "string (datetime format)",
- "bytes": "string (base64 encoded)",
- "enum": "string (enum)",
- "any": "any",
- "none": "null",
-}
-
-SQL_TYPE_KEYWORDS = [
- ("int", "integer"),
- ("char", "string"),
- ("text", "string"),
- ("clob", "string"),
- ("bool", "boolean"),
- ("date", "string (date/datetime format)"),
- ("time", "string (date/datetime format)"),
- ("numeric", "number (float/decimal)"),
- ("decimal", "number (float/decimal)"),
- ("float", "number (float/decimal)"),
- ("double", "number (float/decimal)"),
- ("json", "object"),
- ("array", "array"),
-]
-
-
-# --- Handlers for complex and generic types ---
-def _handle_list_type(python_type_lower: str) -> Optional[str]:
- """Handles list[...] and array[...] type mappings."""
- if python_type_lower.startswith("list[") and python_type_lower.endswith("]"):
- inner_type_str = python_type_lower[5:-1]
- mapped_inner_type = _map_sql_type_to_llm_type("", inner_type_str)
- return f"array[{mapped_inner_type}]"
- return None
-
-
-def _handle_dict_type(python_type_lower: str) -> Optional[str]:
- """Handles dict[...] and object[...] type mappings."""
- if python_type_lower.startswith("dict[") and python_type_lower.endswith("]"):
- inner_types_str = python_type_lower[5:-1]
- try:
- key_type_str, value_type_str = inner_types_str.split(",", 1)
- mapped_key_type = _map_sql_type_to_llm_type("", key_type_str.strip())
- mapped_value_type = _map_sql_type_to_llm_type("", value_type_str.strip())
- return f"object[{mapped_key_type},{mapped_value_type}]"
- except ValueError:
- return "object"
- return None
-
-
-def _handle_union_type(python_type_lower: str) -> Optional[str]:
- """Handles union[...] type mappings."""
- if python_type_lower.startswith("union[") and python_type_lower.endswith("]"):
- inner_types_str = python_type_lower[6:-1]
- union_parts = [p.strip() for p in inner_types_str.split(",") if p.strip()]
- mapped_parts = sorted(
- set(_map_sql_type_to_llm_type("", part) for part in union_parts)
- )
- if not mapped_parts:
- return "any"
- return (
- mapped_parts[0]
- if len(mapped_parts) == 1
- else f"union[{','.join(mapped_parts)}]"
- )
- return None
+class SchemaInspector:
+ """Helper class to inspect SQLAlchemy models and generate LLM schemas."""
+ def __init__(self, logger: Optional[logging.Logger] = None):
+ self.logger = logger or logging.getLogger(__name__)
-def _handle_generic_or_unknown_type(
- python_type_lower: str, sql_type_lower: str
-) -> Optional[str]:
- """Handles ambiguous types like plain 'list' or 'dict' and unknown types."""
- if python_type_lower == "list":
- if "text" in sql_type_lower: # Let the SQL keyword mapping handle this case
- return None
+ def _is_column_unique(self, column_obj: Column) -> bool:
+ """Checks if a column has a unique constraint."""
+ if column_obj.unique:
+ return True
+ if column_obj.table is not None:
+ for constraint in column_obj.table.constraints:
+ if isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
+ if column_obj.name in constraint.columns:
+ return True
+ return False
- return "array"
-
- if python_type_lower == "dict":
- return "object"
-
- if python_type_lower.startswith("unknown"):
- if "json" in sql_type_lower:
- return "object"
- if "array" in sql_type_lower:
- return "array"
- return "string"
- return None
-
-
-def _map_sql_type_to_llm_type(sql_type_str: str, python_type_str: str) -> str:
- """
- Maps SQL/Python types to simpler LLM-friendly type strings using a dispatcher pattern.
- """
- sql_type_lower = str(sql_type_str).lower()
- python_type_lower = str(python_type_str).lower()
-
- # 1. Handle complex Python types first
- for handler in [_handle_list_type, _handle_dict_type, _handle_union_type]:
- result = handler(python_type_lower)
- if result:
- return result
-
- # 2. Look up in the simple Python type map
- if python_type_lower in SIMPLE_PYTHON_TYPE_MAP:
- return SIMPLE_PYTHON_TYPE_MAP[python_type_lower]
-
- # 3. Handle generic or unknown types, which have precedence over broad SQL keywords
- result = _handle_generic_or_unknown_type(python_type_lower, sql_type_lower)
- if result:
- return result
-
- # 4. Search through SQL type keywords as a fallback
- for keyword, llm_type in SQL_TYPE_KEYWORDS:
- if keyword in sql_type_lower:
- return llm_type
-
- # 5. Final fallback if no other rule matched
- return "string"
-
-
-def _is_column_unique(column_obj: Column) -> bool:
- """Checks if a column has a unique constraint."""
- if column_obj.unique:
- return True
- if column_obj.table is not None:
- for constraint in column_obj.table.constraints:
- if isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
- if column_obj.name in constraint.columns:
- return True
- return False
-
-
-def _get_python_type_from_column(column_obj: Column) -> str:
- """Safely extracts the Python type name from a column object."""
- try:
- return column_obj.type.python_type.__name__
- except NotImplementedError:
- return "unknown_not_implemented"
- except AttributeError:
- return "unknown_no_python_type_attr"
- except Exception:
- return "unknown_error_accessing_type"
-
-
-def _build_column_info(
- column_obj: Column, is_unique: bool, python_type_name: str
-) -> Dict[str, Any]:
- """Builds the column information dictionary."""
- col_info = {
- "type": str(column_obj.type),
- "python_type": python_type_name,
- "primary_key": column_obj.primary_key,
- "nullable": column_obj.nullable,
- "unique": is_unique,
- "foreign_key_to": None,
- "comment": column_obj.comment,
- "info_dict": column_obj.info,
- }
- if column_obj.foreign_keys:
- fk_constraint_obj = next(iter(column_obj.foreign_keys))
- col_info["foreign_key_to"] = str(fk_constraint_obj.column)
- return col_info
-
-
-def _get_columns_from_inspector(inspector) -> Dict[str, Any]:
- """Extracts all column properties from a SQLAlchemy inspector."""
- columns_info = {}
- for col_attr in inspector.column_attrs:
- if not isinstance(col_attr.expression, Column):
- continue
- column_obj = col_attr.expression
- is_unique = _is_column_unique(column_obj)
- python_type_name = _get_python_type_from_column(column_obj)
- columns_info[col_attr.key] = _build_column_info(
- column_obj, is_unique, python_type_name
- )
- return columns_info
-
-
-def _get_fks_from_secondary_table(rel_prop: RelationshipProperty) -> Set[str]:
- """Handles relationships that use a secondary table."""
- involved_fk_columns: Set[str] = set()
- if rel_prop.secondary is not None:
- for fk_constraint in rel_prop.secondary.foreign_key_constraints:
- for col in fk_constraint.columns:
- involved_fk_columns.add(str(col))
- return involved_fk_columns
-
-
-def _get_fks_from_synchronize_pairs(rel_prop: RelationshipProperty) -> Set[str]:
- """Handles relationships that use synchronize_pairs."""
- involved_fk_columns: Set[str] = set()
- if rel_prop.synchronize_pairs:
- for local_join_col, remote_join_col in rel_prop.synchronize_pairs:
- if hasattr(local_join_col, "foreign_keys") and local_join_col.foreign_keys:
- involved_fk_columns.add(str(local_join_col))
- if (
- hasattr(remote_join_col, "foreign_keys")
- and remote_join_col.foreign_keys
+ def _get_python_type_from_column(self, column_obj: Column) -> str:
+ """Safely extracts the Python type name from a column object."""
+ try:
+ return column_obj.type.python_type.__name__
+ except NotImplementedError:
+ return "unknown_not_implemented"
+ except AttributeError:
+ return "unknown_no_python_type_attr"
+ except Exception:
+ return "unknown_error_accessing_type"
+
+ def _build_column_info(
+ self, column_obj: Column, is_unique: bool, python_type_name: str
+ ) -> Dict[str, Any]:
+ """Builds the column information dictionary."""
+ enum_values = None
+ # Handle SQLAlchemy Enum types (both class-based and string-based)
+ if hasattr(column_obj.type, "enum_class") and column_obj.type.enum_class:
+ if isinstance(column_obj.type.enum_class, type) and issubclass(
+ column_obj.type.enum_class, enum.Enum
):
- involved_fk_columns.add(str(remote_join_col))
- return involved_fk_columns
-
-
-def _get_fks_from_direct_foreign_keys(rel_prop: RelationshipProperty) -> Set[str]:
- """Handles relationships that have direct foreign_keys."""
- involved_fk_columns: Set[str] = set()
- if hasattr(rel_prop, "foreign_keys") and rel_prop.foreign_keys is not None:
- for fk_col in rel_prop.foreign_keys:
- involved_fk_columns.add(str(fk_col))
- return involved_fk_columns
-
-
-def _get_involved_foreign_keys(rel_prop: RelationshipProperty) -> Set[str]:
- """
- Finds all foreign key columns involved in a relationship by dispatching to helper functions.
- """
- if rel_prop.secondary is not None:
- return _get_fks_from_secondary_table(rel_prop)
-
- if rel_prop.synchronize_pairs:
- return _get_fks_from_synchronize_pairs(rel_prop)
-
- if hasattr(rel_prop, "foreign_keys") and rel_prop.foreign_keys is not None:
- return _get_fks_from_direct_foreign_keys(rel_prop)
-
- return set()
-
-
-def _build_relationship_info(
- rel_prop: RelationshipProperty,
- involved_fk_columns: Set[str],
- recursion_path_tracker: Set[Type[Any]],
-) -> Dict[str, Any]:
- """Builds the relationship information dictionary, including recursion."""
- related_model_class = rel_prop.mapper.class_
- return {
- "type": rel_prop.direction.name,
- "uselist": rel_prop.uselist,
- "related_model_name": related_model_class.__name__,
- "secondary_table_name": rel_prop.secondary.name
- if rel_prop.secondary is not None
- else None,
- "local_columns": [str(c) for c in rel_prop.local_columns],
- "remote_columns_in_join": [str(pair[1]) for pair in rel_prop.local_remote_pairs]
- if rel_prop.local_remote_pairs
- else [],
- "foreign_key_constraints_involved": sorted(involved_fk_columns),
- "back_populates": rel_prop.back_populates,
- "info_dict": rel_prop.info,
- "nested_schema": _inspect_sqlalchemy_model_recursive(
- related_model_class, recursion_path_tracker
- ),
- }
-
-
-def _get_relationships_from_inspector(
- inspector, recursion_path_tracker: Set[Type[Any]]
-) -> Dict[str, Any]:
- """Extracts all relationship properties from a SQLAlchemy inspector."""
- relationships_info = {}
- for name, rel_prop in inspector.relationships.items():
- if isinstance(rel_prop, RelationshipProperty):
- involved_fk_columns = _get_involved_foreign_keys(rel_prop)
- relationships_info[name] = _build_relationship_info(
- rel_prop, involved_fk_columns, recursion_path_tracker
+ enum_values = [e.value for e in column_obj.type.enum_class]
+ elif hasattr(column_obj.type, "enums") and column_obj.type.enums:
+ enum_values = list(column_obj.type.enums)
+
+ col_info = {
+ "type": str(column_obj.type),
+ "python_type": python_type_name,
+ "primary_key": column_obj.primary_key,
+ "nullable": column_obj.nullable,
+ "unique": is_unique,
+ "foreign_key_to": None,
+ "comment": column_obj.comment,
+ "info_dict": column_obj.info,
+ "enum_values": enum_values,
+ }
+ if column_obj.foreign_keys:
+ fk_constraint_obj = next(iter(column_obj.foreign_keys))
+ col_info["foreign_key_to"] = str(fk_constraint_obj.column)
+ return col_info
+
+ def _get_columns_from_inspector(self, inspector) -> Dict[str, Any]:
+ """Extracts all column properties from a SQLAlchemy inspector."""
+ columns_info = {}
+ for col_attr in inspector.column_attrs:
+ if not isinstance(col_attr.expression, Column):
+ continue
+ column_obj = col_attr.expression
+ is_unique = self._is_column_unique(column_obj)
+ python_type_name = self._get_python_type_from_column(column_obj)
+ columns_info[col_attr.key] = self._build_column_info(
+ column_obj, is_unique, python_type_name
)
- return relationships_info
-
-
-def _inspect_sqlalchemy_model_recursive(
- model_class: Type[Any], recursion_path_tracker: Set[Type[Any]]
-) -> Dict[str, Any]:
- """
- Internal recursive function to introspect a SQLAlchemy model class,
- including column comments, info dictionaries, and handling recursion.
-
- Args:
- model_class: The SQLAlchemy model class to inspect.
- recursion_path_tracker: A set used to track visited models in the current
- recursion path to prevent infinite loops.
-
- Returns:
- A dictionary containing the schema information of the model.
- If recursion is detected for a model already in the path, a
- simplified dictionary indicating recursion is returned.
- If the model cannot be inspected, an error dictionary is returned.
- """
- try:
- inspector = inspect(model_class)
- except NoInspectionAvailable:
+ return columns_info
+
+ def _get_fks_from_secondary_table(self, rel_prop: RelationshipProperty) -> Set[str]:
+ """Handles relationships that use a secondary table."""
+ involved_fk_columns: Set[str] = set()
+ if rel_prop.secondary is not None:
+ for fk_constraint in rel_prop.secondary.foreign_key_constraints:
+ for col in fk_constraint.columns:
+ involved_fk_columns.add(str(col))
+ return involved_fk_columns
+
+ def _get_fks_from_synchronize_pairs(
+ self, rel_prop: RelationshipProperty
+ ) -> Set[str]:
+ """Handles relationships that use synchronize_pairs."""
+ involved_fk_columns: Set[str] = set()
+ if rel_prop.synchronize_pairs:
+ for local_join_col, remote_join_col in rel_prop.synchronize_pairs:
+ if (
+ hasattr(local_join_col, "foreign_keys")
+ and local_join_col.foreign_keys
+ ):
+ involved_fk_columns.add(str(local_join_col))
+ if (
+ hasattr(remote_join_col, "foreign_keys")
+ and remote_join_col.foreign_keys
+ ):
+ involved_fk_columns.add(str(remote_join_col))
+ return involved_fk_columns
+
+ def _get_fks_from_direct_foreign_keys(
+ self, rel_prop: RelationshipProperty
+ ) -> Set[str]:
+ """Handles relationships that have direct foreign_keys."""
+ involved_fk_columns: Set[str] = set()
+ if hasattr(rel_prop, "foreign_keys") and rel_prop.foreign_keys is not None:
+ for fk_col in rel_prop.foreign_keys:
+ involved_fk_columns.add(str(fk_col))
+ return involved_fk_columns
+
+ def _get_involved_foreign_keys(self, rel_prop: RelationshipProperty) -> Set[str]:
+ """
+ Finds all foreign key columns involved in a relationship by dispatching to helper functions.
+ """
+ if rel_prop.secondary is not None:
+ return self._get_fks_from_secondary_table(rel_prop)
+
+ if rel_prop.synchronize_pairs:
+ return self._get_fks_from_synchronize_pairs(rel_prop)
+
+ if hasattr(rel_prop, "foreign_keys") and rel_prop.foreign_keys is not None:
+ return self._get_fks_from_direct_foreign_keys(rel_prop)
+
+ return set()
+
+ def _build_relationship_info(
+ self,
+ rel_prop: RelationshipProperty,
+ involved_fk_columns: Set[str],
+ recursion_path_tracker: Set[Type[Any]],
+ ) -> Dict[str, Any]:
+ """Builds the relationship information dictionary, including recursion."""
+ related_model_class = rel_prop.mapper.class_
return {
- "error": f"Could not get an inspector for {model_class}. It might not be a valid SQLAlchemy mapped class."
+ "type": rel_prop.direction.name,
+ "uselist": rel_prop.uselist,
+ "related_model_name": related_model_class.__name__,
+ "secondary_table_name": rel_prop.secondary.name
+ if rel_prop.secondary is not None
+ else None,
+ "local_columns": [str(c) for c in rel_prop.local_columns],
+ "remote_columns_in_join": [
+ str(pair[1]) for pair in rel_prop.local_remote_pairs
+ ]
+ if rel_prop.local_remote_pairs
+ else [],
+ "foreign_key_constraints_involved": sorted(involved_fk_columns),
+ "back_populates": rel_prop.back_populates,
+ "info_dict": rel_prop.info,
+ "nested_schema": self._inspect_sqlalchemy_model_recursive(
+ related_model_class, recursion_path_tracker
+ ),
}
- if inspector is None:
- return {"error": f"Inspector is None for {model_class}."}
-
- table_obj = inspector.selectable
- table_info_dict = (
- getattr(table_obj, "info", None) if isinstance(table_obj, Table) else None
- )
- table_comment = (
- getattr(table_obj, "comment", None) if isinstance(table_obj, Table) else None
- )
-
- table_name_str = getattr(model_class, "__tablename__", model_class.__name__.lower())
- if hasattr(table_obj, "name") and table_obj.name:
- table_name_str = table_obj.name
+ def _get_relationships_from_inspector(
+ self, inspector, recursion_path_tracker: Set[Type[Any]]
+ ) -> Dict[str, Any]:
+ """Extracts all relationship properties from a SQLAlchemy inspector."""
+ relationships_info = {}
+ for name, rel_prop in inspector.relationships.items():
+ if isinstance(rel_prop, RelationshipProperty):
+ involved_fk_columns = self._get_involved_foreign_keys(rel_prop)
+ relationships_info[name] = self._build_relationship_info(
+ rel_prop, involved_fk_columns, recursion_path_tracker
+ )
+ return relationships_info
+
+ def _inspect_sqlalchemy_model_recursive(
+ self, model_class: Type[Any], recursion_path_tracker: Set[Type[Any]]
+ ) -> Dict[str, Any]:
+ """
+ Internal recursive function to introspect a SQLAlchemy model class.
+ """
+ try:
+ inspector = inspect(model_class)
+ except NoInspectionAvailable:
+ return {
+ "error": f"Could not get an inspector for {model_class}. It might not be a valid SQLAlchemy mapped class."
+ }
+
+ if inspector is None:
+ return {"error": f"Inspector is None for {model_class}."}
+
+ table_obj = inspector.selectable
+ table_info_dict = (
+ getattr(table_obj, "info", None) if isinstance(table_obj, Table) else None
+ )
+ table_comment = (
+ getattr(table_obj, "comment", None)
+ if isinstance(table_obj, Table)
+ else None
+ )
- if model_class in recursion_path_tracker:
- return {
+ table_name_str = getattr(
+ model_class, "__tablename__", model_class.__name__.lower()
+ )
+ if hasattr(table_obj, "name") and table_obj.name:
+ table_name_str = table_obj.name
+
+ if model_class in recursion_path_tracker:
+ return {
+ "table_name": table_name_str,
+ "model_name": model_class.__name__,
+ "recursion_detected_for_type": model_class.__name__,
+ "info_dict": table_info_dict,
+ "comment": table_comment,
+ "description_note": "Schema for this model is detailed elsewhere in the current path.",
+ }
+
+ recursion_path_tracker.add(model_class)
+
+ schema_info: Dict[str, Any] = {
"table_name": table_name_str,
"model_name": model_class.__name__,
- "recursion_detected_for_type": model_class.__name__,
"info_dict": table_info_dict,
"comment": table_comment,
- "description_note": "Schema for this model is detailed elsewhere in the current path.",
+ "columns": self._get_columns_from_inspector(inspector),
+ "relationships": self._get_relationships_from_inspector(
+ inspector, recursion_path_tracker
+ ),
}
- recursion_path_tracker.add(model_class)
-
- schema_info: Dict[str, Any] = {
- "table_name": table_name_str,
- "model_name": model_class.__name__,
- "info_dict": table_info_dict,
- "comment": table_comment,
- "columns": _get_columns_from_inspector(inspector),
- "relationships": _get_relationships_from_inspector(
- inspector, recursion_path_tracker
- ),
- }
-
- recursion_path_tracker.remove(model_class)
- return schema_info
-
-
-def inspect_sqlalchemy_model(model_class: Type[Any]) -> Dict[str, Any]:
- """
- Public wrapper function to start the SQLAlchemy model introspection.
- Initializes an empty set for tracking the recursion path.
-
- Args:
- model_class: The SQLAlchemy model class to inspect.
-
- Returns:
- A dictionary containing the schema information of the model,
- including nested schemas, comments, and info dictionaries.
- """
- return _inspect_sqlalchemy_model_recursive(model_class, set())
-
-
-def _collect_all_sqla_models_recursively(
- current_model_class: Type[Any],
- all_discovered_models: List[Type[Any]],
- recursion_guard: Set[Type[Any]],
-) -> None:
- """
- Recursively collects all unique SQLAlchemy model classes related to current_model_class.
- This function populates the `all_discovered_models` list, preserving order.
-
- Args:
- current_model_class: The SQLAlchemy model class currently being processed.
- all_discovered_models: A list to store all unique model classes found.
- recursion_guard: A set to track models visited in the current recursion
- path to prevent infinite loops.
- """
- if current_model_class in recursion_guard:
- return
- recursion_guard.add(current_model_class)
-
- # Add the model if it's not already in the list to preserve order and uniqueness
- if current_model_class not in all_discovered_models:
- all_discovered_models.append(current_model_class)
-
- try:
- inspector = inspect(current_model_class)
- except NoInspectionAvailable:
- recursion_guard.remove(current_model_class)
- return
+ recursion_path_tracker.remove(model_class)
+ return schema_info
+
+ def inspect_sqlalchemy_model(self, model_class: Type[Any]) -> Dict[str, Any]:
+ """
+ Public wrapper function to start the SQLAlchemy model introspection.
+ """
+ return self._inspect_sqlalchemy_model_recursive(model_class, set())
+
+ def _collect_all_sqla_models_recursively(
+ self,
+ current_model_class: Type[Any],
+ all_discovered_models: List[Type[Any]],
+ recursion_guard: Set[Type[Any]],
+ ) -> None:
+ """
+ Recursively collects all unique SQLAlchemy model classes related to current_model_class.
+ """
+ if current_model_class in recursion_guard:
+ return
+ recursion_guard.add(current_model_class)
+
+ # Add the model if it's not already in the list to preserve order and uniqueness
+ if current_model_class not in all_discovered_models:
+ all_discovered_models.append(current_model_class)
- if inspector is None:
+ try:
+ inspector = inspect(current_model_class)
+ except NoInspectionAvailable:
+ recursion_guard.remove(current_model_class)
+ return
+
+ if inspector is None:
+ recursion_guard.remove(current_model_class)
+ return
+
+ for rel_prop in inspector.relationships:
+ related_sqla_model_class = rel_prop.mapper.class_
+ if related_sqla_model_class not in recursion_guard:
+ self._collect_all_sqla_models_recursively(
+ related_sqla_model_class, all_discovered_models, recursion_guard
+ )
recursion_guard.remove(current_model_class)
- return
-
- for rel_prop in inspector.relationships:
- related_sqla_model_class = rel_prop.mapper.class_
- if related_sqla_model_class not in recursion_guard:
- _collect_all_sqla_models_recursively(
- related_sqla_model_class, all_discovered_models, recursion_guard
- )
- recursion_guard.remove(current_model_class)
-
-
-def _get_prioritized_description(
- *,
- custom_desc: Optional[str] = None,
- pydantic_desc: Optional[str] = None,
- info_dict: Optional[Dict[str, Any]] = None,
- comment: Optional[str] = None,
-) -> Tuple[Optional[str], Dict[str, Any]]:
- """
- Centralized helper to determine the best description from multiple sources.
- Priority: custom -> pydantic -> info_dict['description'] -> comment.
- Also extracts any other key-value pairs from the info_dict.
- """
- description = None
- if custom_desc:
- description = custom_desc
- elif pydantic_desc:
- description = pydantic_desc
-
- other_info_from_dict = {}
- if isinstance(info_dict, dict):
- info_desc = info_dict.get("description")
- if info_desc and not description:
- description = info_desc
- other_info_from_dict = {
- k: v for k, v in info_dict.items() if k != "description"
- }
-
- if not description and comment:
- description = comment
- return description, other_info_from_dict
+ def _get_prioritized_description(
+ self,
+ *,
+ custom_desc: Optional[str] = None,
+ pydantic_desc: Optional[str] = None,
+ info_dict: Optional[Dict[str, Any]] = None,
+ comment: Optional[str] = None,
+ ) -> Tuple[Optional[str], Dict[str, Any]]:
+ """
+ Centralized helper to determine the best description from multiple sources.
+ """
+ description = None
+ if custom_desc:
+ description = custom_desc
+ elif pydantic_desc:
+ description = pydantic_desc
+
+ other_info_from_dict = {}
+ if isinstance(info_dict, dict):
+ info_desc = info_dict.get("description")
+ if info_desc and not description:
+ description = info_desc
+ other_info_from_dict = {
+ k: v for k, v in info_dict.items() if k != "description"
+ }
+
+ if not description and comment:
+ description = comment
+
+ return description, other_info_from_dict
+
+ def _process_column_for_llm_schema(
+ self,
+ col_name: str,
+ col_data: Dict[str, Any],
+ pydantic_fields: Dict[str, Any],
+ custom_descs: Dict[str, str],
+ model_name: str,
+ ) -> Tuple[str, str]:
+ """Processes a single column to generate its LLM schema representation."""
+ python_type_for_mapping = str(col_data.get("python_type", ""))
+ pydantic_field_description = None
+
+ if col_name in pydantic_fields:
+ field_pydantic_info = pydantic_fields[col_name]
+ if field_pydantic_info.annotation:
+ pydantic_derived_type_str = (
+ get_python_type_str_from_pydantic_annotation(
+ field_pydantic_info.annotation
+ )
+ )
+ if (
+ pydantic_derived_type_str
+ and not pydantic_derived_type_str.startswith("unknown")
+ and pydantic_derived_type_str != "any"
+ ):
+ python_type_for_mapping = pydantic_derived_type_str
+
+ if field_pydantic_info.description:
+ pydantic_field_description = field_pydantic_info.description
+
+ llm_type = map_sql_type_to_llm_type(
+ str(col_data.get("type", "")),
+ python_type_for_mapping,
+ )
+ description, other_info = self._get_prioritized_description(
+ custom_desc=custom_descs.get(col_name),
+ pydantic_desc=pydantic_field_description,
+ info_dict=col_data.get("info_dict"),
+ comment=col_data.get("comment"),
+ )
-def _process_column_for_llm_schema(
- col_name: str,
- col_data: Dict[str, Any],
- pydantic_fields: Dict[str, Any],
- custom_descs: Dict[str, str],
- model_name: str,
-) -> Tuple[str, str]:
- """Processes a single column to generate its LLM schema representation."""
- python_type_for_mapping = str(col_data.get("python_type", ""))
- pydantic_field_description = None
+ if not description:
+ description = f"Field '{col_name}' of type {llm_type} for {model_name}."
- if col_name in pydantic_fields:
- field_pydantic_info = pydantic_fields[col_name]
- if field_pydantic_info.annotation:
- pydantic_derived_type_str = _get_python_type_str_from_pydantic_annotation(
- field_pydantic_info.annotation
+ if col_data.get("enum_values"):
+ description += (
+ f" Authorized values: {', '.join(map(str, col_data['enum_values']))}."
)
- if (
- pydantic_derived_type_str
- and not pydantic_derived_type_str.startswith("unknown")
- and pydantic_derived_type_str != "any"
- ):
- python_type_for_mapping = pydantic_derived_type_str
- if field_pydantic_info.description:
- pydantic_field_description = field_pydantic_info.description
+ additional_info_items_str = ""
+ if other_info:
+ try:
+ additional_info_items_str = f" (Info: {json.dumps(other_info)})"
+ except TypeError:
+ additional_info_items_str = f" (Info: {str(other_info)})"
- llm_type = _map_sql_type_to_llm_type(
- str(col_data.get("type", "")),
- python_type_for_mapping,
- )
+ final_description = f"{description}{additional_info_items_str}"
+ formatted_string = f"{llm_type} // {final_description.strip()}"
- description, other_info = _get_prioritized_description(
- custom_desc=custom_descs.get(col_name),
- pydantic_desc=pydantic_field_description,
- info_dict=col_data.get("info_dict"),
- comment=col_data.get("comment"),
- )
+ return col_name, formatted_string
- if not description:
- description = f"Field '{col_name}' of type {llm_type} for {model_name}."
+ def _process_relationship_for_llm_schema(
+ self, rel_name: str, rel_data: Dict[str, Any], custom_descs: Dict[str, str]
+ ) -> Optional[Tuple[str, str]]:
+ """Processes a single relationship to generate its LLM schema representation."""
+ related_model_name = rel_data.get("related_model_name", "UnknownRelatedModel")
- additional_info_items_str = ""
- if other_info:
- try:
- additional_info_items_str = f" (Info: {json.dumps(other_info)})"
- except TypeError:
- additional_info_items_str = f" (Info: {str(other_info)})"
+ temp_ref_field_name_single = f"{rel_name}_ref_id"
+ temp_ref_field_name_list = f"{rel_name}_ref_ids"
- final_description = f"{description}{additional_info_items_str}"
- formatted_string = f"{llm_type} // {final_description.strip()}"
+ custom_desc_lookup = (
+ custom_descs.get(rel_name)
+ or custom_descs.get(temp_ref_field_name_single)
+ or custom_descs.get(temp_ref_field_name_list)
+ )
- return col_name, formatted_string
+ description, other_info = self._get_prioritized_description(
+ custom_desc=custom_desc_lookup,
+ info_dict=rel_data.get("info_dict"),
+ )
+ additional_info_items_str = ""
+ if other_info:
+ try:
+ additional_info_items_str = f" (Info: {json.dumps(other_info)})"
+ except TypeError:
+ additional_info_items_str = f" (Info: {str(other_info)})"
+
+ ref_field_name_for_llm = ""
+ field_type_for_llm = ""
+ default_desc = ""
+
+ if rel_data.get("uselist") is True:
+ ref_field_name_for_llm = temp_ref_field_name_list
+ field_type_for_llm = "array of strings (temporary IDs)"
+ default_desc = f"A list of _temp_ids for related {related_model_name} entities in '{rel_name}'."
+ elif rel_data.get("uselist") is False:
+ ref_field_name_for_llm = temp_ref_field_name_single
+ field_type_for_llm = "string (temporary ID)"
+ default_desc = (
+ f"The _temp_id of the related {related_model_name} for '{rel_name}'."
+ )
-def _process_relationship_for_llm_schema(
- rel_name: str, rel_data: Dict[str, Any], custom_descs: Dict[str, str]
-) -> Optional[Tuple[str, str]]:
- """Processes a single relationship to generate its LLM schema representation."""
- related_model_name = rel_data.get("related_model_name", "UnknownRelatedModel")
+ if not ref_field_name_for_llm:
+ return None
- temp_ref_field_name_single = f"{rel_name}_ref_id"
- temp_ref_field_name_list = f"{rel_name}_ref_ids"
+ final_description = description or default_desc
+ full_description = f"{final_description}{additional_info_items_str}"
- custom_desc_lookup = (
- custom_descs.get(rel_name)
- or custom_descs.get(temp_ref_field_name_single)
- or custom_descs.get(temp_ref_field_name_list)
- )
+ formatted_string = f"{field_type_for_llm} // {full_description.strip()}"
- description, other_info = _get_prioritized_description(
- custom_desc=custom_desc_lookup,
- info_dict=rel_data.get("info_dict"),
- )
+ return ref_field_name_for_llm, formatted_string
- additional_info_items_str = ""
- if other_info:
- try:
- additional_info_items_str = f" (Info: {json.dumps(other_info)})"
- except TypeError:
- additional_info_items_str = f" (Info: {str(other_info)})"
-
- ref_field_name_for_llm = ""
- field_type_for_llm = ""
- default_desc = ""
-
- if rel_data.get("uselist") is True:
- ref_field_name_for_llm = temp_ref_field_name_list
- field_type_for_llm = "array of strings (temporary IDs)"
- default_desc = f"A list of _temp_ids for related {related_model_name} entities in '{rel_name}'."
- elif rel_data.get("uselist") is False:
- ref_field_name_for_llm = temp_ref_field_name_single
- field_type_for_llm = "string (temporary ID)"
- default_desc = (
- f"The _temp_id of the related {related_model_name} for '{rel_name}'."
+ def _generate_model_level_description(
+ self, model_name: str, raw_schema: Dict[str, Any], custom_descs: Dict[str, str]
+ ) -> str:
+ """Generates the complete model-level description block."""
+ description, other_info = self._get_prioritized_description(
+ custom_desc=custom_descs.get("_model_description"),
+ info_dict=raw_schema.get("info_dict"),
+ comment=raw_schema.get("comment"),
)
- if not ref_field_name_for_llm:
- return None
-
- final_description = description or default_desc
- full_description = f"{final_description}{additional_info_items_str}"
+ if not description:
+ description = f"Represents a {model_name} entity."
+
+ model_additional_info = ""
+ if other_info:
+ try:
+ model_additional_info = f" (Info: {json.dumps(other_info)})"
+ except TypeError:
+ model_additional_info = f" (Info: {str(other_info)})"
+
+ final_model_description_base = f"{description}{model_additional_info}"
+ final_model_overall_description = (
+ f"{final_model_description_base.strip()} "
+ f"When processing a {model_name}, the LLM should assign a unique '_temp_id' "
+ f"to each instance and use '{model_name}' as its '_type' field in the output 'entities' list."
+ )
+ return final_model_overall_description
+
+ def generate_llm_schema_from_models(
+ self,
+ initial_model_classes: List[Type[SQLModel]],
+ custom_field_descriptions: Optional[Dict[str, Dict[str, str]]] = None,
+ ) -> str:
+ """
+ Generates an LLM-friendly schema representation for a list of SQLAlchemy models.
+ """
+ if custom_field_descriptions is None:
+ custom_field_descriptions = {}
+
+ all_sqla_models_to_document: List[Type[Any]] = []
+ for root_model_class in initial_model_classes:
+ self._collect_all_sqla_models_recursively(
+ root_model_class, all_sqla_models_to_document, set()
+ )
- formatted_string = f"{field_type_for_llm} // {full_description.strip()}"
+ llm_schema_map = {}
- return ref_field_name_for_llm, formatted_string
+ for model_class in all_sqla_models_to_document:
+ model_name = model_class.__name__
+ raw_schema = self.inspect_sqlalchemy_model(model_class)
+ if raw_schema.get("error"):
+ self.logger.warning(
+ f"Could not inspect model {model_name} for LLM schema generation. Error: {raw_schema['error']}"
+ )
+ continue
-def _generate_model_level_description(
- model_name: str, raw_schema: Dict[str, Any], custom_descs: Dict[str, str]
-) -> str:
- """Generates the complete model-level description block."""
- description, other_info = _get_prioritized_description(
- custom_desc=custom_descs.get("_model_description"),
- info_dict=raw_schema.get("info_dict"),
- comment=raw_schema.get("comment"),
- )
+ model_custom_descs = custom_field_descriptions.get(model_name, {})
- if not description:
- description = f"Represents a {model_name} entity."
+ # Get pydantic model fields if applicable
+ pydantic_model_fields = {}
+ if hasattr(model_class, "model_fields") and issubclass(
+ model_class, SQLModel
+ ):
+ pydantic_model_fields = model_class.model_fields
+
+ fields_info = {}
+ for col_name, col_data in raw_schema.get("columns", {}).items():
+ processed_col_name, formatted_col_string = (
+ self._process_column_for_llm_schema(
+ col_name,
+ col_data,
+ pydantic_model_fields,
+ model_custom_descs,
+ model_name,
+ )
+ )
+ fields_info[processed_col_name] = formatted_col_string
+
+ for rel_name, rel_data in raw_schema.get("relationships", {}).items():
+ processed_rel = self._process_relationship_for_llm_schema(
+ rel_name, rel_data, model_custom_descs
+ )
+ if processed_rel:
+ field_name, formatted_string = processed_rel
+ fields_info[field_name] = formatted_string
+
+ final_model_overall_description = self._generate_model_level_description(
+ model_name, raw_schema, model_custom_descs
+ )
- model_additional_info = ""
- if other_info:
+ llm_schema_map[model_name] = {
+ "description": final_model_overall_description,
+ "fields": fields_info,
+ "notes_for_llm": (
+ f"For {model_name}: Ensure all fields conform to their types. "
+ "Relationship fields (like '{rel_name}_ref_id' or '{rel_name}_ref_ids') "
+ "must use the _temp_ids of corresponding related entities defined in this response. "
+ "Omit optional fields if no information is found."
+ ),
+ }
+ return json.dumps(llm_schema_map, indent=2)
+
+ def discover_sqlmodels_from_root(
+ self,
+ root_sqlmodel_class: Type[SQLModel],
+ ) -> List[Type[SQLModel]]:
+ """
+ Discovers all unique SQLModel classes starting from a root SQLModel class.
+ """
+ if not root_sqlmodel_class or not issubclass(root_sqlmodel_class, SQLModel):
+ self.logger.warning(f"{root_sqlmodel_class} is not a valid SQLModel class.")
+ return []
+
+ all_discovered_models: List[Type[SQLModel]] = []
try:
- model_additional_info = f" (Info: {json.dumps(other_info)})"
- except TypeError:
- model_additional_info = f" (Info: {str(other_info)})"
-
- final_model_description_base = f"{description}{model_additional_info}"
- final_model_overall_description = (
- f"{final_model_description_base.strip()} "
- f"When processing a {model_name}, the LLM should assign a unique '_temp_id' "
- f"to each instance and use '{model_name}' as its '_type' field in the output 'entities' list."
- )
- return final_model_overall_description
-
-
-def generate_llm_schema_from_models(
- initial_model_classes: List[Type[SQLModel]],
- custom_field_descriptions: Optional[Dict[str, Dict[str, str]]] = None,
-) -> str:
- """
- Generates an LLM-friendly schema representation for a list of SQLAlchemy models.
- It starts with `initial_model_classes`, discovers all related SQLAlchemy models
- recursively, and includes them in the generated schema.
- The schema utilizes comments and info dictionaries from the models and allows
- for custom descriptions to override or augment default ones.
-
- Args:
- initial_model_classes: A list of SQLAlchemy model classes to serve as
- starting points for schema generation.
- custom_field_descriptions: An optional dictionary to provide custom
- descriptions for models or their fields.
- Format: `{"ModelName": {"field_name": "desc", "_model_description": "model_desc"}}`
-
- Returns:
- A JSON string representing the LLM-friendly schema for all discovered models.
- """
- if custom_field_descriptions is None:
- custom_field_descriptions = {}
-
- all_sqla_models_to_document: List[Type[Any]] = []
- for root_model_class in initial_model_classes:
- _collect_all_sqla_models_recursively(
- root_model_class, all_sqla_models_to_document, set()
- )
-
- llm_schema_map = {}
-
- for model_class in all_sqla_models_to_document:
- model_name = model_class.__name__
- raw_schema = inspect_sqlalchemy_model(model_class)
-
- if raw_schema.get("error"):
- print(
- f"Warning: Could not inspect model {model_name} for LLM schema generation. Error: {raw_schema['error']}"
+ self._collect_all_sqla_models_recursively(
+ current_model_class=root_sqlmodel_class,
+ all_discovered_models=all_discovered_models, # type: ignore[arg-type]
+ recursion_guard=set(),
)
- continue
-
- model_custom_descs = custom_field_descriptions.get(model_name, {})
-
- # Get pydantic model fields if applicable
- pydantic_model_fields = {}
- if hasattr(model_class, "model_fields") and issubclass(model_class, SQLModel):
- pydantic_model_fields = model_class.model_fields
-
- fields_info = {}
- for col_name, col_data in raw_schema.get("columns", {}).items():
- processed_col_name, formatted_col_string = _process_column_for_llm_schema(
- col_name,
- col_data,
- pydantic_model_fields,
- model_custom_descs,
- model_name,
+ except Exception as e:
+ self.logger.error(
+ f"Error during SQLModel discovery starting from {root_sqlmodel_class.__name__}: {e}"
)
- fields_info[processed_col_name] = formatted_col_string
-
- for rel_name, rel_data in raw_schema.get("relationships", {}).items():
- processed_rel = _process_relationship_for_llm_schema(
- rel_name, rel_data, model_custom_descs
- )
- if processed_rel:
- field_name, formatted_string = processed_rel
- fields_info[field_name] = formatted_string
-
- final_model_overall_description = _generate_model_level_description(
- model_name, raw_schema, model_custom_descs
- )
-
- llm_schema_map[model_name] = {
- "description": final_model_overall_description,
- "fields": fields_info,
- "notes_for_llm": (
- f"For {model_name}: Ensure all fields conform to their types. "
- "Relationship fields (like '{rel_name}_ref_id' or '{rel_name}_ref_ids') "
- "must use the _temp_ids of corresponding related entities defined in this response. "
- "Omit optional fields if no information is found."
- ),
- }
- return json.dumps(llm_schema_map, indent=2)
-
-
-def discover_sqlmodels_from_root(
- root_sqlmodel_class: Type[SQLModel],
-) -> List[Type[SQLModel]]:
- """
- Discovers all unique SQLModel classes starting from a root SQLModel class,
- by recursively inspecting SQLAlchemy relationships, preserving discovery order.
-
- Args:
- root_sqlmodel_class: The primary SQLModel class to start discovery from.
-
- Returns:
- A list of all unique SQLModel classes discovered (including the root),
- in the order they were found. Returns an empty list if the
- root_sqlmodel_class is not a valid SQLModel or if no classes can be inspected.
- """
- if not root_sqlmodel_class or not issubclass(root_sqlmodel_class, SQLModel):
- print(f"Warning: {root_sqlmodel_class} is not a valid SQLModel class.")
- return []
-
- all_discovered_models: List[Type[SQLModel]] = []
- try:
- _collect_all_sqla_models_recursively(
- current_model_class=root_sqlmodel_class,
- all_discovered_models=all_discovered_models, # type: ignore[arg-type]
- recursion_guard=set(),
- )
- except Exception as e:
- print(
- f"Error during SQLModel discovery starting from {root_sqlmodel_class.__name__}: {e}"
- )
- return []
+ return []
- return all_discovered_models
+ return all_discovered_models
diff --git a/src/extrai/core/sqlalchemy_hydrator.py b/src/extrai/core/sqlalchemy_hydrator.py
deleted file mode 100644
index 40abc12..0000000
--- a/src/extrai/core/sqlalchemy_hydrator.py
+++ /dev/null
@@ -1,278 +0,0 @@
-from typing import (
- Dict,
- List,
- Any,
- Type,
- Optional,
- get_origin,
- get_args,
- Union,
- NamedTuple,
-)
-import uuid
-from sqlalchemy.orm import Session
-from sqlmodel import SQLModel
-
-SQLModelInstance = SQLModel
-
-
-class PrimaryKeyInfo(NamedTuple):
- name: Optional[str]
- type: Optional[Type[Any]]
- has_uuid_factory: bool
-
-
-class SQLAlchemyHydrator:
- """
- Hydrates SQLModel objects from consensus JSON data.
- It uses a two-pass strategy: first, create all object instances,
- then link their relationships using temporary IDs.
- """
-
- def __init__(self, session: Session):
- """
- Initializes the Hydrator.
-
- Args:
- session: The SQLAlchemy session to use for database operations
- and instance management (e.g., adding instances).
- """
- self.session: Session = session
- self.temp_id_to_instance_map: Dict[
- str, SQLModelInstance
- ] = {} # Stores _temp_id -> SQLModel instance
-
- def _filter_special_fields(self, data: Dict[str, Any]) -> Dict[str, Any]:
- """Removes _temp_id, _type, and relationship reference fields before Pydantic validation."""
- return {
- k: v
- for k, v in data.items()
- if k not in ["_temp_id", "_type"]
- and not k.endswith("_ref_id")
- and not k.endswith("_ref_ids")
- }
-
- def _validate_entities_list(self, entities_list: List[Dict[str, Any]]) -> None:
- """Performs initial validation on the input entities list."""
- if not isinstance(entities_list, list):
- raise TypeError(
- f"Input 'entities_list' must be a list. Got: {type(entities_list)}"
- )
- if not all(isinstance(item, dict) for item in entities_list):
- first_non_dict = next(
- (item for item in entities_list if not isinstance(item, dict)), None
- )
- raise ValueError(
- "All items in 'entities_list' must be dictionaries. "
- f"Found an item of type: {type(first_non_dict)}."
- )
-
- def _get_primary_key_info(self, model_class: Type[SQLModel]) -> PrimaryKeyInfo:
- """Introspects the model to find primary key details."""
- for field_name, model_field in model_class.model_fields.items():
- if getattr(model_field, "primary_key", False):
- pk_type = model_field.annotation
- origin_type = get_origin(pk_type)
- if origin_type is Union:
- args = get_args(pk_type)
- pk_type = next(
- (
- arg
- for arg in args
- if arg is not type(None) and arg is not None
- ),
- None,
- )
-
- has_uuid_factory = False
- if model_field.default_factory:
- factory_func = model_field.default_factory
- if factory_func is uuid.uuid4 or (
- callable(factory_func)
- and getattr(factory_func, "__name__", "").lower() == "uuid4"
- ):
- has_uuid_factory = True
-
- return PrimaryKeyInfo(
- name=field_name, type=pk_type, has_uuid_factory=has_uuid_factory
- )
-
- return PrimaryKeyInfo(name=None, type=None, has_uuid_factory=False)
-
- def _generate_pk_if_needed(
- self, instance: SQLModelInstance, model_class: Type[SQLModel]
- ) -> None:
- """Generates a primary key for the instance if it's needed."""
- pk_info = self._get_primary_key_info(model_class)
-
- if not pk_info.name:
- return
-
- current_pk_value = getattr(instance, pk_info.name, None)
-
- if current_pk_value is not None or pk_info.has_uuid_factory:
- return
-
- if pk_info.type is uuid.UUID:
- setattr(instance, pk_info.name, uuid.uuid4())
- elif pk_info.type is str:
- setattr(instance, pk_info.name, str(uuid.uuid4()))
-
- def _create_single_instance(
- self,
- entity_data: Dict[str, Any],
- model_schema_map: Dict[str, Type[SQLModel]],
- ) -> None:
- """Creates a single SQLModel instance from its dictionary representation."""
- _temp_id = entity_data.get("_temp_id")
- _type = entity_data.get("_type")
-
- if not _temp_id or not _type:
- raise ValueError(
- "Entity data in 'entities' list is missing '_temp_id' or '_type'."
- )
- if _type not in model_schema_map:
- raise ValueError(
- f"No SQLModel class found in model_schema_map for type: '{_type}'."
- )
- if _temp_id in self.temp_id_to_instance_map:
- raise ValueError(
- f"Duplicate _temp_id '{_temp_id}' found in 'entities' list."
- )
-
- model_class = model_schema_map[_type]
-
- filtered_data = self._filter_special_fields(entity_data.copy())
-
- pk_field_name: Optional[str] = None
- for field_name, model_field in model_class.model_fields.items():
- if getattr(model_field, "primary_key", False):
- pk_field_name = field_name
- break
-
- if pk_field_name and pk_field_name in filtered_data:
- del filtered_data[pk_field_name]
-
- try:
- instance = model_class.model_validate(filtered_data)
- except Exception as e:
- raise ValueError(
- f"Failed to instantiate/validate SQLModel '{_type}' for _temp_id '{_temp_id}': {e}"
- ) from e
-
- self._generate_pk_if_needed(instance, model_class)
- self.temp_id_to_instance_map[_temp_id] = instance
-
- def _create_and_map_instances(
- self,
- entities_list: List[Dict[str, Any]],
- model_schema_map: Dict[str, Type[SQLModel]],
- ) -> None:
- """Pass 1: Creates and maps all SQLModel instances."""
- for entity_data in entities_list:
- self._create_single_instance(entity_data, model_schema_map)
-
- def _link_to_one_relation(
- self,
- instance: SQLModelInstance,
- relation_name: str,
- ref_id: Any,
- entity_data: Dict[str, Any],
- ) -> None:
- """Handles the logic for a single to-one relationship."""
- if ref_id is None:
- setattr(instance, relation_name, None)
- return
-
- if isinstance(ref_id, str) and ref_id in self.temp_id_to_instance_map:
- related_instance = self.temp_id_to_instance_map[ref_id]
- setattr(instance, relation_name, related_instance)
- else:
- _temp_id = entity_data.get("_temp_id", "N/A")
- _type = entity_data.get("_type", "N/A")
- print(
- f"Warning: Referenced _temp_id '{ref_id}' for relation "
- f"'{relation_name}' on instance '{_temp_id}' (type: {_type}) not found or invalid type."
- )
-
- def _link_to_many_relation(
- self,
- instance: SQLModelInstance,
- relation_name: str,
- ref_ids: Any,
- entity_data: Dict[str, Any],
- ) -> None:
- """Handles the logic for a single to-many relationship."""
- _temp_id = entity_data.get("_temp_id", "N/A")
- _type = entity_data.get("_type", "N/A")
-
- if not isinstance(ref_ids, list):
- if ref_ids is not None:
- print(
- f"Warning: Value for '{relation_name}_ref_ids' on instance '{_temp_id}' is not a list as expected for '_ref_ids'. Value: {ref_ids}"
- )
- setattr(instance, relation_name, [])
- return
-
- related_instances = []
- for ref_id in ref_ids:
- if isinstance(ref_id, str) and ref_id in self.temp_id_to_instance_map:
- related_instances.append(self.temp_id_to_instance_map[ref_id])
- else:
- print(
- f"Warning: Referenced _temp_id '{ref_id}' in list for relation "
- f"'{relation_name}' on instance '{_temp_id}' (type: {_type}) not found or invalid type."
- )
- setattr(instance, relation_name, related_instances)
-
- def _link_relations_for_instance(self, entity_data: Dict[str, Any]) -> None:
- """Links relationships for a single instance by dispatching to specialized helpers."""
- _temp_id = entity_data["_temp_id"]
- instance = self.temp_id_to_instance_map[_temp_id]
-
- for key, value in entity_data.items():
- if key.endswith("_ref_id"):
- relation_name = key[:-7]
- if hasattr(instance, relation_name):
- self._link_to_one_relation(
- instance, relation_name, value, entity_data
- )
- elif key.endswith("_ref_ids"):
- relation_name = key[:-8]
- if hasattr(instance, relation_name):
- self._link_to_many_relation(
- instance, relation_name, value, entity_data
- )
-
- def _link_relationships(self, entities_list: List[Dict[str, Any]]) -> None:
- """Pass 2: Links all created instances together."""
- for entity_data in entities_list:
- self._link_relations_for_instance(entity_data)
-
- def _add_instances_to_session(self) -> None:
- """Adds all created instances to the SQLAlchemy session."""
- for instance in self.temp_id_to_instance_map.values():
- self.session.add(instance)
-
- def hydrate(
- self,
- entities_list: List[Dict[str, Any]],
- model_schema_map: Dict[str, Type[SQLModel]],
- ) -> List[SQLModelInstance]:
- """
- Hydrates SQLModel objects from a list of entity data dictionaries.
- """
- self._validate_entities_list(entities_list)
-
- self.temp_id_to_instance_map.clear()
-
- # Pass 1: Create all object instances without relationships.
- self._create_and_map_instances(entities_list, model_schema_map)
-
- # Pass 2: Link the created instances together.
- self._link_relationships(entities_list)
-
- # Add the completed object graph to the session.
- self._add_instances_to_session()
-
- return list(self.temp_id_to_instance_map.values())
diff --git a/src/extrai/core/sqlmodel_generator.py b/src/extrai/core/sqlmodel_generator.py
index b018278..9be472d 100644
--- a/src/extrai/core/sqlmodel_generator.py
+++ b/src/extrai/core/sqlmodel_generator.py
@@ -1,6 +1,5 @@
-import keyword
import logging
-from typing import Any, Dict, Set, Type, List as TypingList, Optional, Generator
+from typing import Any, Dict, Type, List as TypingList, Optional, Generator
import tempfile
import importlib.util
import sys
@@ -30,262 +29,7 @@
generate_sqlmodel_creation_system_prompt,
generate_user_prompt_for_docs,
)
-
-
-class _ImportManager:
- """Manages imports for the generated code, handling consolidation."""
-
- def __init__(self):
- self.typing_imports: Set[str] = set()
- self.sqlmodel_imports: Set[str] = {"SQLModel"}
- self.module_imports: Set[str] = set()
- self.custom_imports: Set[str] = set()
-
- def add_import_for_type(self, type_str: str):
- if "datetime." in type_str:
- self.module_imports.add("datetime")
- if "uuid." in type_str:
- self.module_imports.add("uuid")
- if "Optional[" in type_str:
- self.typing_imports.add("Optional")
- if "List[" in type_str:
- self.typing_imports.add("List")
- if "Dict[" in type_str:
- self.typing_imports.add("Dict")
- if "Union[" in type_str:
- self.typing_imports.add("Union")
- if "Any" in type_str:
- self.typing_imports.add("Any")
-
- def add_custom_imports(self, imports: TypingList[str]):
- for imp in imports:
- self.custom_imports.add(imp.strip())
-
- def render(self) -> str:
- import_lines = []
-
- # Consolidate custom imports with auto-detected ones
- for custom_imp in self.custom_imports:
- if custom_imp.startswith("from sqlmodel"):
- items = {
- item.strip() for item in custom_imp.split(" import ")[1].split(",")
- }
- self.sqlmodel_imports.update(items)
- elif custom_imp.startswith("from typing"):
- items = {
- item.strip() for item in custom_imp.split(" import ")[1].split(",")
- }
- self.typing_imports.update(items)
- elif custom_imp.startswith("import "):
- modules = {
- mod.strip() for mod in custom_imp.replace("import ", "").split(",")
- }
- self.module_imports.update(modules)
- else:
- import_lines.append(custom_imp) # Add other complex imports as is
-
- if self.sqlmodel_imports:
- import_lines.append(
- f"from sqlmodel import {', '.join(sorted(list(self.sqlmodel_imports)))}"
- )
- if self.typing_imports:
- import_lines.append(
- f"from typing import {', '.join(sorted(list(self.typing_imports)))}"
- )
- if self.module_imports:
- for mod in sorted(self.module_imports):
- import_lines.append(f"import {mod}")
-
- return "\n".join(sorted(set(import_lines)))
-
-
-class _CodeBuilder:
- """Builds the final Python code string from its components."""
-
- def __init__(
- self,
- model_name: str,
- import_manager: _ImportManager,
- description: str,
- table_name: str,
- base_classes: TypingList[str],
- is_table_model: bool,
- ):
- self.model_name = model_name
- self.import_manager = import_manager
- self.description = description
- self.table_name = table_name
- self.base_classes = base_classes
- self.is_table_model = is_table_model
- self.fields_code: TypingList[str] = []
-
- def add_field(self, field_code: str):
- self.fields_code.append(field_code)
-
- def render_class_definition(self) -> str:
- fields_str = "\n".join(self.fields_code) if self.fields_code else " pass"
-
- base_classes_str = ", ".join(self.base_classes)
- class_decorator_args = []
- if self.is_table_model:
- class_decorator_args.append("table=True")
-
- class_header = f"class {self.model_name}({base_classes_str}"
- if "SQLModel" in base_classes_str and class_decorator_args:
- class_header += f", {', '.join(class_decorator_args)}"
- class_header += "):"
-
- docstring_section = ""
- if self.description:
- docstring_section = f' """{self.description}"""\n'
-
- table_name_section = ""
- if self.is_table_model:
- table_name_section = f' __tablename__ = "{self.table_name}"\n\n'
- elif self.description:
- table_name_section = "\n"
-
- return f"{class_header}\n{docstring_section}{table_name_section}{fields_str}\n"
-
-
-class _FieldGenerator:
- """Generates the code for a single field in a SQLModel."""
-
- def __init__(self, field_info: Dict[str, Any], import_manager: _ImportManager):
- self.field_info = field_info
- self.imports = import_manager
- self.field_name_original = self.field_info["name"]
- self.field_name_python = self.field_name_original
- self.args_map: Dict[str, str] = {}
-
- def _handle_keyword_name(self):
- if keyword.iskeyword(self.field_name_original):
- self.field_name_python = self.field_name_original + "_"
- self.args_map["alias"] = f'"{self.field_name_original}"'
-
- def _get_default_value_arg(self):
- if "default_factory" in self.field_info:
- factory_str = self.field_info["default_factory"]
- self.args_map["default_factory"] = factory_str
- if "." in factory_str:
- potential_module = factory_str.split(".")[0]
- if potential_module.isidentifier() and potential_module not in [
- "list",
- "dict",
- "set",
- "tuple",
- ]:
- self.imports.module_imports.add(potential_module)
- elif "default" in self.field_info:
- default_val = self.field_info["default"]
- if isinstance(default_val, str):
- self.args_map["default"] = f'"{default_val}"'
- elif isinstance(default_val, bool):
- self.args_map["default"] = str(default_val)
- elif default_val is None:
- self.args_map["default"] = "None"
- else:
- self.args_map["default"] = str(default_val)
-
- def _get_nullable_arg(self):
- field_type_str = self.field_info["type"]
- is_optional_type = field_type_str.startswith("Optional[")
- is_pk = self.field_info.get("primary_key", False)
- explicit_nullable = self.field_info.get("nullable")
-
- if explicit_nullable is True:
- self.args_map["nullable"] = "True"
- elif explicit_nullable is False:
- if not (is_pk and not is_optional_type) or is_optional_type:
- self.args_map["nullable"] = "False"
- elif is_optional_type:
- self.args_map["nullable"] = "True"
-
- if is_pk and is_optional_type and self.args_map.get("nullable") != "False":
- self.args_map["nullable"] = "True"
-
- def _get_sa_column_args(self):
- sa_column_kwargs = self.field_info.get("sa_column_kwargs", {})
- for k, v in sa_column_kwargs.items():
- if k in ["server_default", "onupdate"]:
- self.args_map[k] = f'"{v}"' if isinstance(v, str) else str(v)
- elif k == "sa_type":
- self.args_map["sa_type"] = str(v)
- if str(v) == "JSON":
- self.imports.sqlmodel_imports.add("JSON")
- if "sqlalchemy." in str(v):
- self.imports.module_imports.add("sqlalchemy")
-
- def _get_common_args(self):
- if "description" in self.field_info:
- desc = self.field_info["description"]
- escaped_desc = (
- desc.replace("\\", "\\\\")
- .replace('"', '\\"')
- .replace("\n", "\\n")
- .replace("\r", "\\r")
- )
- self.args_map["description"] = f'"{escaped_desc}"'
- if "foreign_key" in self.field_info:
- self.args_map["foreign_key"] = f'"{self.field_info["foreign_key"]}"'
- if self.field_info.get("index"):
- self.args_map["index"] = "True"
- if self.field_info.get("primary_key"):
- self.args_map["primary_key"] = "True"
- if self.field_info.get("unique"):
- self.args_map["unique"] = "True"
-
- def _determine_field_arguments(self):
- self._handle_keyword_name()
- self._get_default_value_arg()
- self._get_nullable_arg()
- self._get_sa_column_args()
- self._get_common_args()
-
- def generate_code(self) -> str:
- field_type_str = self.field_info["type"]
- self.imports.add_import_for_type(field_type_str)
-
- if "field_options_str" in self.field_info:
- field_options = self.field_info["field_options_str"]
- if "JSON" in field_options:
- self.imports.sqlmodel_imports.add("JSON")
- if "Relationship" in field_options:
- self.imports.sqlmodel_imports.add("Relationship")
-
- if keyword.iskeyword(self.field_name_original):
- self.field_name_python = self.field_name_original + "_"
- return f" {self.field_name_python}: {field_type_str} = {field_options}"
-
- self._determine_field_arguments()
-
- if not self.args_map:
- return f" {self.field_name_python}: {field_type_str}"
-
- self.imports.sqlmodel_imports.add("Field")
- ordered_keys = [
- "primary_key",
- "alias",
- "default",
- "default_factory",
- "unique",
- "index",
- "foreign_key",
- "nullable",
- "sa_type",
- "description",
- "server_default",
- "onupdate",
- ]
- final_args_list = [
- f"{key}={self.args_map[key]}"
- for key in ordered_keys
- if key in self.args_map
- ]
- field_args_str = ", ".join(final_args_list)
- return (
- f" {self.field_name_python}: {field_type_str} = Field({field_args_str})"
- )
+from extrai.core.code_generation.python_builder import PythonModelBuilder
class SQLModelCodeGenerator:
@@ -357,45 +101,12 @@ def _load_sqlmodel_description_schema(self) -> Dict[str, Any]:
return SQLModelCodeGenerator._sqlmodel_description_schema_cache
def _generate_code_from_description(self, llm_json_output: Dict[str, Any]) -> str:
- import_manager = _ImportManager()
- class_definitions = []
+ """
+ Delegates the code generation to the PythonModelBuilder.
+ """
model_descriptions = llm_json_output.get("sql_models", [])
-
- # First pass to gather all imports from all model definitions
- for model_desc in model_descriptions:
- import_manager.add_custom_imports(model_desc.get("imports", []))
- base_classes = model_desc.get("base_classes_str", ["SQLModel"])
- if "SQLModel" in base_classes:
- import_manager.sqlmodel_imports.add("SQLModel")
- if "fields" in model_desc and model_desc["fields"]:
- for f_info in model_desc["fields"]:
- # This populates the import manager with types from fields
- _ = _FieldGenerator(f_info, import_manager).generate_code()
-
- # Now, build each class definition
- for model_desc in model_descriptions:
- model_name = model_desc["model_name"]
- base_classes = model_desc.get("base_classes_str", ["SQLModel"])
-
- builder = _CodeBuilder(
- model_name=model_name,
- import_manager=import_manager,
- description=model_desc.get("description", ""),
- table_name=model_desc.get("table_name", f"{model_name.lower()}s"),
- base_classes=base_classes,
- is_table_model=model_desc.get("is_table_model", True),
- )
-
- if "fields" in model_desc and model_desc["fields"]:
- for f_info in model_desc["fields"]:
- field_generator = _FieldGenerator(f_info, import_manager)
- builder.add_field(field_generator.generate_code())
-
- class_definitions.append(builder.render_class_definition())
-
- imports_str = import_manager.render()
- full_code = f"{imports_str}\n\n\n" + "\n\n".join(class_definitions)
- return full_code
+ builder = PythonModelBuilder()
+ return builder.generate_model_code(model_descriptions)
@contextmanager
def _managed_temp_module(self, code: str) -> Generator[str, None, None]:
diff --git a/src/extrai/core/workflow_orchestrator.py b/src/extrai/core/workflow_orchestrator.py
index 3d3443a..ca15e89 100644
--- a/src/extrai/core/workflow_orchestrator.py
+++ b/src/extrai/core/workflow_orchestrator.py
@@ -1,62 +1,30 @@
# extrai/core/workflow_orchestrator.py
-import json
-import logging
import asyncio
-from typing import (
- List,
- Dict,
- Any,
- Type,
- Callable,
- Optional,
- Tuple,
- Union,
-)
-
-# SQLAlchemy imports
+import logging
+from typing import List, Dict, Any, Type, Optional, Union
+from extrai.core.base_llm_client import BaseLLMClient
+from extrai.core.batch_models import BatchJobStatus, BatchProcessResult
from sqlalchemy.orm import Session
-from sqlalchemy import create_engine
-
from sqlmodel import SQLModel
-# Project imports
-from .prompt_builder import generate_system_prompt, generate_user_prompt_for_docs
-from .json_consensus import JSONConsensus, default_conflict_resolver
-from .sqlalchemy_hydrator import SQLAlchemyHydrator
-from .db_writer import persist_objects, DatabaseWriterError
-from .base_llm_client import BaseLLMClient
-from .schema_inspector import (
- generate_llm_schema_from_models,
- discover_sqlmodels_from_root,
-)
-from .errors import (
- WorkflowError,
- LLMInteractionError,
- ConfigurationError,
- ConsensusProcessError,
- HydrationError,
- LLMConfigurationError,
- LLMOutputParseError,
- LLMOutputValidationError,
- LLMAPICallError,
-)
+from .extraction_config import ExtractionConfig
+from .extraction_pipeline import ExtractionPipeline
+from .batch_pipeline import BatchPipeline
+from .result_processor import ResultProcessor
+from .model_registry import ModelRegistry
from .analytics_collector import WorkflowAnalyticsCollector
-from .example_json_generator import ExampleJSONGenerator, ExampleGenerationError
class WorkflowOrchestrator:
"""
- Orchestrates the data extraction workflow, handling both standard and hierarchical extraction.
+ Orchestrates data extraction workflows by delegating to specialized components.
- This class manages the entire process from receiving unstructured text to outputting
- structured SQLModel objects. It integrates various components like LLM clients,
- a JSON consensus mechanism, and a SQLAlchemy hydrator.
-
- For hierarchical data, it uses a breadth-first traversal approach, extracting entities
- level by level and using parent entities as context for extracting children. This logic
- is now fully integrated within this class, removing the need for a separate
- HierarchicalExtractor.
+ This class serves as a facade, coordinating between:
+ - ModelRegistry: Schema discovery and management
+ - ExtractionPipeline: Standard extraction flow
+ - BatchPipeline: Batch extraction flow
+ - ResultProcessor: Result hydration and persistence
"""
def __init__(
@@ -66,534 +34,323 @@ def __init__(
num_llm_revisions: int = 3,
max_validation_retries_per_revision: int = 2,
consensus_threshold: float = 0.51,
- conflict_resolver: Callable[
- [Tuple[int | str, ...], List[str | int | float | bool | None]],
- Optional[str | int | float | bool | None],
- ] = default_conflict_resolver,
+ conflict_resolver=None,
analytics_collector: Optional[WorkflowAnalyticsCollector] = None,
use_hierarchical_extraction: bool = False,
+ use_structured_output: bool = False,
logger: Optional[logging.Logger] = None,
+ counting_llm_client: Optional[BaseLLMClient] = None,
):
- """
- Initializes the WorkflowOrchestrator.
+ self.logger = logger or self._create_default_logger()
- Args:
- root_sqlmodel_class: The primary SQLModel class for extraction.
- llm_client: An instance or a list of LLM clients.
- num_llm_revisions: The number of JSON revisions to request from the LLM for consensus.
- max_validation_retries_per_revision: Max retries for LLM output validation per revision.
- consensus_threshold: The agreement threshold for the consensus mechanism (0.0 to 1.0).
- conflict_resolver: A function to resolve disagreements during the consensus process.
- analytics_collector: An optional collector for workflow analytics.
- use_hierarchical_extraction: If True, enables the hierarchical extraction workflow
- for models with nested relationships.
- logger: An optional logger instance. If not provided, a default logger is created.
- """
- self.logger = logger or logging.getLogger(self.__class__.__name__)
- if not logger:
- self.logger.setLevel(logging.WARNING)
-
- self._validate_init_parameters(
- root_sqlmodel_class,
- num_llm_revisions,
- max_validation_retries_per_revision,
- consensus_threshold,
- )
- self._setup_llm_clients(llm_client)
- self._discover_models_and_generate_schema(root_sqlmodel_class)
-
- self.llm_client_index = 0
- self.llm_client = self.llm_clients[0]
- self.num_llm_revisions = num_llm_revisions
- self.max_validation_retries_per_revision = max_validation_retries_per_revision
- self.root_sqlmodel_class = root_sqlmodel_class
- self.use_hierarchical_extraction = use_hierarchical_extraction
-
- if self.use_hierarchical_extraction:
- self.logger.warning(
- "Hierarchical extraction is enabled. "
- "This may significantly increase LLM API calls and processing time "
- "based on model complexity and the number of entities."
- )
+ # Initialize registry first (validates root model)
+ self.model_registry = ModelRegistry(root_sqlmodel_class, self.logger)
- self.json_consensus = JSONConsensus(
+ # Create shared config
+ self.config = ExtractionConfig(
+ num_llm_revisions=num_llm_revisions,
+ max_validation_retries_per_revision=max_validation_retries_per_revision,
consensus_threshold=consensus_threshold,
conflict_resolver=conflict_resolver,
- logger=self.logger,
+ use_hierarchical_extraction=use_hierarchical_extraction,
+ use_structured_output=use_structured_output,
)
- if analytics_collector is None:
- self.analytics_collector = WorkflowAnalyticsCollector(logger=self.logger)
- else:
- self.analytics_collector = analytics_collector
+ # Initialize components
+ self.analytics_collector = analytics_collector or WorkflowAnalyticsCollector(
+ logger=self.logger
+ )
- def _validate_init_parameters(
- self,
- root_sqlmodel_class: Type[SQLModel],
- num_llm_revisions: int,
- max_validation_retries_per_revision: int,
- consensus_threshold: float,
- ):
- """Validates the initial parameters for the orchestrator."""
- if not root_sqlmodel_class or not issubclass(root_sqlmodel_class, SQLModel):
- raise ConfigurationError(
- "root_sqlmodel_class must be a valid SQLModel class."
- )
- if num_llm_revisions < 1:
- raise ConfigurationError("Number of LLM revisions must be at least 1.")
- if max_validation_retries_per_revision < 1:
- raise ConfigurationError(
- "Max validation retries per revision must be at least 1."
- )
- if not (0.0 <= consensus_threshold <= 1.0):
- raise ConfigurationError(
- "Extrai threshold must be between 0.0 and 1.0 inclusive."
- )
+ self.pipeline = ExtractionPipeline(
+ model_registry=self.model_registry,
+ llm_client=llm_client,
+ config=self.config,
+ analytics_collector=self.analytics_collector,
+ logger=self.logger,
+ counting_llm_client=counting_llm_client,
+ )
- def _setup_llm_clients(self, llm_client: Union[BaseLLMClient, List[BaseLLMClient]]):
- """Sets up the LLM clients list."""
- if isinstance(llm_client, list):
- if not all(isinstance(c, BaseLLMClient) for c in llm_client):
- raise ConfigurationError(
- "All items in llm_client list must be instances of BaseLLMClient."
- )
- if not llm_client:
- raise ConfigurationError("llm_client list cannot be empty.")
- self.llm_clients = llm_client
- elif isinstance(llm_client, BaseLLMClient):
- self.llm_clients = [llm_client]
- else:
- raise ConfigurationError(
- "llm_client must be an instance of BaseLLMClient or a list of them."
- )
- for client in self.llm_clients:
- client.logger = self.logger
-
- def _discover_models_and_generate_schema(self, root_sqlmodel_class: Type[SQLModel]):
- """Discovers SQLModels and generates the JSON schema for the LLM."""
- try:
- self.sqla_model_classes = discover_sqlmodels_from_root(root_sqlmodel_class)
- except Exception as e:
- raise ConfigurationError(f"Failed to discover SQLModel classes: {e}") from e
-
- if not self.sqla_model_classes:
- raise ConfigurationError(
- "No SQLModel classes were discovered from the root model."
- )
+ self.batch_pipeline = BatchPipeline(
+ model_registry=self.model_registry,
+ llm_client=llm_client,
+ config=self.config,
+ analytics_collector=self.analytics_collector,
+ logger=self.logger,
+ counting_llm_client=counting_llm_client,
+ )
- self.model_schema_map_for_hydration = {
- model_cls.__name__: model_cls for model_cls in self.sqla_model_classes
- }
+ self.result_processor = ResultProcessor(
+ model_registry=self.model_registry,
+ analytics_collector=self.analytics_collector,
+ logger=self.logger,
+ )
- try:
- generated_prompt_schema_str = generate_llm_schema_from_models(
- initial_model_classes=self.sqla_model_classes
- )
- if not generated_prompt_schema_str:
- raise ConfigurationError(
- "Generated target_json_schema_for_llm (prompt schema) is empty."
- )
- json.loads(generated_prompt_schema_str)
- self.target_json_schema_for_llm = generated_prompt_schema_str
- except json.JSONDecodeError as e:
- raise ConfigurationError(
- f"The internally generated LLM prompt JSON schema is not valid: {e}."
- )
- except Exception as e:
- raise ConfigurationError(
- f"Failed to generate the LLM prompt JSON schema: {e}"
- ) from e
-
- def _get_next_llm_client(self) -> BaseLLMClient:
- """Rotates through the list of LLM clients and returns the next one."""
- client = self.llm_clients[self.llm_client_index]
- self.llm_client_index = (self.llm_client_index + 1) % len(self.llm_clients)
- return client
-
- async def _prepare_extraction_example(self, extraction_example_json: str) -> str:
- """Prepares the extraction example, auto-generating it if necessary."""
- if extraction_example_json:
- return extraction_example_json
-
- try:
- llm_client_for_example = self._get_next_llm_client()
- example_generator = ExampleJSONGenerator(
- llm_client=llm_client_for_example,
- output_model=self.root_sqlmodel_class,
- analytics_collector=self.analytics_collector,
- max_validation_retries_per_revision=self.max_validation_retries_per_revision,
- logger=self.logger,
- )
- self.logger.info(
- f"Attempting to auto-generate extraction example for {self.root_sqlmodel_class.__name__}..."
- )
- generated_example = await example_generator.generate_example()
- if self.analytics_collector:
- self.analytics_collector.record_custom_event(
- "example_json_auto_generation_success"
- )
- self.logger.info("Successfully auto-generated extraction example.")
- return generated_example
- except ExampleGenerationError as e:
- if self.analytics_collector:
- self.analytics_collector.record_custom_event(
- "example_json_auto_generation_failure"
- )
- raise WorkflowError(
- f"Failed to auto-generate extraction example: {e}"
- ) from e
- except Exception as e:
- if self.analytics_collector:
- self.analytics_collector.record_custom_event(
- "example_json_auto_generation_unexpected_failure"
- )
- raise WorkflowError(
- f"An unexpected error occurred during auto-generation of extraction example: {e}"
- ) from e
+ def _create_default_logger(self) -> logging.Logger:
+ logger = logging.getLogger(self.__class__.__name__)
+ logger.setLevel(logging.WARNING)
+ return logger
+
+ # ==================== Standard Extraction ====================
async def synthesize(
self,
input_strings: List[str],
- db_session_for_hydration: Optional[Session],
+ db_session_for_hydration: Optional[Session] = None,
extraction_example_json: str = "",
extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None,
custom_extraction_process: str = "",
custom_extraction_guidelines: str = "",
custom_final_checklist: str = "",
+ custom_context: str = "",
+ count_entities: bool = False,
+ custom_counting_context: str = "",
) -> List[Any]:
- """
- Executes the full pipeline: input strings -> LLM -> consensus -> SQLAlchemy objects.
-
- Args:
- input_strings: A list of input strings for data extraction.
- db_session_for_hydration: SQLAlchemy session for the hydrator.
- extraction_example_json: Optional JSON string for few-shot prompting.
- extraction_example_object: Optional SQLModel object or list of objects to use as example.
- custom_extraction_process: Optional custom instructions for LLM extraction process.
- custom_extraction_guidelines: Optional custom guidelines for LLM extraction.
- custom_final_checklist: Optional custom final checklist for LLM.
-
- Returns:
- A list of hydrated SQLAlchemy object instances.
-
- Raises:
- ValueError: If input_strings is empty.
- LLMInteractionError, ConsensusProcessError, HydrationError, WorkflowError: For pipeline failures.
- """
+ """Executes extraction pipeline and returns hydrated objects."""
if not input_strings:
raise ValueError("Input strings list cannot be empty.")
- if extraction_example_object and not extraction_example_json:
- objects_to_process = (
- extraction_example_object
- if isinstance(extraction_example_object, list)
- else [extraction_example_object]
- )
- processed_objects = []
- for obj in objects_to_process:
- if isinstance(obj, SQLModel):
- processed_objects.append(obj.model_dump(mode="json"))
- else:
- self.logger.warning(
- f"Skipping unsupported object type in extraction_example_object: {type(obj)}"
- )
- if processed_objects:
- extraction_example_json = json.dumps(
- processed_objects, default=str, indent=2
- )
-
- self.logger.info(
- f"Starting synthesis for {self.root_sqlmodel_class.__name__}..."
- )
- current_extraction_example_json = await self._prepare_extraction_example(
- extraction_example_json
+ # Extract to consensus JSON
+ consensus_results = await self.pipeline.extract(
+ input_strings=input_strings,
+ extraction_example_json=extraction_example_json,
+ extraction_example_object=extraction_example_object,
+ custom_extraction_process=custom_extraction_process,
+ custom_extraction_guidelines=custom_extraction_guidelines,
+ custom_final_checklist=custom_final_checklist,
+ custom_context=custom_context,
+ count_entities=count_entities,
+ custom_counting_context=custom_counting_context,
)
- if self.use_hierarchical_extraction:
- final_list_for_hydration = await self._execute_hierarchical_extraction(
- input_strings=input_strings,
- current_extraction_example_json=current_extraction_example_json,
- custom_extraction_process=custom_extraction_process,
- custom_extraction_guidelines=custom_extraction_guidelines,
- custom_final_checklist=custom_final_checklist,
- )
- else:
- final_list_for_hydration = await self._execute_standard_extraction(
- input_strings=input_strings,
- current_extraction_example_json=current_extraction_example_json,
- custom_extraction_process=custom_extraction_process,
- custom_extraction_guidelines=custom_extraction_guidelines,
- custom_final_checklist=custom_final_checklist,
- )
-
- return self._hydrate_results(final_list_for_hydration, db_session_for_hydration)
-
- def _process_consensus_output(
- self, consensus_output: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]
- ) -> List[Dict[str, Any]]:
- """Processes the raw output from the consensus mechanism."""
- if consensus_output is None:
- return []
- if isinstance(consensus_output, list):
- return consensus_output
- if isinstance(consensus_output, dict):
- if "results" in consensus_output and isinstance(
- consensus_output["results"], list
- ):
- return consensus_output["results"]
- return [consensus_output]
-
- raise ConsensusProcessError(
- f"Unexpected type from json_consensus.get_consensus: {type(consensus_output)}."
+ # Hydrate results
+ return self.result_processor.hydrate(
+ consensus_results, db_session_for_hydration
)
- def _hydrate_results(
+ async def synthesize_and_save(
self,
- final_list_for_hydration: List[Dict[str, Any]],
- db_session_for_hydration: Optional[Session],
+ input_strings: List[str],
+ db_session: Session,
+ extraction_example_json: str = "",
+ extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None,
+ custom_extraction_process: str = "",
+ custom_extraction_guidelines: str = "",
+ custom_final_checklist: str = "",
+ custom_context: str = "",
+ count_entities: bool = False,
+ custom_counting_context: str = "",
) -> List[Any]:
- """Hydrates the final list of dictionaries into SQLModel objects."""
- session_to_use = db_session_for_hydration
- if session_to_use is None:
- # Create a temporary, in-memory SQLite database if no session is provided.
- engine = create_engine("sqlite:///:memory:")
- SQLModel.metadata.create_all(engine)
- session_to_use = Session(engine)
-
- hydrator = SQLAlchemyHydrator(session=session_to_use)
- try:
- self.logger.info(
- f"Starting hydration for {len(final_list_for_hydration)} consensus objects."
- )
- hydrated_objects = hydrator.hydrate(
- final_list_for_hydration, self.model_schema_map_for_hydration
- )
- self.analytics_collector.record_hydration_success(len(hydrated_objects))
- self.logger.info(
- f"Successfully hydrated {len(hydrated_objects)} SQLModel objects."
- )
- if db_session_for_hydration is None:
- session_to_use.close()
- return hydrated_objects
- except Exception as e:
- self.analytics_collector.record_hydration_failure()
- if db_session_for_hydration is None and session_to_use:
- session_to_use.close()
- raise HydrationError(
- f"Failed during SQLAlchemy object hydration: {e}"
- ) from e
-
- def _generate_contextual_prompt_for_hierarchical(
- self,
- model_name: str,
- results_store: Dict[Tuple[str, str], Dict[str, Any]],
- ) -> str:
- """Generates the contextual prompt for a hierarchical extraction step."""
- custom_context = f"Your current task is to extract **only** entities of type '{model_name}'. Do not extract any other types of entities in this step."
+ """Synthesizes and persists objects in a single transaction."""
+ hydrated_objects = await self.synthesize(
+ input_strings=input_strings,
+ db_session_for_hydration=db_session,
+ extraction_example_json=extraction_example_json,
+ extraction_example_object=extraction_example_object,
+ custom_extraction_process=custom_extraction_process,
+ custom_extraction_guidelines=custom_extraction_guidelines,
+ custom_final_checklist=custom_final_checklist,
+ custom_context=custom_context,
+ count_entities=count_entities,
+ custom_counting_context=custom_counting_context,
+ )
- if results_store:
- custom_context += "\n\nSo far, the following entities have been extracted. Use them as context to establish relationships:\n"
- custom_context += json.dumps(list(results_store.values()), indent=2)
+ if hydrated_objects:
+ self.result_processor.persist(hydrated_objects, db_session)
+
+ return hydrated_objects
- return custom_context
+ # ==================== Batch Extraction ====================
- async def _execute_hierarchical_extraction(
+ async def synthesize_batch(
self,
input_strings: List[str],
- current_extraction_example_json: str,
- custom_extraction_process: str,
- custom_extraction_guidelines: str,
- custom_final_checklist: str,
- ) -> List[Dict[str, Any]]:
- """Executes the hierarchical extraction process, processing each model type in order."""
- self.logger.info("Executing hierarchical extraction process...")
- models_to_process = discover_sqlmodels_from_root(self.root_sqlmodel_class)
- results_store: Dict[Tuple[str, str], Dict[str, Any]] = {}
-
- for model_class in models_to_process:
- model_name = model_class.__name__
- self.logger.info(f"Hierarchical step: Processing model '{model_name}'...")
-
- schema_json = generate_llm_schema_from_models([model_class])
- custom_context = self._generate_contextual_prompt_for_hierarchical(
- model_name, results_store
- )
-
- system_prompt = generate_system_prompt(
- schema_json=schema_json,
- extraction_example_json=current_extraction_example_json,
- custom_extraction_process=custom_extraction_process,
- custom_extraction_guidelines=custom_extraction_guidelines,
- custom_final_checklist=custom_final_checklist,
- custom_context=custom_context,
- )
- user_prompt = generate_user_prompt_for_docs(
- documents=input_strings, custom_context=custom_context
- )
-
- extracted_entities = await self._run_single_extraction_cycle(
- system_prompt, user_prompt
- )
-
- for entity in extracted_entities:
- temp_id = entity.get("_temp_id")
- if not temp_id:
- continue
- result_key = (model_name, temp_id)
- if result_key not in results_store:
- results_store[result_key] = entity
- self.logger.info(
- f"Hierarchical step for '{model_name}' completed. "
- f"Total entities in store: {len(results_store)}"
- )
-
- self.logger.info("Hierarchical extraction finished.")
- return list(results_store.values())
-
- async def _run_single_extraction_cycle(
- self, system_prompt: str, user_prompt: str
- ) -> List[Dict[str, Any]]:
- """Runs a single extraction cycle, including LLM revisions and consensus."""
- tasks = []
- try:
- for _ in range(self.num_llm_revisions):
- client_for_revision = self._get_next_llm_client()
- task = asyncio.create_task(
- client_for_revision.generate_json_revisions(
- system_prompt=system_prompt,
- user_prompt=user_prompt,
- num_revisions=1,
- model_schema_map=self.model_schema_map_for_hydration,
- max_validation_retries_per_revision=self.max_validation_retries_per_revision,
- analytics_collector=self.analytics_collector,
- )
- )
- tasks.append(task)
-
- llm_json_revisions = await asyncio.gather(*tasks)
-
- self.logger.debug(
- f"llm_json_revisions before consensus: {llm_json_revisions}"
- )
+ db_session: Session,
+ extraction_example_json: str = "",
+ extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None,
+ custom_extraction_process: str = "",
+ custom_extraction_guidelines: str = "",
+ custom_final_checklist: str = "",
+ custom_context: str = "",
+ count_entities: bool = False,
+ custom_counting_context: str = "",
+ wait_for_completion: bool = False,
+ poll_interval: int = 60,
+ ) -> Union[str, BatchProcessResult]:
+ """Submits a batch job.
- if not llm_json_revisions and self.num_llm_revisions > 0:
- raise LLMInteractionError(
- "LLM client returned no revisions despite being requested."
- )
- except (
- LLMConfigurationError,
- LLMOutputParseError,
- LLMOutputValidationError,
- LLMAPICallError,
- ) as client_err:
- raise LLMInteractionError(
- f"LLM client operation failed: {client_err}"
- ) from client_err
- except Exception as e:
- raise LLMInteractionError(
- f"An unexpected error occurred during LLM interaction: {e}"
- ) from e
-
- try:
- consensus_output, consensus_analytics_details = (
- self.json_consensus.get_consensus(llm_json_revisions)
- )
- self.logger.debug(
- f"consensus_output from get_consensus: {consensus_output}"
- )
- if consensus_analytics_details:
- self.analytics_collector.record_consensus_run_details(
- consensus_analytics_details
- )
+ Args:
+ input_strings: List of strings to process
+ db_session: Database session for state persistence
+ ...
+ wait_for_completion: If True, waits for the batch job (and any hierarchical steps) to complete.
+ poll_interval: Interval in seconds to poll for status if wait_for_completion is True.
- processed_output = self._process_consensus_output(consensus_output)
- self.logger.debug(f"processed_output before hydration: {processed_output}")
- return processed_output
- except ConsensusProcessError:
- raise
- except Exception as e:
- raise ConsensusProcessError(
- f"Failed during JSON consensus processing: {e}"
- ) from e
-
- async def _execute_standard_extraction(
- self,
- input_strings: List[str],
- current_extraction_example_json: str,
- custom_extraction_process: str,
- custom_extraction_guidelines: str,
- custom_final_checklist: str,
- ) -> List[Dict[str, Any]]:
- """Executes the standard extraction process."""
- self.logger.info("Executing standard extraction...")
- system_prompt = generate_system_prompt(
- schema_json=self.target_json_schema_for_llm,
- extraction_example_json=current_extraction_example_json,
+ Returns:
+ root_batch_id (str) if wait_for_completion is False.
+ BatchProcessResult if wait_for_completion is True.
+ """
+ root_batch_id = await self.batch_pipeline.submit_batch(
+ db_session=db_session,
+ input_strings=input_strings,
+ extraction_example_json=extraction_example_json,
+ extraction_example_object=extraction_example_object,
custom_extraction_process=custom_extraction_process,
custom_extraction_guidelines=custom_extraction_guidelines,
custom_final_checklist=custom_final_checklist,
+ custom_context=custom_context,
+ count_entities=count_entities,
+ custom_counting_context=custom_counting_context,
)
- user_prompt = generate_user_prompt_for_docs(documents=input_strings)
- logging.info(f"System Prompt: {system_prompt}")
- logging.info(f"User Prompt: {user_prompt}")
+ if wait_for_completion:
+ return await self.monitor_batch_job(
+ root_batch_id, db_session, poll_interval
+ )
- return await self._run_single_extraction_cycle(system_prompt, user_prompt)
+ return root_batch_id
- async def synthesize_and_save(
+ async def create_continuation_batch(
self,
- input_strings: List[str],
+ original_batch_id: str,
db_session: Session,
+ start_from_step_index: int,
extraction_example_json: str = "",
extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None,
custom_extraction_process: str = "",
custom_extraction_guidelines: str = "",
custom_final_checklist: str = "",
- ) -> List[Any]:
+ custom_context: str = "",
+ count_entities: bool = False,
+ custom_counting_context: str = "",
+ wait_for_completion: bool = False,
+ poll_interval: int = 60,
+ ) -> Union[str, BatchProcessResult]:
"""
- Synthesizes SQLAlchemy objects and persists them to the database.
- This method manages the transaction via the provided db_session.
+ Creates a new batch cycle continuing from a previous batch's state.
+ Copies completed steps up to start_from_step_index into the new batch.
+ Accepts all configuration parameters to update the job logic.
"""
- hydrated_objects = await self.synthesize(
- input_strings=input_strings,
- db_session_for_hydration=db_session,
- extraction_example_json=extraction_example_json,
- extraction_example_object=extraction_example_object,
- custom_extraction_process=custom_extraction_process,
- custom_extraction_guidelines=custom_extraction_guidelines,
- custom_final_checklist=custom_final_checklist,
+ # Prepare example
+ example_json = await self.batch_pipeline.context_preparer.prepare_example(
+ extraction_example_json,
+ extraction_example_object,
+ self.batch_pipeline.client_rotator.get_next_client,
)
- if hydrated_objects:
+ config_data = {
+ "extraction_example_json": example_json,
+ "custom_extraction_process": custom_extraction_process,
+ "custom_extraction_guidelines": custom_extraction_guidelines,
+ "custom_final_checklist": custom_final_checklist,
+ "custom_context": custom_context,
+ "count_entities": count_entities,
+ "custom_counting_context": custom_counting_context,
+ "schema_json": self.model_registry.llm_schema_json,
+ }
+
+ new_batch_id = await self.batch_pipeline.create_continuation_batch(
+ db_session, original_batch_id, config_data, start_from_step_index
+ )
+
+ if wait_for_completion:
+ return await self.monitor_batch_job(new_batch_id, db_session, poll_interval)
+
+ return new_batch_id
+
+ async def get_batch_status(
+ self, root_batch_id: str, db_session: Session
+ ) -> BatchJobStatus:
+ """Retrieves current batch job status."""
+ return await self.batch_pipeline.get_status(root_batch_id, db_session)
+
+ async def process_batch(
+ self,
+ root_batch_id: str,
+ db_session: Session,
+ ) -> "BatchProcessResult":
+ """Processes a completed batch job and persists results."""
+ result = await self.batch_pipeline.process_batch(
+ root_batch_id,
+ db_session,
+ )
+
+ if result.status.name == "COMPLETED" and result.hydrated_objects:
try:
- persist_objects(
- db_session=db_session,
- objects_to_persist=hydrated_objects,
- logger=self.logger,
- )
- except DatabaseWriterError:
- db_session.rollback()
- raise
+ # Add the PK map from the batch pipeline to the main result processor
+ if result.original_pk_map:
+ self.result_processor.original_pk_map.update(result.original_pk_map)
+
+ self.result_processor.persist(result.hydrated_objects, db_session)
except Exception as e:
- db_session.rollback()
- raise WorkflowError(
- f"An unexpected error occurred during database persistence phase: {e}"
- ) from e
- else:
- self.logger.info(
- "WorkflowOrchestrator: No objects were hydrated, thus nothing to persist."
- )
+ self.logger.error(f"Persistence failed for batch {root_batch_id}: {e}")
+ result.message = f"Extraction successful but persistence failed: {e}"
+ raise
- return hydrated_objects
+ return result
- def get_analytics_report(self) -> Dict[str, Any]:
+ async def monitor_batch_job(
+ self, root_batch_id: str, db_session: Session, poll_interval: int = 60
+ ) -> "BatchProcessResult":
"""
- Retrieves the analytics report from the associated collector.
+ Polls the batch job status until it reaches a terminal state.
+ Automatically handles hierarchical extraction steps by re-polling
+ if an intermediate step is submitted.
+
+ Useful for scripts or simple workflows where blocking is acceptable.
"""
+ self.logger.info(f"Monitoring batch job {root_batch_id}...")
+
+ while True:
+ status = await self.get_batch_status(root_batch_id, db_session)
+ self.logger.info(f"Batch Status: {status}")
+
+ if status in [
+ BatchJobStatus.READY_TO_PROCESS,
+ BatchJobStatus.COUNTING_READY_TO_PROCESS,
+ ]:
+ self.logger.info("Batch ready! Processing...")
+ result = await self.process_batch(root_batch_id, db_session)
+
+ if result.status == BatchJobStatus.COMPLETED:
+ self.logger.info("Batch workflow completed successfully.")
+ return result
+
+ elif result.status in [
+ BatchJobStatus.PROCESSING,
+ BatchJobStatus.SUBMITTED,
+ ]:
+ self.logger.info(
+ f"Intermediate step processed (new status: {result.status}). Continuing workflow..."
+ )
+ continue
+
+ else:
+ self.logger.error(f"Batch processing failed: {result.message}")
+ return result
+
+ elif status in [
+ BatchJobStatus.COMPLETED,
+ BatchJobStatus.FAILED,
+ BatchJobStatus.CANCELLED,
+ ]:
+ # If it's already COMPLETED (e.g. checked before monitoring started), retrieve results
+ if status == BatchJobStatus.COMPLETED:
+ self.logger.info("Batch already completed. Retrieving results...")
+ return await self.process_batch(root_batch_id, db_session)
+
+ self.logger.error(f"Batch job ended with status: {status}")
+ return BatchProcessResult(
+ status=status, message=f"Batch ended with status: {status}"
+ )
+
+ await asyncio.sleep(poll_interval)
+
+ # ==================== Analytics ====================
+
+ def get_analytics_report(self) -> Dict[str, Any]:
+ """Retrieves analytics report."""
return self.analytics_collector.get_report()
def get_analytics_collector(self) -> WorkflowAnalyticsCollector:
- """
- Returns the instance of the analytics collector.
- """
+ """Returns the analytics collector instance."""
return self.analytics_collector
diff --git a/src/extrai/llm_providers/deepseek_client.py b/src/extrai/llm_providers/deepseek_client.py
index 0eb2be1..92b6a4c 100644
--- a/src/extrai/llm_providers/deepseek_client.py
+++ b/src/extrai/llm_providers/deepseek_client.py
@@ -12,7 +12,7 @@ def __init__(
self,
api_key: str,
model_name: str = "deepseek-chat",
- base_url: str = "https://api.deepseek.com/v1",
+ base_url: str = "https://api.deepseek.com",
temperature: Optional[float] = 0.3,
logger: Optional[logging.Logger] = None,
):
diff --git a/src/extrai/llm_providers/gemini_client.py b/src/extrai/llm_providers/gemini_client.py
index 0b3efe2..3e948f7 100644
--- a/src/extrai/llm_providers/gemini_client.py
+++ b/src/extrai/llm_providers/gemini_client.py
@@ -1,6 +1,10 @@
import logging
-from typing import Optional
+import json
+from typing import Optional, Dict, Any, List
+from extrai.utils.rate_limiter import AsyncRateLimiter
from .generic_openai_client import GenericOpenAIClient
+from extrai.core.errors import LLMAPICallError
+from extrai.core.analytics_collector import WorkflowAnalyticsCollector
class GeminiClient(GenericOpenAIClient):
@@ -16,6 +20,8 @@ def __init__(
base_url: str = "https://generativelanguage.googleapis.com/v1beta/",
temperature: Optional[float] = 0.3,
logger: Optional[logging.Logger] = None,
+ requests_per_minute: int = 15,
+ tokens_per_minute: int = 32000,
):
"""
Initializes the GeminiClient.
@@ -26,6 +32,8 @@ def __init__(
base_url: The base URL for the Gemini API (OpenAI-compatible endpoint).
temperature: The sampling temperature for generation.
logger: Logger.
+ requests_per_minute: Maximum number of requests allowed per minute.
+ tokens_per_minute: Maximum number of input tokens allowed per minute.
"""
super().__init__(
api_key=api_key,
@@ -34,3 +42,311 @@ def __init__(
temperature=temperature,
logger=logger,
)
+ self.request_limiter = AsyncRateLimiter(
+ max_capacity=requests_per_minute, period=60.0
+ )
+ self.token_limiter = AsyncRateLimiter(
+ max_capacity=tokens_per_minute, period=60.0
+ )
+ self.logger = logger
+
+ async def _execute_llm_call(
+ self,
+ system_prompt: str,
+ user_prompt: str,
+ analytics_collector: Optional[WorkflowAnalyticsCollector] = None,
+ ) -> str:
+ """
+ Executes the LLM call with rate limiting.
+ """
+ # Estimate token count (simple character heuristic)
+ # 1 token ~= 4 chars
+ estimated_tokens = (len(system_prompt) + len(user_prompt)) // 4
+ # Minimum 1 token
+ estimated_tokens = max(1, estimated_tokens)
+
+ self.logger.warning("estimated tokens: " + str(estimated_tokens))
+ # Acquire rate limits
+ await self.request_limiter.acquire(1)
+ await self.token_limiter.acquire(estimated_tokens)
+
+ return await super()._execute_llm_call(
+ system_prompt, user_prompt, analytics_collector=analytics_collector
+ )
+
+ def _sanitize_schema_for_gemini(self, schema: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Ensures JSON schema compatibility with Gemini REST API by inlining $defs.
+ Gemini REST API does not support $defs/$ref in schema payloads.
+ This implements a dependency-free version of the 'jsonref' workaround.
+ """
+ import copy
+
+ schema = copy.deepcopy(schema)
+ defs = schema.pop("$defs", {}) or schema.pop("definitions", {})
+
+ def _resolve(node: Any) -> Any:
+ if isinstance(node, dict):
+ if "$ref" in node:
+ ref = node["$ref"].split("/")[-1]
+ if ref in defs:
+ return _resolve(defs[ref])
+ return {k: _resolve(v) for k, v in node.items()}
+ elif isinstance(node, list):
+ return [_resolve(x) for x in node]
+ return node
+
+ return _resolve(schema)
+
+ async def create_batch_job(
+ self,
+ requests: List[Dict[str, Any]],
+ endpoint: str = None,
+ completion_window: str = None,
+ metadata: Optional[Dict[str, str]] = None,
+ ) -> Any:
+ """
+ Creates a Gemini batch job using the native REST API (Inline Requests).
+ """
+ import httpx
+
+ # Convert requests to Gemini 'contents' format
+ gemini_requests = []
+ for i, req in enumerate(requests):
+ body = req.get("body", req)
+ custom_id = req.get("custom_id", f"req-{i}")
+
+ messages = body.get("messages", [])
+ contents = []
+ system_instruction = None
+
+ for msg in messages:
+ role = msg.get("role")
+ content = msg.get("content")
+ if role == "system":
+ system_instruction = {"parts": [{"text": content}]}
+ elif role == "user":
+ contents.append({"role": "user", "parts": [{"text": content}]})
+ elif role == "assistant":
+ contents.append({"role": "model", "parts": [{"text": content}]})
+
+ # Construct the request object
+ # Note: We need to ensure we use the correct model format
+ # API expects model resource name in URL usually, but can also be in request?
+ # Inline requests structure: { "request": { ... }, "metadata": ... }
+
+ # Map configuration
+ generation_config = {}
+ if "temperature" in body:
+ generation_config["temperature"] = body["temperature"]
+ if "max_tokens" in body:
+ generation_config["maxOutputTokens"] = body["max_tokens"]
+
+ # Map OpenAI response_format to Gemini generationConfig
+ response_format = body.get("response_format", {})
+ if response_format.get("type") == "json_schema":
+ generation_config["responseMimeType"] = "application/json"
+ if (
+ "json_schema" in response_format
+ and "schema" in response_format["json_schema"]
+ ):
+ raw_schema = response_format["json_schema"]["schema"]
+ generation_config["responseJsonSchema"] = (
+ self._sanitize_schema_for_gemini(raw_schema)
+ )
+ elif response_format.get("type") == "json_object":
+ generation_config["responseMimeType"] = "application/json"
+
+ g_req_inner = {"contents": contents, "generationConfig": generation_config}
+ if system_instruction:
+ g_req_inner["system_instruction"] = system_instruction
+
+ gemini_requests.append(
+ {"request": g_req_inner, "metadata": {"key": custom_id}}
+ )
+
+ # Construct Payload
+ payload = {
+ "batch": {"input_config": {"requests": {"requests": gemini_requests}}}
+ }
+ if metadata and "display_name" in metadata:
+ payload["batch"]["display_name"] = metadata["display_name"]
+
+ url = f"{self.base_url}models/{self.model_name}:batchGenerateContent?key={self.api_key}"
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.post(
+ url, json=payload, headers={"Content-Type": "application/json"}
+ )
+ if resp.status_code >= 400:
+ raise LLMAPICallError(
+ f"Gemini Batch Creation Failed: {resp.status_code} - {resp.text}"
+ )
+
+ return self._wrap_batch_response(resp.json())
+
+ async def retrieve_batch_job(self, batch_id: str) -> Any:
+ """
+ Retrieves batch status using Native REST API.
+ """
+ import httpx
+
+ # batch_id is expected to be the full resource name e.g., "batches/12345"
+ url = f"{self.base_url}{batch_id}?key={self.api_key}"
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.get(url)
+ if resp.status_code >= 400:
+ raise LLMAPICallError(
+ f"Gemini Batch Retrieve Failed: {resp.status_code} - {resp.text}"
+ )
+ return self._wrap_batch_response(resp.json())
+
+ async def cancel_batch_job(self, batch_id: str) -> Any:
+ """
+ Cancels batch job using Native REST API.
+ """
+ import httpx
+
+ url = f"{self.base_url}{batch_id}:cancel?key={self.api_key}"
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.post(url)
+ if resp.status_code >= 400:
+ raise LLMAPICallError(
+ f"Gemini Batch Cancel Failed: {resp.status_code} - {resp.text}"
+ )
+ # Empty response usually on success? Or updated metadata.
+ # We can re-fetch or just return true/empty
+ return True
+
+ def _wrap_batch_response(self, data: Dict[str, Any]) -> Any:
+ class GeminiBatchJob:
+ def __init__(self, data):
+ self.id = data.get("name") # "batches/..."
+ # State might be at top level or inside metadata (depending on API version/endpoint)
+ self.status = data.get("state")
+ if not self.status and "metadata" in data:
+ self.status = data["metadata"].get("state")
+ self.original_data = data
+
+ # If finished, it might have results
+ # Inline results structure?
+ # The docs say: response.inlinedResponses
+ # We can try to extract output_file_id if it exists, or handle inline.
+ # OpenAI interface expects output_file_id for retrieve_batch_results.
+ # If inline, we can't provide a file ID.
+ # We'll need retrieve_batch_results to handle the batch_id as file_id for inline.
+
+ return GeminiBatchJob(data)
+
+ async def retrieve_batch_results(self, file_id: str) -> str:
+ """
+ Retrieves batch results.
+ For Gemini Inline, 'file_id' should be the batch ID.
+ """
+ # If the batch job had 'responsesFile', we download it.
+ # If it had 'inlinedResponses', we format it as JSONL.
+
+ # We need to fetch the batch first to see which one it is (or assume we have the object)
+ # But this method usually takes just an ID.
+ # So we fetch the batch.
+
+ batch = await self.retrieve_batch_job(file_id)
+ data = batch.original_data
+
+ # Check for inline responses
+ # Structure: data.get("response", {}).get("inlinedResponses", [])
+ # Actually docs say: batch_job.dest.inlined_responses (SDK) or .response.inlinedResponses (REST)
+
+ # REST: .response.inlinedResponses
+ response_section = data.get(
+ "response", {}
+ ) # Not to be confused with 'responses'
+ # Wait, the example output JSON says:
+ # "response": { "inlinedResponses": [ ... ] } OR "response": { "responsesFile": "..." }
+
+ inlined = response_section.get("inlinedResponses")
+ if inlined:
+ # Handle case where inlined might be a dict (unexpected but observed)
+ if isinstance(inlined, dict):
+ # If it's a dict, maybe the list is nested or it's a map?
+ self.logger.warning(
+ f"inlinedResponses is a dict, keys: {list(inlined.keys())}"
+ )
+ # Try to find the actual list
+ if "inlinedResponses" in inlined:
+ inlined = inlined["inlinedResponses"]
+ elif "responses" in inlined:
+ inlined = inlined["responses"]
+ elif "results" in inlined:
+ inlined = inlined["results"]
+ else:
+ # Fallback: treat values as the list if they look like items
+ inlined = list(inlined.values())
+
+ # Convert to JSONL string to match OpenAI format
+ lines = []
+ for item in inlined:
+ # item has 'response' or 'error' and 'requestKey' (if we used metadata.key)
+ # We should map back to OpenAI-like format if possible
+ if isinstance(item, str):
+ self.logger.warning(
+ f"Unexpected string item in inlinedResponses: {item}"
+ )
+ continue
+ lines.append(json.dumps(item))
+ return "\n".join(lines)
+
+ file_name = response_section.get("responsesFile")
+ if file_name:
+ # Download file
+ # url: https://generativelanguage.googleapis.com/download/v1beta/$responses_file_name:download?alt=media
+ import httpx
+
+ url = f"https://generativelanguage.googleapis.com/download/v1beta/{file_name}:download?alt=media&key={self.api_key}"
+ async with httpx.AsyncClient() as client:
+ resp = await client.get(url)
+ if resp.status_code >= 400:
+ raise LLMAPICallError(
+ f"Gemini Result Download Failed: {resp.status_code}"
+ )
+ return resp.text
+
+ raise LLMAPICallError("No results found in batch (or batch not complete).")
+
+ async def list_batch_jobs(
+ self, limit: int = 20, after: Optional[str] = None
+ ) -> Any:
+ import httpx
+
+ url = f"{self.base_url}batches?key={self.api_key}&pageSize={limit}"
+ if after:
+ url += f"&pageToken={after}"
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.get(url)
+ if resp.status_code >= 400:
+ raise LLMAPICallError(f"Gemini List Batches Failed: {resp.text}")
+
+ data = resp.json()
+ # Wrap list?
+ return data
+
+ def extract_content_from_batch_response(
+ self, response: Dict[str, Any]
+ ) -> Optional[str]:
+ """
+ Extracts content from Gemini batch response item.
+ """
+ if "error" in response:
+ self.logger.error(f"Batch item contains error: {response['error']}")
+ return None
+
+ if "response" in response and "candidates" in response["response"]:
+ candidates = response["response"]["candidates"]
+ if candidates and "content" in candidates[0]:
+ parts = candidates[0]["content"].get("parts", [])
+ if parts:
+ return parts[0].get("text")
+ return None
diff --git a/src/extrai/llm_providers/generic_openai_client.py b/src/extrai/llm_providers/generic_openai_client.py
index c38e716..afc9561 100644
--- a/src/extrai/llm_providers/generic_openai_client.py
+++ b/src/extrai/llm_providers/generic_openai_client.py
@@ -1,8 +1,11 @@
import logging
import openai
-from typing import Optional
+import json
+import io
+from typing import Optional, List, Dict, Any
from extrai.core.errors import LLMAPICallError
from extrai.core.base_llm_client import BaseLLMClient
+from extrai.core.analytics_collector import WorkflowAnalyticsCollector
class GenericOpenAIClient(BaseLLMClient):
@@ -38,13 +41,19 @@ def __init__(
)
self.client = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
- async def _execute_llm_call(self, system_prompt: str, user_prompt: str) -> str:
+ async def _execute_llm_call(
+ self,
+ system_prompt: str,
+ user_prompt: str,
+ analytics_collector: Optional[WorkflowAnalyticsCollector] = None,
+ ) -> str:
"""
Makes the actual API call to an OpenAI-compatible LLM.
Args:
system_prompt: The system prompt for the LLM.
user_prompt: The user prompt for the LLM.
+ analytics_collector: Optional analytics collector.
Returns:
The raw string content from the LLM response. Returns an empty string
@@ -68,6 +77,19 @@ async def _execute_llm_call(self, system_prompt: str, user_prompt: str) -> str:
else openai.NOT_GIVEN,
)
+ if (
+ analytics_collector
+ and hasattr(chat_completion, "usage")
+ and chat_completion.usage
+ ):
+ analytics_collector.record_llm_usage(
+ input_tokens=getattr(chat_completion.usage, "prompt_tokens", 0),
+ output_tokens=getattr(
+ chat_completion.usage, "completion_tokens", 0
+ ),
+ model=self.model_name,
+ )
+
response_content = chat_completion.choices[0].message.content
return response_content if response_content is not None else ""
@@ -89,3 +111,181 @@ async def _execute_llm_call(self, system_prompt: str, user_prompt: str) -> str:
raise LLMAPICallError(
f"Unexpected error during API call: {type(e).__name__} - {str(e)}"
) from e
+
+ async def generate_structured(
+ self,
+ system_prompt: str,
+ user_prompt: str,
+ response_model: Any,
+ analytics_collector: Optional[WorkflowAnalyticsCollector] = None,
+ **kwargs: Any,
+ ) -> Any:
+ """
+ Generates structured output using OpenAI's beta.chat.completions.parse.
+ """
+ try:
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.append({"role": "user", "content": user_prompt})
+
+ completion = await self.client.beta.chat.completions.parse(
+ model=self.model_name,
+ messages=messages,
+ response_format=response_model,
+ temperature=self.temperature
+ if self.temperature is not None
+ else openai.NOT_GIVEN,
+ **kwargs,
+ )
+
+ if (
+ analytics_collector
+ and hasattr(completion, "usage")
+ and completion.usage
+ ):
+ analytics_collector.record_llm_usage(
+ input_tokens=getattr(completion.usage, "prompt_tokens", 0),
+ output_tokens=getattr(completion.usage, "completion_tokens", 0),
+ model=self.model_name,
+ )
+
+ message = completion.choices[0].message
+ if message.refusal:
+ raise LLMAPICallError(
+ f"Model refused to generate structured output: {message.refusal}"
+ )
+
+ return message.parsed
+
+ except openai.APIError as e:
+ error_message = str(e)
+ if hasattr(e, "message") and e.message:
+ error_message = e.message
+ elif hasattr(e, "body") and e.body:
+ if "message" in e.body:
+ error_message = e.body["message"]
+ elif "error" in e.body and "message" in e.body["error"]:
+ error_message = e.body["error"]["message"]
+
+ status_code = e.status_code if hasattr(e, "status_code") else "N/A"
+ raise LLMAPICallError(
+ f"API call failed. Status: {status_code}. Error: {error_message}"
+ ) from e
+ except Exception as e:
+ raise LLMAPICallError(
+ f"Unexpected error during API call: {type(e).__name__} - {str(e)}"
+ ) from e
+
+ async def create_batch_job(
+ self,
+ requests: List[Dict[str, Any]],
+ endpoint: str = "/v1/chat/completions",
+ completion_window: str = "24h",
+ metadata: Optional[Dict[str, str]] = None,
+ ) -> Any:
+ """
+ Creates a batch job for processing multiple requests.
+ """
+ try:
+ # 1. Create JSONL content
+ jsonl_lines = []
+ for i, req in enumerate(requests):
+ # Check if the request is already in batch format
+ if "method" in req and "url" in req and "body" in req:
+ jsonl_lines.append(json.dumps(req))
+ else:
+ # Construct batch request object
+ # Extract custom_id if present in the body, otherwise generate one
+ body = req.copy()
+ custom_id = body.pop("custom_id", f"req-{i}")
+
+ batch_req = {
+ "custom_id": custom_id,
+ "method": "POST",
+ "url": endpoint,
+ "body": body,
+ }
+ jsonl_lines.append(json.dumps(batch_req))
+
+ jsonl_content = "\n".join(jsonl_lines)
+
+ # 2. Upload File
+ # Create a bytes buffer for the file content
+ file_obj = io.BytesIO(jsonl_content.encode("utf-8"))
+ # The API expects a 'name' attribute for the file-like object, or a tuple (name, content)
+ # file_obj.name = "batch_requests.jsonl"
+ # Using tuple syntax for file upload: (filename, content, content_type)
+ file_tuple = ("batch_requests.jsonl", file_obj, "application/json")
+
+ batch_input_file = await self.client.files.create(
+ file=file_tuple, purpose="batch"
+ )
+
+ # 3. Create Batch
+ batch_job = await self.client.batches.create(
+ input_file_id=batch_input_file.id,
+ endpoint=endpoint,
+ completion_window=completion_window,
+ metadata=metadata,
+ )
+
+ return batch_job
+
+ except openai.APIError as e:
+ raise LLMAPICallError(f"Batch creation failed: {e}") from e
+ except Exception as e:
+ raise LLMAPICallError(f"Unexpected error creating batch: {e}") from e
+
+ async def retrieve_batch_job(self, batch_id: str) -> Any:
+ """
+ Retrieves the status and details of a batch job.
+ """
+ try:
+ return await self.client.batches.retrieve(batch_id)
+ except openai.APIError as e:
+ raise LLMAPICallError(f"Failed to retrieve batch {batch_id}: {e}") from e
+
+ async def list_batch_jobs(
+ self, limit: int = 20, after: Optional[str] = None
+ ) -> Any:
+ """
+ Lists batch jobs.
+ """
+ try:
+ return await self.client.batches.list(limit=limit, after=after)
+ except openai.APIError as e:
+ raise LLMAPICallError(f"Failed to list batches: {e}") from e
+
+ async def cancel_batch_job(self, batch_id: str) -> Any:
+ """
+ Cancels a batch job.
+ """
+ try:
+ return await self.client.batches.cancel(batch_id)
+ except openai.APIError as e:
+ raise LLMAPICallError(f"Failed to cancel batch {batch_id}: {e}") from e
+
+ async def retrieve_batch_results(self, file_id: str) -> str:
+ """
+ Retrieves the content of a batch output file.
+ """
+ try:
+ content = await self.client.files.content(file_id)
+ return content.text
+ except openai.APIError as e:
+ raise LLMAPICallError(
+ f"Failed to retrieve file content {file_id}: {e}"
+ ) from e
+
+ def extract_content_from_batch_response(
+ self, response: Dict[str, Any]
+ ) -> Optional[str]:
+ """
+ Extracts content from OpenAI batch response item.
+ """
+ if "response" in response and "body" in response["response"]:
+ body = response["response"]["body"]
+ if "choices" in body and body["choices"]:
+ return body["choices"][0]["message"]["content"]
+ return None
diff --git a/src/extrai/utils/alignment_utils.py b/src/extrai/utils/alignment_utils.py
new file mode 100644
index 0000000..ee56ac0
--- /dev/null
+++ b/src/extrai/utils/alignment_utils.py
@@ -0,0 +1,179 @@
+from typing import Any, Dict, List
+from difflib import SequenceMatcher
+
+
+def normalize_json_revisions(revisions: List[Any]) -> List[Any]:
+ """
+ Aligns arrays across revisions using similarity-based matching.
+ Handles different structures and ensures consistent ordering.
+ """
+ if not revisions:
+ return revisions
+
+ # Check if all revisions are lists of dictionaries (entity arrays)
+ if all(
+ isinstance(rev, list) and rev and isinstance(rev[0], dict)
+ for rev in revisions
+ if rev
+ ):
+ return align_entity_arrays(revisions)
+
+ # Check if revisions have a "results" wrapper
+ if all(isinstance(rev, dict) and "results" in rev for rev in revisions):
+ results_arrays = [rev["results"] for rev in revisions]
+ aligned_results = align_entity_arrays(results_arrays)
+ # Reconstruct with aligned results
+ return [{"results": aligned} for aligned in aligned_results]
+
+ # Otherwise return as-is (single object extractions)
+ return revisions
+
+
+def align_entity_arrays(
+ arrays: List[List[Dict[str, Any]]],
+) -> List[List[Dict[str, Any]]]:
+ """
+ Aligns multiple arrays of entities so similar objects are in the same positions.
+ Uses the first array as reference and matches objects based on similarity.
+ """
+ if not arrays or not any(arrays):
+ return arrays
+
+ # Validate all arrays have the same length
+ lengths = [len(arr) for arr in arrays]
+ if len(set(lengths)) > 1:
+ print(
+ f"Warning: Arrays have different lengths {lengths}. Using minimum length."
+ )
+ min_length = min(lengths)
+ arrays = [arr[:min_length] for arr in arrays]
+
+ # Use first array as reference
+ reference = arrays[0]
+ aligned = [reference[:]]
+
+ # Align each subsequent array to match the reference
+ for arr in arrays[1:]:
+ reordered = []
+ used_indices = set()
+
+ for ref_obj in reference:
+ # Find best match in current array
+ best_idx = find_best_match(ref_obj, arr, used_indices)
+ reordered.append(arr[best_idx])
+ used_indices.add(best_idx)
+
+ aligned.append(reordered)
+
+ return aligned
+
+
+def find_best_match(
+ target: Dict[str, Any], candidates: List[Dict[str, Any]], used_indices: set
+) -> int:
+ """
+ Finds the index of the most similar object in candidates that hasn't been used.
+ """
+ best_idx = -1
+ best_score = -1.0
+
+ for idx, candidate in enumerate(candidates):
+ if idx in used_indices:
+ continue
+
+ score = calculate_similarity(target, candidate)
+ if score > best_score:
+ best_score = score
+ best_idx = idx
+
+ return best_idx
+
+
+def calculate_similarity(obj1: Dict[str, Any], obj2: Dict[str, Any]) -> float:
+ """
+ Calculates similarity score between two objects (0-1, higher is more similar).
+ Handles different field types recursively.
+ """
+ if not isinstance(obj1, dict) or not isinstance(obj2, dict):
+ return 1.0 if obj1 == obj2 else 0.0
+
+ # Check for ID fields first (quick exact match)
+ id1 = obj1.get("_temp_id") or obj1.get("id")
+ id2 = obj2.get("_temp_id") or obj2.get("id")
+ if id1 and id2 and str(id1) == str(id2):
+ return 1.0
+
+ # Get all unique fields
+ all_fields = set(obj1.keys()) | set(obj2.keys())
+ if not all_fields:
+ return 1.0
+
+ total_similarity = 0.0
+
+ for field in all_fields:
+ val1 = obj1.get(field)
+ val2 = obj2.get(field)
+
+ # If field missing in one object
+ if field not in obj1 or field not in obj2:
+ field_similarity = 0.0
+ else:
+ field_similarity = compare_values(val1, val2)
+
+ total_similarity += field_similarity
+
+ return total_similarity / len(all_fields)
+
+
+def compare_values(val1: Any, val2: Any) -> float:
+ """
+ Compares two values and returns similarity score (0-1).
+ """
+ # Handle None
+ if val1 is None and val2 is None:
+ return 1.0
+ if val1 is None or val2 is None:
+ return 0.0
+
+ # Prevent boolean vs number comparison (True == 1 is True in Python)
+ if isinstance(val1, bool) != isinstance(val2, bool):
+ return 0.0
+
+ # Exact equality
+ if val1 == val2:
+ return 1.0
+
+ # String comparison (fuzzy)
+ if isinstance(val1, str) and isinstance(val2, str):
+ # Case-insensitive comparison
+ if val1.strip().lower() == val2.strip().lower():
+ return 1.0
+ # Fuzzy string matching
+ return SequenceMatcher(None, val1, val2).ratio()
+
+ # Numeric comparison
+ if isinstance(val1, (int, float)) and isinstance(val2, (int, float)):
+ max_val = max(abs(val1), abs(val2), 1)
+ return 1.0 - min(abs(val1 - val2) / max_val, 1.0)
+
+ # List comparison (recursive)
+ if isinstance(val1, list) and isinstance(val2, list):
+ if len(val1) == 0 or len(val2) == 0:
+ return 0.0
+
+ # Find best matches for each element
+ similarities = []
+ for item1 in val1:
+ best_match = max(
+ (compare_values(item1, item2) for item2 in val2), default=0.0
+ )
+ similarities.append(best_match)
+
+ return sum(similarities) / len(similarities)
+
+ # Dict comparison (recursive)
+ if isinstance(val1, dict) and isinstance(val2, dict):
+ return calculate_similarity(val1, val2)
+
+ # Different types
+ return 0.0
diff --git a/src/extrai/utils/llm_output_processing.py b/src/extrai/utils/llm_output_processing.py
index 6acc072..51c3f60 100644
--- a/src/extrai/utils/llm_output_processing.py
+++ b/src/extrai/utils/llm_output_processing.py
@@ -1,5 +1,5 @@
import json
-from typing import Any, Dict, Type, Optional, Union
+from typing import Any, Dict, Type, Optional, Union, Tuple
from extrai.core.analytics_collector import WorkflowAnalyticsCollector
from sqlmodel import SQLModel
@@ -31,25 +31,53 @@ def _filter_special_fields_for_validation(data: Dict[str, Any]) -> Dict[str, Any
}
+def _unwrap_priority_keys(data: Any) -> Tuple[Any, bool]:
+ """
+ Recursively unwraps priority keys (result, data, etc.) from a dictionary.
+ Returns a tuple (unwrapped_data, was_unwrapped).
+ """
+ if isinstance(data, dict):
+ if "_type" in data:
+ return data, False
+ for key in ["result", "data", "results", "entities"]:
+ if key in data:
+ # Found a priority key. Unwrap it and recurse.
+ val, _ = _unwrap_priority_keys(data[key])
+ return val, True
+ return data, False
+
+
def _unwrap_llm_output(data: Any) -> Any:
"""
Unwraps nested data from LLM JSON outputs.
It searches for a primary data payload, which could be a list or a dictionary,
within common wrapping structures like `{"result": [...]}` or `{"data": [...]}`.
+ Recursively unwraps priority keys, but checks for single-key fallback only at the top level
+ if no priority keys were found.
"""
+ # 1. Handle list wrapper (special case where a list contains a single wrapper dict)
+ if isinstance(data, list) and len(data) == 1:
+ inner = data[0]
+ if isinstance(inner, dict) and "_type" not in inner:
+ # Check if inner has priority keys
+ val, found = _unwrap_priority_keys(inner)
+ if found:
+ return val
+
+ # 2. Try to unwrap priority keys recursively
+ val, found = _unwrap_priority_keys(data)
+ if found:
+ return val
+
+ # 3. If no priority keys found, try single-key fallback (once, non-recursive)
if isinstance(data, dict):
- # Prioritize keys that are likely to contain the main payload.
- for key in ["result", "data", "results", "entities"]:
- if key in data:
- return data[key]
+ if "_type" in data:
+ return data
- # If no priority key is found, and there's only one key,
- # return the value associated with that key.
if len(data) == 1:
return next(iter(data.values()))
- # If the data is not a dictionary or no specific unwrapping rule applies,
- # return the data as is.
+ # 4. Return data as is
return data
@@ -58,6 +86,7 @@ def process_and_validate_llm_output(
model_schema_map: Dict[str, Type[SQLModel]],
revision_info_for_error: str = "LLM Output",
analytics_collector: Optional[WorkflowAnalyticsCollector] = None,
+ default_model_type: Optional[str] = None,
) -> list[Dict[str, Any]]:
"""
Parses raw LLM JSON content, unwraps structures, and validates a list of objects
@@ -100,6 +129,10 @@ def process_and_validate_llm_output(
)
type_key = item.get("_type")
+ if not type_key and default_model_type:
+ type_key = default_model_type
+ item["_type"] = type_key # Inject it for consistency
+
if not type_key:
raise LLMOutputValidationError(
f"{revision_info_for_error}: Missing '_type' key in object.", item
@@ -139,6 +172,7 @@ def process_and_validate_raw_json(
raw_llm_content: str,
revision_info_for_error: str,
target_json_schema: Optional[Dict[str, Any]] = None,
+ attempt_unwrap: bool = True,
) -> Union[Dict[str, Any], list[Dict[str, Any]]]:
"""
Parses, unwraps, and validates raw JSON content against a schema.
@@ -147,6 +181,7 @@ def process_and_validate_raw_json(
raw_llm_content: The raw string from the LLM.
revision_info_for_error: A string for error reporting.
target_json_schema: An optional JSON schema for validation.
+ attempt_unwrap: Whether to attempt unwrapping the JSON content. Defaults to True.
Returns:
The validated dictionary or list of dictionaries.
@@ -170,7 +205,10 @@ def process_and_validate_raw_json(
original_exception=e,
)
- unwrapped_data = _unwrap_llm_output(parsed_json)
+ if attempt_unwrap:
+ unwrapped_data = _unwrap_llm_output(parsed_json)
+ else:
+ unwrapped_data = parsed_json
if not isinstance(unwrapped_data, (dict, list)):
raise LLMOutputParseError(
diff --git a/src/extrai/utils/rate_limiter.py b/src/extrai/utils/rate_limiter.py
new file mode 100644
index 0000000..a82ec0e
--- /dev/null
+++ b/src/extrai/utils/rate_limiter.py
@@ -0,0 +1,74 @@
+import asyncio
+import time
+from typing import List, Tuple
+
+
+class AsyncRateLimiter:
+ """
+ A generic async rate limiter using a sliding window algorithm.
+ Tracks usage of a resource (calls, tokens, etc.) over a time period.
+ """
+
+ def __init__(self, max_capacity: int, period: float = 60.0):
+ """
+ Args:
+ max_capacity: The maximum amount of resource allowed in the period.
+ period: The time window in seconds (default 60.0 for 1 minute).
+ """
+ self.max_capacity = max_capacity
+ self.period = period
+ # List of (timestamp, cost)
+ self.history: List[Tuple[float, int]] = []
+ self._lock = asyncio.Lock()
+
+ async def acquire(self, cost: int = 1):
+ """
+ Acquires the specified amount of resource, waiting if necessary.
+
+ Args:
+ cost: The amount of resource to consume (default 1).
+ """
+ async with self._lock:
+ now = time.monotonic()
+
+ # 1. Clean up old history
+ self.history = [(t, c) for t, c in self.history if now - t <= self.period]
+
+ # 2. Calculate current usage
+ current_usage = sum(c for t, c in self.history)
+
+ # 3. Check if we need to wait
+ if current_usage + cost > self.max_capacity:
+ # We need to wait until enough usage expires.
+ # Find how much we need to free.
+ needed_to_free = (current_usage + cost) - self.max_capacity
+
+ freed = 0
+ wait_until = now
+
+ for t, c in self.history:
+ freed += c
+ if freed >= needed_to_free:
+ # Found the point where enough capacity is freed
+ wait_until = t + self.period
+ break
+
+ sleep_time = wait_until - now
+ if sleep_time > 0:
+ await asyncio.sleep(sleep_time)
+
+ # After sleep, update state
+ now = time.monotonic()
+ self.history = [
+ (t, c) for t, c in self.history if now - t <= self.period
+ ]
+
+ # 4. Record usage
+ self.history.append((now, cost))
+
+ async def __aenter__(self):
+ await self.acquire(1)
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ pass
diff --git a/src/extrai/utils/serialization_utils.py b/src/extrai/utils/serialization_utils.py
new file mode 100644
index 0000000..3b08680
--- /dev/null
+++ b/src/extrai/utils/serialization_utils.py
@@ -0,0 +1,69 @@
+from typing import Any, Dict, Set, Optional
+from sqlmodel import SQLModel
+from sqlalchemy.orm.collections import InstrumentedList
+
+
+def serialize_sqlmodel_with_relationships(
+ obj: SQLModel, seen: Optional[Set[int]] = None
+) -> Dict[str, Any]:
+ """
+ Recursively serializes a SQLModel instance, including its loaded relationships.
+ Uses model_dump(mode='json') to handle basic types (including Decimal -> str/float).
+
+ Args:
+ obj: The SQLModel instance to serialize.
+ seen: A set of object IDs visited in the current recursion stack to prevent infinite loops.
+
+ Returns:
+ A dictionary representation of the SQLModel instance, including relationships.
+ """
+ if seen is None:
+ seen = set()
+
+ obj_id = id(obj)
+ if obj_id in seen:
+ # Prevent infinite recursion for circular references
+ return {}
+
+ seen.add(obj_id)
+
+ # 1. Dump basic fields (handles Decimals, Datetimes, etc. via Pydantic serialization)
+ data = obj.model_dump(mode="json")
+
+ # 2. Inspect for relationships
+ # We rely on SQLModel's internal metadata which is consistent for SQLModel instances
+ relationships = getattr(obj, "__sqlmodel_relationships__", {})
+
+ for key in relationships.keys():
+ value = getattr(obj, key, None)
+
+ if value is None:
+ continue
+
+ if isinstance(value, (list, InstrumentedList)):
+ data[key] = [
+ serialize_sqlmodel_with_relationships(item, seen)
+ if isinstance(item, SQLModel)
+ else item
+ for item in value
+ ]
+ elif isinstance(value, SQLModel):
+ data[key] = serialize_sqlmodel_with_relationships(value, seen)
+
+ return data
+
+
+def make_json_serializable(obj: Any) -> Any:
+ """
+ Recursively converts objects to JSON-serializable formats.
+ Handles Decimals by converting to float.
+ """
+ from decimal import Decimal
+
+ if isinstance(obj, Decimal):
+ return float(obj)
+ elif isinstance(obj, dict):
+ return {k: make_json_serializable(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [make_json_serializable(item) for item in obj]
+ return obj
diff --git a/src/extrai/utils/type_mapping.py b/src/extrai/utils/type_mapping.py
new file mode 100644
index 0000000..0b891db
--- /dev/null
+++ b/src/extrai/utils/type_mapping.py
@@ -0,0 +1,211 @@
+import datetime
+import enum
+from typing import (
+ Any,
+ Dict,
+ List,
+ Optional,
+ get_args,
+ get_origin,
+ Union as TypingUnion,
+)
+
+
+def _process_union_types(args, recurse_func):
+ """Helper to process Union types, filtering and sorting."""
+ if not args:
+ return "union"
+ union_types_str = [recurse_func(arg) for arg in args]
+ processed_union_types = sorted(set(t for t in union_types_str if t != "none"))
+ if len(processed_union_types) == 1:
+ return processed_union_types[0]
+ return f"union[{','.join(processed_union_types)}]"
+
+
+# Handler registry for different type origins
+ORIGIN_HANDLERS = {
+ Optional: lambda args, r: r(args[0])
+ if args and args[0] is not type(None)
+ else "none",
+ list: lambda args, r: f"list[{','.join([r(arg) for arg in args])}]"
+ if args
+ else "list",
+ List: lambda args, r: f"list[{','.join([r(arg) for arg in args])}]"
+ if args
+ else "list",
+ dict: lambda args, r: f"dict[{r(args[0])},{r(args[1])}]"
+ if args and len(args) == 2
+ else "dict",
+ Dict: lambda args, r: f"dict[{r(args[0])},{r(args[1])}]"
+ if args and len(args) == 2
+ else "dict",
+ TypingUnion: _process_union_types,
+}
+
+# Data-driven approach for base types
+BASE_TYPE_MAP = {
+ int: "int",
+ str: "str",
+ bool: "bool",
+ float: "float",
+ datetime.date: "date",
+ datetime.datetime: "datetime",
+ bytes: "bytes",
+ Any: "any",
+ type(None): "none",
+}
+
+
+def get_python_type_str_from_pydantic_annotation(annotation: Any) -> str:
+ """Helper function to get a simplified string from Pydantic/SQLModel annotations."""
+ origin = get_origin(annotation)
+ args = get_args(annotation)
+
+ if origin in ORIGIN_HANDLERS:
+ return ORIGIN_HANDLERS[origin](
+ args, get_python_type_str_from_pydantic_annotation
+ )
+
+ if annotation in BASE_TYPE_MAP:
+ return BASE_TYPE_MAP[annotation]
+
+ if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
+ return "enum"
+
+ if hasattr(annotation, "__name__"):
+ name_lower = annotation.__name__.lower()
+ if name_lower == "secretstr":
+ return "str"
+ return name_lower
+
+ # Fallback
+ cleaned_annotation_str = str(annotation).lower().replace("typing.", "")
+ if cleaned_annotation_str.startswith("~"):
+ cleaned_annotation_str = cleaned_annotation_str[1:]
+ return cleaned_annotation_str
+
+
+# --- Data-driven mappings for type conversion ---
+SIMPLE_PYTHON_TYPE_MAP = {
+ "int": "integer",
+ "str": "string",
+ "bool": "boolean",
+ "float": "number (float/decimal)",
+ "date": "string (date format)",
+ "datetime": "string (datetime format)",
+ "bytes": "string (base64 encoded)",
+ "enum": "string (enum)",
+ "any": "any",
+ "none": "null",
+}
+
+SQL_TYPE_KEYWORDS = [
+ ("int", "integer"),
+ ("char", "string"),
+ ("text", "string"),
+ ("clob", "string"),
+ ("bool", "boolean"),
+ ("date", "string (date/datetime format)"),
+ ("time", "string (date/datetime format)"),
+ ("numeric", "number (float/decimal)"),
+ ("decimal", "number (float/decimal)"),
+ ("float", "number (float/decimal)"),
+ ("double", "number (float/decimal)"),
+ ("json", "object"),
+ ("array", "array"),
+]
+
+
+# --- Handlers for complex and generic types ---
+def _handle_list_type(python_type_lower: str) -> Optional[str]:
+ """Handles list[...] and array[...] type mappings."""
+ if python_type_lower.startswith("list[") and python_type_lower.endswith("]"):
+ inner_type_str = python_type_lower[5:-1]
+ mapped_inner_type = map_sql_type_to_llm_type("", inner_type_str)
+ return f"array[{mapped_inner_type}]"
+ return None
+
+
+def _handle_dict_type(python_type_lower: str) -> Optional[str]:
+ """Handles dict[...] and object[...] type mappings."""
+ if python_type_lower.startswith("dict[") and python_type_lower.endswith("]"):
+ inner_types_str = python_type_lower[5:-1]
+ try:
+ key_type_str, value_type_str = inner_types_str.split(",", 1)
+ mapped_key_type = map_sql_type_to_llm_type("", key_type_str.strip())
+ mapped_value_type = map_sql_type_to_llm_type("", value_type_str.strip())
+ return f"object[{mapped_key_type},{mapped_value_type}]"
+ except ValueError:
+ return "object"
+ return None
+
+
+def _handle_union_type(python_type_lower: str) -> Optional[str]:
+ """Handles union[...] type mappings."""
+ if python_type_lower.startswith("union[") and python_type_lower.endswith("]"):
+ inner_types_str = python_type_lower[6:-1]
+ union_parts = [p.strip() for p in inner_types_str.split(",") if p.strip()]
+ mapped_parts = sorted(
+ set(map_sql_type_to_llm_type("", part) for part in union_parts)
+ )
+ if not mapped_parts:
+ return "any"
+ return (
+ mapped_parts[0]
+ if len(mapped_parts) == 1
+ else f"union[{','.join(mapped_parts)}]"
+ )
+ return None
+
+
+def _handle_generic_or_unknown_type(
+ python_type_lower: str, sql_type_lower: str
+) -> Optional[str]:
+ """Handles ambiguous types like plain 'list' or 'dict' and unknown types."""
+ if python_type_lower == "list":
+ if "text" in sql_type_lower: # Let the SQL keyword mapping handle this case
+ return None
+
+ return "array"
+
+ if python_type_lower == "dict":
+ return "object"
+
+ if python_type_lower.startswith("unknown"):
+ if "json" in sql_type_lower:
+ return "object"
+ if "array" in sql_type_lower:
+ return "array"
+ return "string"
+ return None
+
+
+def map_sql_type_to_llm_type(sql_type_str: str, python_type_str: str) -> str:
+ """
+ Maps SQL/Python types to simpler LLM-friendly type strings using a dispatcher pattern.
+ """
+ sql_type_lower = str(sql_type_str).lower()
+ python_type_lower = str(python_type_str).lower()
+
+ # 1. Handle complex Python types first
+ for handler in [_handle_list_type, _handle_dict_type, _handle_union_type]:
+ result = handler(python_type_lower)
+ if result:
+ return result
+
+ # 2. Look up in the simple Python type map
+ if python_type_lower in SIMPLE_PYTHON_TYPE_MAP:
+ return SIMPLE_PYTHON_TYPE_MAP[python_type_lower]
+
+ # 3. Handle generic or unknown types, which have precedence over broad SQL keywords
+ result = _handle_generic_or_unknown_type(python_type_lower, sql_type_lower)
+ if result:
+ return result
+
+ # 4. Search through SQL type keywords as a fallback
+ for keyword, llm_type in SQL_TYPE_KEYWORDS:
+ if keyword in sql_type_lower:
+ return llm_type
+
+ # 5. Final fallback if no other rule matched
+ return "string"
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..11cf9df
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+# Init file for tests module
diff --git a/tests/core/batch_pipeline/test_batch_counting.py b/tests/core/batch_pipeline/test_batch_counting.py
new file mode 100644
index 0000000..b3603d5
--- /dev/null
+++ b/tests/core/batch_pipeline/test_batch_counting.py
@@ -0,0 +1,151 @@
+import unittest
+from unittest.mock import MagicMock, patch, AsyncMock
+from extrai.core.batch_pipeline import BatchPipeline, BatchJobStatus
+from extrai.core.model_registry import ModelRegistry
+from extrai.core.extraction_config import ExtractionConfig
+from extrai.core.base_llm_client import BaseLLMClient
+from extrai.core.batch_models import BatchJobContext
+
+
+class TestBatchPipelineCounting(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ self.mock_model_registry = MagicMock(spec=ModelRegistry)
+ self.mock_model_registry.root_model = MagicMock()
+ self.mock_model_registry.root_model.__name__ = "RootModel"
+ self.mock_model_registry.models = [self.mock_model_registry.root_model]
+ self.mock_model_registry.llm_schema_json = "{}"
+ self.mock_model_registry.model_map = {
+ "RootModel": self.mock_model_registry.root_model
+ }
+ self.mock_model_registry.get_all_model_names.return_value = ["RootModel"]
+
+ self.mock_client = MagicMock(spec=BaseLLMClient)
+ self.mock_client.temperature = 0.0
+ self.mock_config = MagicMock(spec=ExtractionConfig)
+ self.mock_config.consensus_threshold = 0.6
+ self.mock_config.conflict_resolver = None
+ self.mock_config.num_llm_revisions = 1
+ self.mock_config.max_validation_retries_per_revision = 1
+ self.mock_config.use_structured_output = False
+ self.mock_config.use_hierarchical_extraction = False
+ self.mock_analytics = MagicMock()
+ self.mock_session = MagicMock()
+ self.mock_logger = MagicMock()
+
+ with (
+ patch("extrai.core.batch_pipeline.ClientRotator") as MockClientRotator,
+ patch(
+ "extrai.core.batch_pipeline.ExtractionContextPreparer"
+ ) as MockContextPreparer,
+ patch("extrai.core.batch_pipeline.PromptBuilder") as MockBuilder,
+ patch("extrai.core.batch_pipeline.EntityCounter") as MockCounter,
+ patch("extrai.core.batch_pipeline.JSONConsensus") as MockConsensus,
+ patch("extrai.core.batch_pipeline.ModelWrapperBuilder"),
+ ):
+ self.pipeline = BatchPipeline(
+ self.mock_model_registry,
+ self.mock_client,
+ self.mock_config,
+ self.mock_analytics,
+ self.mock_logger,
+ )
+ self.pipeline.client_rotator = MockClientRotator.return_value
+ self.pipeline.context_preparer = MockContextPreparer.return_value
+ self.pipeline.prompt_builder = MockBuilder.return_value
+ self.pipeline.entity_counter = MockCounter.return_value
+ self.pipeline.consensus = MockConsensus.return_value
+
+ async def test_submit_batch_counting(self):
+ # Setup mocks
+ self.pipeline.entity_counter.prepare_counting_prompts.return_value = (
+ "count_sys",
+ "count_user",
+ )
+ self.pipeline.context_preparer.prepare_example = AsyncMock(return_value="")
+
+ mock_batch_job = MagicMock()
+ mock_batch_job.id = "counting_batch_id"
+
+ # Mock the entity_counter's client for counting phase
+ self.pipeline.entity_counter.llm_client.create_batch_job = AsyncMock(
+ return_value=mock_batch_job
+ )
+
+ # Test submit
+ root_id = await self.pipeline.submit_batch(
+ self.mock_session, ["doc"], count_entities=True
+ )
+
+ # Verify
+ self.assertIsInstance(root_id, str)
+ self.assertEqual(self.mock_session.add.call_count, 2)
+ added_context = self.mock_session.add.call_args[0][0]
+
+ self.assertEqual(added_context.current_batch_id, "counting_batch_id")
+ self.assertEqual(added_context.status, BatchJobStatus.COUNTING_SUBMITTED)
+
+ config = added_context.config
+ self.assertTrue(config["count_entities"])
+
+ async def test_process_batch_counting_transition(self):
+ # Mock Context
+ context = BatchJobContext(
+ root_batch_id="root_1",
+ current_batch_id="counting_batch_id",
+ status=BatchJobStatus.COUNTING_SUBMITTED,
+ input_strings=["doc"],
+ config={"count_entities": True, "custom_extraction_process": "proc"},
+ )
+ self.mock_session.get.return_value = context
+
+ # Mock get_status to return COUNTING_READY_TO_PROCESS
+ mock_provider_job = MagicMock()
+ mock_provider_job.status = "completed"
+
+ # Correctly mock entity_counter.llm_client for counting status check and results retrieval
+ self.pipeline.entity_counter.llm_client.retrieve_batch_job = AsyncMock(
+ return_value=mock_provider_job
+ )
+
+ # Mock counting results
+ counting_results_file = '{"id": "line1"}'
+ self.pipeline.entity_counter.llm_client.retrieve_batch_results = AsyncMock(
+ return_value=counting_results_file
+ )
+ self.pipeline.entity_counter.llm_client.extract_content_from_batch_response.return_value = '{"RootModel": ["desc1"]}'
+
+ self.pipeline.entity_counter.validate_counts.return_value = {
+ "RootModel": ["desc1"]
+ }
+
+ # Mock extraction batch submission
+ # Extraction phase uses client_rotator client
+ mock_extraction_job = MagicMock()
+ mock_extraction_job.id = "extraction_batch_id"
+ mock_client_instance = self.pipeline.client_rotator.get_next_client.return_value
+ mock_client_instance.create_batch_job = AsyncMock(
+ return_value=mock_extraction_job
+ )
+
+ # Ensure build_prompts returns expected tuple
+ self.pipeline.prompt_builder.build_prompts.return_value = ("sys", "user")
+
+ # Test process
+ result = await self.pipeline.process_batch("root_1", self.mock_session)
+
+ # Verify transition
+ self.assertEqual(result.status, BatchJobStatus.PROCESSING)
+ self.assertEqual(result.message, "Transitioned from counting to extraction")
+
+ # Verify context updated
+ self.assertEqual(context.status, BatchJobStatus.SUBMITTED)
+ self.assertEqual(context.current_batch_id, "extraction_batch_id")
+
+ # Verify config updated with descriptions
+ config = context.config
+ self.assertIn("expected_entity_descriptions", config)
+ self.assertEqual(config["expected_entity_descriptions"], ["[RootModel] desc1"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/batch_pipeline/test_batch_pipeline.py b/tests/core/batch_pipeline/test_batch_pipeline.py
new file mode 100644
index 0000000..6f25376
--- /dev/null
+++ b/tests/core/batch_pipeline/test_batch_pipeline.py
@@ -0,0 +1,106 @@
+import unittest
+from unittest.mock import MagicMock, patch, AsyncMock
+from extrai.core.batch_pipeline import BatchPipeline, BatchJobStatus
+from extrai.core.model_registry import ModelRegistry
+from extrai.core.extraction_config import ExtractionConfig
+from extrai.core.base_llm_client import BaseLLMClient
+from extrai.core.batch_models import BatchJobContext
+
+
+class TestBatchPipeline(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ self.mock_model_registry = MagicMock(spec=ModelRegistry)
+ self.mock_model_registry.root_model = MagicMock()
+ self.mock_model_registry.root_model.__name__ = "RootModel"
+ self.mock_model_registry.models = [self.mock_model_registry.root_model]
+ self.mock_model_registry.llm_schema_json = "{}"
+ self.mock_model_registry.model_map = {
+ "RootModel": self.mock_model_registry.root_model
+ }
+
+ self.mock_client = MagicMock(spec=BaseLLMClient)
+ self.mock_client.temperature = 0.0
+ self.mock_config = MagicMock(spec=ExtractionConfig)
+ self.mock_config.consensus_threshold = 0.6
+ self.mock_config.conflict_resolver = None
+ self.mock_config.num_llm_revisions = 1
+ self.mock_config.max_validation_retries_per_revision = 1
+ self.mock_config.use_structured_output = False
+ self.mock_config.use_hierarchical_extraction = False
+ self.mock_analytics = MagicMock()
+ self.mock_session = MagicMock()
+ self.mock_logger = MagicMock()
+
+ with (
+ patch("extrai.core.batch_pipeline.ClientRotator") as MockClientRotator,
+ patch(
+ "extrai.core.batch_pipeline.ExtractionContextPreparer"
+ ) as MockContextPreparer,
+ patch("extrai.core.batch_pipeline.PromptBuilder") as MockBuilder,
+ patch("extrai.core.batch_pipeline.EntityCounter") as MockCounter,
+ patch("extrai.core.batch_pipeline.JSONConsensus") as MockConsensus,
+ patch("extrai.core.batch_pipeline.ModelWrapperBuilder"),
+ ):
+ self.pipeline = BatchPipeline(
+ self.mock_model_registry,
+ self.mock_client,
+ self.mock_config,
+ self.mock_analytics,
+ self.mock_logger,
+ )
+ # We need to access the instances created inside, so we'll mock them on the pipeline instance
+ self.pipeline.client_rotator = MockClientRotator.return_value
+ self.pipeline.context_preparer = MockContextPreparer.return_value
+ self.pipeline.prompt_builder = MockBuilder.return_value
+ self.pipeline.entity_counter = MockCounter.return_value
+ self.pipeline.consensus = MockConsensus.return_value
+
+ async def test_submit_batch_success(self):
+ self.pipeline.prompt_builder.build_prompts.return_value = ("sys", "user")
+
+ mock_batch_job = MagicMock()
+ mock_batch_job.id = "provider_id_123"
+
+ mock_client_instance = self.pipeline.client_rotator.get_next_client.return_value
+ mock_client_instance.create_batch_job = AsyncMock(return_value=mock_batch_job)
+
+ self.pipeline.context_preparer.prepare_example = AsyncMock(return_value="")
+ self.pipeline._count_if_needed = AsyncMock(return_value=None)
+
+ root_id = await self.pipeline.submit_batch(self.mock_session, ["doc"])
+
+ self.assertIsInstance(root_id, str)
+ self.assertEqual(self.mock_session.add.call_count, 2)
+ self.assertEqual(self.mock_session.commit.call_count, 2)
+
+ added_context = self.mock_session.add.call_args[0][0]
+ self.assertIsInstance(added_context, BatchJobContext)
+ self.assertEqual(added_context.current_batch_id, "provider_id_123")
+ self.assertEqual(added_context.status, BatchJobStatus.SUBMITTED)
+
+ async def test_retrieve_and_validate_results(self):
+ mock_context = BatchJobContext(current_batch_id="prov_1")
+
+ mock_client = self.pipeline.client_rotator.get_next_client.return_value
+ mock_client.retrieve_batch_results = AsyncMock(
+ return_value='{"key": "value"}\n{"key": "value2"}'
+ )
+ mock_client.extract_content_from_batch_response.side_effect = [
+ '{"_type": "RootModel", "id": 1}',
+ '{"_type": "RootModel", "id": 2}',
+ ]
+
+ with patch(
+ "extrai.core.batch_pipeline.process_and_validate_llm_output"
+ ) as mock_validate:
+ mock_validate.side_effect = [[{"id": 1}], [{"id": 2}]]
+ results = await self.pipeline._retrieve_and_validate_results(mock_context)
+
+ # normalize_json_revisions wraps each revision in a list.
+ self.assertEqual(len(results), 2)
+ self.assertEqual(results[0], [{"id": 1}])
+ self.assertEqual(results[1], [{"id": 2}])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/batch_pipeline/test_batch_pipeline_structured.py b/tests/core/batch_pipeline/test_batch_pipeline_structured.py
new file mode 100644
index 0000000..c75c154
--- /dev/null
+++ b/tests/core/batch_pipeline/test_batch_pipeline_structured.py
@@ -0,0 +1,133 @@
+import unittest
+from unittest.mock import MagicMock, patch, AsyncMock
+import json
+from sqlmodel import SQLModel, Field
+
+from extrai.core.batch_pipeline import BatchPipeline
+from extrai.core.model_registry import ModelRegistry
+from extrai.core.extraction_config import ExtractionConfig
+from extrai.core.base_llm_client import BaseLLMClient
+from extrai.core.batch_models import BatchJobContext
+
+
+class Recipe(SQLModel):
+ name: str
+ ingredients: list[str] = Field(default_factory=list)
+ prep_time: int
+
+
+class TestBatchPipelineStructured(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ self.mock_model_registry = MagicMock(spec=ModelRegistry)
+ self.mock_model_registry.root_model = Recipe
+ self.mock_model_registry.models = [Recipe]
+ self.mock_model_registry.llm_schema_json = "{}"
+ self.mock_model_registry.model_map = {"Recipe": Recipe}
+
+ self.mock_client = MagicMock(spec=BaseLLMClient)
+ self.mock_client.temperature = 0.0
+
+ # Configure for Structured Output
+ self.mock_config = MagicMock(spec=ExtractionConfig)
+ self.mock_config.consensus_threshold = 0.6
+ self.mock_config.conflict_resolver = None
+ self.mock_config.num_llm_revisions = 1
+ self.mock_config.max_validation_retries_per_revision = 1
+ self.mock_config.use_structured_output = True # ENABLED
+ self.mock_config.use_hierarchical_extraction = False
+
+ self.mock_analytics = MagicMock()
+ self.mock_session = MagicMock()
+ self.mock_logger = MagicMock()
+
+ with (
+ patch("extrai.core.batch_pipeline.ClientRotator") as MockClientRotator,
+ patch(
+ "extrai.core.batch_pipeline.ExtractionContextPreparer"
+ ) as MockContextPreparer,
+ patch("extrai.core.batch_pipeline.PromptBuilder") as MockBuilder,
+ patch("extrai.core.batch_pipeline.EntityCounter") as MockCounter,
+ patch("extrai.core.batch_pipeline.JSONConsensus") as MockConsensus,
+ patch("extrai.core.batch_pipeline.ModelWrapperBuilder"),
+ ):
+ self.pipeline = BatchPipeline(
+ self.mock_model_registry,
+ self.mock_client,
+ self.mock_config,
+ self.mock_analytics,
+ self.mock_logger,
+ )
+ self.pipeline.client_rotator = MockClientRotator.return_value
+ self.pipeline.context_preparer = MockContextPreparer.return_value
+ self.pipeline.prompt_builder = MockBuilder.return_value
+ self.pipeline.entity_counter = MockCounter.return_value
+ self.pipeline.consensus = MockConsensus.return_value
+
+ async def test_retrieve_and_validate_results_missing_type(self):
+ """
+ Test that validation FAILS when _type is missing in structured output mode
+ (before the fix is applied).
+ """
+ mock_context = BatchJobContext(
+ current_batch_id="prov_1",
+ config={
+ "use_structured_output": True,
+ "schema_json": {},
+ },
+ )
+
+ mock_client = self.pipeline.client_rotator.get_next_client.return_value
+
+ # Simulating structured output which does NOT contain _type
+ structured_response = {
+ "entities": [
+ {"name": "Pancake", "ingredients": ["flour", "milk"], "prep_time": 10}
+ ]
+ }
+
+ mock_client.retrieve_batch_results = AsyncMock(
+ return_value=json.dumps(structured_response)
+ )
+
+ # We need to simulate how extract_content_from_batch_response behaves.
+ # Assuming it returns the inner JSON string or dict.
+ # In the original code it calls `process_and_validate_llm_output`.
+
+ # For this test, we mock extract_content_from_batch_response to return the JSON string of entities wrapper
+ # The real client implementation varies, but let's assume it returns the raw JSON string
+ mock_client.extract_content_from_batch_response.return_value = json.dumps(
+ structured_response
+ )
+
+ # We expect this to fail because we haven't fixed the code yet,
+ # and process_and_validate_llm_output will look for _type.
+
+ # NOTE: process_and_validate_llm_output is imported in batch_pipeline.
+ # We shouldn't patch it if we want to test the failure integration,
+ # but process_and_validate_llm_output raises LLMOutputValidationError.
+ # BatchPipeline catches Exception and logs it, returning empty list if validation fails.
+
+ # However, looking at _retrieve_and_validate_results:
+ # It logs warning on validation failure.
+
+ # To assert failure, we can check that the returned list is empty
+ # OR we can mock process_and_validate_llm_output to see what it was called with
+ # OR we can let it run and see if it returns valid objects.
+
+ # Since we want to prove it fails validation, we should let the real process_and_validate_llm_output run.
+ # But `process_and_validate_llm_output` requires `Recipe` (SQLModel) to be in `model_schema_map`.
+ # We set that up in setUp.
+
+ results = await self.pipeline._retrieve_and_validate_results(mock_context)
+
+ # With the fix, we expect results to be validated and returned
+ # Since _type is injected, it should be present in the result
+ self.assertEqual(len(results), 1)
+ self.assertEqual(len(results[0]), 1)
+ item = results[0][0]
+ self.assertEqual(item["name"], "Pancake")
+ self.assertEqual(item["_type"], "Recipe")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/code_generation/test_python_builder.py b/tests/core/code_generation/test_python_builder.py
new file mode 100644
index 0000000..e1dbea5
--- /dev/null
+++ b/tests/core/code_generation/test_python_builder.py
@@ -0,0 +1,25 @@
+from extrai.core.code_generation.python_builder import ImportManager
+
+
+class TestImportManager:
+ def test_malformed_imports_fallback(self):
+ """Tests that malformed imports are handled gracefully by falling back to appending as-is."""
+ manager = ImportManager()
+
+ # Add malformed imports that should trigger fallback behavior
+ malformed_sqlmodel = "from sqlmodel" # Missing " import "
+ malformed_typing = "from typing" # Missing " import "
+ malformed_import = "import ," # Missing modules (strips to "import ,")
+
+ manager.add_custom_imports(
+ [malformed_sqlmodel, malformed_typing, malformed_import]
+ )
+
+ rendered = manager.render()
+
+ assert "from sqlmodel" in rendered
+ assert "from typing" in rendered
+
+ # Check that "import ," line is present exactly as added
+ rendered_lines = rendered.split("\n")
+ assert "import ," in rendered_lines
diff --git a/tests/core/entity_counter/test_entity_counter.py b/tests/core/entity_counter/test_entity_counter.py
new file mode 100644
index 0000000..a5feed8
--- /dev/null
+++ b/tests/core/entity_counter/test_entity_counter.py
@@ -0,0 +1,81 @@
+import unittest
+from unittest.mock import MagicMock, patch, AsyncMock
+from extrai.core.entity_counter import EntityCounter
+from extrai.core.model_registry import ModelRegistry
+from extrai.core.extraction_config import ExtractionConfig
+from extrai.core.base_llm_client import BaseLLMClient
+
+
+class TestEntityCounter(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ self.mock_model_registry = MagicMock(spec=ModelRegistry)
+ self.mock_client = MagicMock(spec=BaseLLMClient)
+ self.mock_config = MagicMock(spec=ExtractionConfig)
+ self.mock_config.max_validation_retries_per_revision = 1
+ self.mock_analytics = MagicMock()
+ self.mock_logger = MagicMock()
+
+ self.counter = EntityCounter(
+ self.mock_model_registry,
+ self.mock_client,
+ self.mock_config,
+ self.mock_analytics,
+ self.mock_logger,
+ )
+
+ @patch("extrai.core.entity_counter.generate_entity_counting_system_prompt")
+ @patch("extrai.core.entity_counter.generate_entity_counting_user_prompt")
+ @patch("extrai.core.entity_counter.create_model")
+ async def test_count_entities_success(
+ self, mock_create_model, mock_user_prompt, mock_system_prompt
+ ):
+ # Setup mocks
+ self.mock_model_registry.get_schema_for_models.return_value = (
+ '{"type": "object"}'
+ )
+ self.mock_client.generate_and_validate_raw_json_output = AsyncMock(
+ return_value={"ModelA": 5}
+ )
+
+ mock_model_instance = MagicMock()
+ mock_model_instance.model_dump.return_value = {"ModelA": 5}
+
+ # Mock the dynamically created Pydantic model
+ MockPydanticModel = MagicMock()
+ MockPydanticModel.return_value = mock_model_instance
+ mock_create_model.return_value = MockPydanticModel
+
+ counts = await self.counter.count_entities(["doc"], ["ModelA"])
+
+ self.assertEqual(counts, {"ModelA": 5})
+ self.mock_model_registry.get_schema_for_models.assert_called_with(["ModelA"])
+ self.mock_client.generate_and_validate_raw_json_output.assert_called_once()
+ mock_create_model.assert_called_once()
+
+ async def test_count_entities_llm_failure(self):
+ self.mock_model_registry.get_schema_for_models.return_value = "{}"
+ self.mock_client.generate_and_validate_raw_json_output = AsyncMock(
+ side_effect=Exception("LLM Fail")
+ )
+
+ with patch("extrai.core.entity_counter.create_model"):
+ counts = await self.counter.count_entities(["doc"], ["ModelA"])
+
+ self.assertEqual(counts, {})
+ self.mock_logger.error.assert_called_once()
+
+ async def test_count_entities_invalid_output(self):
+ self.mock_model_registry.get_schema_for_models.return_value = "{}"
+ self.mock_client.generate_and_validate_raw_json_output = AsyncMock(
+ return_value="Not a dict"
+ )
+
+ with patch("extrai.core.entity_counter.create_model"):
+ counts = await self.counter.count_entities(["doc"], ["ModelA"])
+
+ self.assertEqual(counts, {})
+ self.mock_logger.warning.assert_called_once()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/extraction_pipeline/test_extraction_pipeline.py b/tests/core/extraction_pipeline/test_extraction_pipeline.py
new file mode 100644
index 0000000..9a5c0a3
--- /dev/null
+++ b/tests/core/extraction_pipeline/test_extraction_pipeline.py
@@ -0,0 +1,151 @@
+import unittest
+from unittest.mock import MagicMock, patch, AsyncMock
+from extrai.core.extraction_pipeline import ExtractionPipeline
+from extrai.core.model_registry import ModelRegistry
+from extrai.core.extraction_config import ExtractionConfig
+from extrai.core.base_llm_client import BaseLLMClient
+from extrai.core.extraction_request_factory import ExtractionRequest
+
+
+class TestExtractionPipeline(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ self.mock_model_registry = MagicMock(spec=ModelRegistry)
+ self.mock_model_registry.root_model = MagicMock()
+ self.mock_model_registry.root_model.__name__ = "RootModel"
+ self.mock_model_registry.llm_schema_json = "{}"
+
+ self.mock_client = MagicMock(spec=BaseLLMClient)
+ self.mock_config = MagicMock(spec=ExtractionConfig)
+ self.mock_config.use_hierarchical_extraction = False
+ self.mock_config.use_structured_output = False
+ self.mock_config.max_validation_retries_per_revision = 1
+ self.mock_analytics = MagicMock()
+ self.mock_logger = MagicMock()
+
+ with (
+ patch("extrai.core.extraction_pipeline.ClientRotator"),
+ patch("extrai.core.extraction_pipeline.ExtractionContextPreparer"),
+ patch("extrai.core.extraction_pipeline.PromptBuilder"),
+ patch("extrai.core.extraction_pipeline.EntityCounter"),
+ patch("extrai.core.extraction_pipeline.LLMRunner"),
+ patch("extrai.core.extraction_pipeline.ModelWrapperBuilder"),
+ patch(
+ "extrai.core.extraction_pipeline.ExtractionRequestFactory"
+ ) as MockFactory,
+ ):
+ # Setup default request factory behavior
+ self.mock_request_factory = MockFactory.return_value
+ self.mock_request = ExtractionRequest(
+ system_prompt="sys",
+ user_prompt="user",
+ json_schema=None,
+ model_name=None,
+ response_model=None,
+ )
+ self.mock_request_factory.prepare_request.return_value = self.mock_request
+
+ self.pipeline = ExtractionPipeline(
+ self.mock_model_registry,
+ self.mock_client,
+ self.mock_config,
+ self.mock_analytics,
+ self.mock_logger,
+ )
+
+ async def test_extract_standard_flow(self):
+ self.pipeline.llm_runner.run_extraction_cycle = AsyncMock(
+ return_value=[{"id": 1}]
+ )
+ self.pipeline.entity_counter.count_entities = AsyncMock(return_value={})
+
+ self.pipeline.context_preparer.prepare_example = AsyncMock(return_value="{}")
+
+ results = await self.pipeline.extract(["doc"])
+
+ self.assertEqual(results, [{"id": 1}])
+ self.pipeline.context_preparer.prepare_example.assert_called_once()
+ self.pipeline.request_factory.prepare_request.assert_called_once()
+ self.pipeline.llm_runner.run_extraction_cycle.assert_called_once_with(
+ system_prompt="sys", user_prompt="user"
+ )
+
+ async def test_extract_structured_flow(self):
+ self.pipeline.config.use_structured_output = True
+
+ # Setup structured request
+ structured_request = ExtractionRequest(
+ system_prompt="sys_struct",
+ user_prompt="user_struct",
+ json_schema={"type": "object"},
+ model_name=None,
+ response_model=MagicMock(),
+ )
+ self.mock_request_factory.prepare_request.return_value = structured_request
+ self.pipeline.llm_runner.run_structured_extraction_cycle = AsyncMock(
+ return_value=[{"id": 1}]
+ )
+ self.pipeline.context_preparer.prepare_example = AsyncMock(return_value="{}")
+
+ results = await self.pipeline.extract(["doc"])
+
+ self.assertEqual(results, [{"id": 1}])
+ self.pipeline.request_factory.prepare_request.assert_called_once()
+ self.pipeline.llm_runner.run_structured_extraction_cycle.assert_called_once_with(
+ system_prompt="sys_struct",
+ user_prompt="user_struct",
+ response_model=structured_request.response_model,
+ )
+
+ async def test_extract_hierarchical_flow(self):
+ self.pipeline.config.use_hierarchical_extraction = True
+ self.pipeline.hierarchical_extractor = MagicMock()
+ self.pipeline.hierarchical_extractor.extract = AsyncMock(
+ return_value=[{"id": 1}]
+ )
+
+ self.pipeline.context_preparer.prepare_example = AsyncMock(return_value="{}")
+
+ results = await self.pipeline.extract(["doc"])
+
+ self.assertEqual(results, [{"id": 1}])
+ self.pipeline.context_preparer.prepare_example.assert_called_once()
+ self.pipeline.hierarchical_extractor.extract.assert_called_once()
+
+ async def test_count_entities(self):
+ self.pipeline.entity_counter.count_entities = AsyncMock(
+ return_value={"RootModel": 10}
+ )
+ self.pipeline.context_preparer.prepare_example = AsyncMock(return_value="{}")
+ self.pipeline.llm_runner.run_extraction_cycle = AsyncMock(return_value=[])
+
+ await self.pipeline.extract(["doc"], count_entities=True)
+
+ self.pipeline.entity_counter.count_entities.assert_called_once()
+
+ async def test_count_entities_failure(self):
+ self.pipeline.entity_counter.count_entities = AsyncMock(
+ side_effect=Exception("Count failed")
+ )
+ self.pipeline.context_preparer.prepare_example = AsyncMock(return_value="{}")
+ self.pipeline.llm_runner.run_extraction_cycle = AsyncMock(return_value=[])
+
+ await self.pipeline.extract(["doc"], count_entities=True)
+
+ self.mock_logger.warning.assert_called_with(
+ "Entity counting failed: Count failed"
+ )
+
+ def test_repr(self):
+ self.pipeline.config.use_hierarchical_extraction = False
+ repr_str = repr(self.pipeline)
+ self.assertEqual(repr_str, "ExtractionPipeline(mode=standard, root=RootModel)")
+
+ self.pipeline.config.use_hierarchical_extraction = True
+ repr_str = repr(self.pipeline)
+ self.assertEqual(
+ repr_str, "ExtractionPipeline(mode=hierarchical, root=RootModel)"
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/llm_runner/test_llm_runner.py b/tests/core/llm_runner/test_llm_runner.py
new file mode 100644
index 0000000..0d6083b
--- /dev/null
+++ b/tests/core/llm_runner/test_llm_runner.py
@@ -0,0 +1,122 @@
+import unittest
+from unittest.mock import MagicMock, patch, AsyncMock
+from extrai.core.llm_runner import LLMRunner, LLMInteractionError
+from extrai.core.base_llm_client import BaseLLMClient
+from extrai.core.extraction_config import ExtractionConfig
+
+
+class TestLLMRunner(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ self.mock_model_registry = MagicMock()
+ self.mock_config = MagicMock(spec=ExtractionConfig)
+ self.mock_config.num_llm_revisions = 2
+ self.mock_config.consensus_threshold = 0.6
+ self.mock_config.conflict_resolver = None
+ self.mock_config.max_validation_retries_per_revision = 1
+ self.mock_analytics_collector = MagicMock()
+ self.mock_logger = MagicMock()
+
+ self.mock_client1 = MagicMock(spec=BaseLLMClient)
+ self.mock_client2 = MagicMock(spec=BaseLLMClient)
+
+ self.runner = LLMRunner(
+ self.mock_model_registry,
+ [self.mock_client1, self.mock_client2],
+ self.mock_config,
+ self.mock_analytics_collector,
+ self.mock_logger,
+ )
+
+ def test_init_validation(self):
+ with self.assertRaises(ValueError):
+ LLMRunner(
+ self.mock_model_registry,
+ [],
+ self.mock_config,
+ self.mock_analytics_collector,
+ self.mock_logger,
+ )
+
+ def test_client_rotation(self):
+ c1 = self.runner.get_next_client()
+ c2 = self.runner.get_next_client()
+ c3 = self.runner.get_next_client()
+
+ self.assertEqual(c1, self.mock_client1)
+ self.assertEqual(c2, self.mock_client2)
+ self.assertEqual(c3, self.mock_client1)
+
+ @patch("extrai.core.llm_runner.normalize_json_revisions")
+ async def test_run_extraction_cycle_success(self, mock_normalize):
+ # Setup mocks
+ self.mock_client1.generate_json_revisions = AsyncMock(return_value=[{"id": 1}])
+ self.mock_client2.generate_json_revisions = AsyncMock(return_value=[{"id": 1}])
+
+ mock_normalize.return_value = [{"id": 1}, {"id": 1}]
+
+ # Mock consensus
+ with patch.object(self.runner, "consensus") as mock_consensus:
+ mock_consensus.get_consensus.return_value = ([{"id": 1}], {})
+
+ results = await self.runner.run_extraction_cycle("sys", "user")
+
+ self.assertEqual(len(results), 1)
+ self.assertEqual(results[0], {"id": 1})
+
+ # Verify calls
+ self.assertEqual(self.mock_client1.generate_json_revisions.call_count, 1)
+ self.assertEqual(self.mock_client2.generate_json_revisions.call_count, 1)
+ mock_normalize.assert_called_once()
+ mock_consensus.get_consensus.assert_called_once()
+
+ async def test_run_extraction_cycle_llm_failure(self):
+ self.mock_client1.generate_json_revisions = AsyncMock(
+ side_effect=Exception("API Error")
+ )
+ self.mock_client2.generate_json_revisions = AsyncMock(return_value=[])
+
+ # The runner uses asyncio.gather without return_exceptions=True (default is False),
+ # so it propagates the exception.
+
+ with self.assertRaises(LLMInteractionError):
+ await self.runner.run_extraction_cycle("sys", "user")
+
+ def test_process_consensus_output(self):
+ # List
+ res = self.runner._process_consensus_output([{"a": 1}])
+ self.assertEqual(res, [{"a": 1}])
+
+ # None
+ res = self.runner._process_consensus_output(None)
+ self.assertEqual(res, [])
+
+ # Dict
+ res = self.runner._process_consensus_output({"a": 1})
+ self.assertEqual(res, [{"a": 1}])
+
+ # Dict with results
+ res = self.runner._process_consensus_output({"results": [{"a": 1}]})
+ self.assertEqual(res, [{"a": 1}])
+
+ def test_get_client_count(self):
+ self.assertEqual(self.runner.get_client_count(), 2)
+
+ def test_reset_client_rotation(self):
+ # Advance client index
+ self.runner.get_next_client()
+ self.assertEqual(self.runner.client_index, 1)
+
+ self.runner.reset_client_rotation()
+ self.assertEqual(self.runner.client_index, 0)
+
+ self.mock_logger.debug.assert_called_with("Client rotation reset to index 0")
+
+ def test_repr(self):
+ repr_str = repr(self.runner)
+ self.assertIn("LLMRunner", repr_str)
+ self.assertIn("clients=2", repr_str)
+ self.assertIn("revisions=2", repr_str)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/model_registry/test_model_registry.py b/tests/core/model_registry/test_model_registry.py
new file mode 100644
index 0000000..7138de0
--- /dev/null
+++ b/tests/core/model_registry/test_model_registry.py
@@ -0,0 +1,194 @@
+import unittest
+from unittest.mock import MagicMock, patch
+from typing import Optional
+from sqlmodel import SQLModel, Field
+
+from extrai.core.model_registry import ModelRegistry, ConfigurationError
+
+
+class MockRootModel(SQLModel):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ name: str
+
+
+class TestModelRegistry(unittest.TestCase):
+ def setUp(self):
+ self.mock_logger = MagicMock()
+
+ def test_init_success(self):
+ with patch("extrai.core.model_registry.SchemaInspector") as MockInspector:
+ mock_inspector = MockInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = [MockRootModel]
+ mock_inspector.generate_llm_schema_from_models.return_value = (
+ '{"type": "object"}'
+ )
+
+ registry = ModelRegistry(MockRootModel, self.mock_logger)
+
+ self.assertEqual(len(registry.models), 1)
+ self.assertEqual(registry.models[0], MockRootModel)
+ self.assertEqual(registry.llm_schema_json, '{"type": "object"}')
+ self.assertEqual(registry.get_model_by_name("MockRootModel"), MockRootModel)
+
+ def test_init_invalid_root_model(self):
+ with self.assertRaises(ConfigurationError):
+ ModelRegistry("NotAModel", self.mock_logger)
+
+ def test_init_discovery_failure(self):
+ with patch("extrai.core.model_registry.SchemaInspector") as MockInspector:
+ mock_inspector = MockInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.side_effect = Exception(
+ "Discovery failed"
+ )
+
+ with self.assertRaises(ConfigurationError):
+ ModelRegistry(MockRootModel, self.mock_logger)
+
+ def test_init_empty_discovery(self):
+ with patch("extrai.core.model_registry.SchemaInspector") as MockInspector:
+ mock_inspector = MockInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = []
+
+ with self.assertRaises(ConfigurationError):
+ ModelRegistry(MockRootModel, self.mock_logger)
+
+ def test_init_schema_generation_failure(self):
+ with patch("extrai.core.model_registry.SchemaInspector") as MockInspector:
+ mock_inspector = MockInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = [MockRootModel]
+ mock_inspector.generate_llm_schema_from_models.side_effect = Exception(
+ "Schema gen failed"
+ )
+
+ with self.assertRaises(ConfigurationError):
+ ModelRegistry(MockRootModel, self.mock_logger)
+
+ def test_init_invalid_json_schema(self):
+ with patch("extrai.core.model_registry.SchemaInspector") as MockInspector:
+ mock_inspector = MockInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = [MockRootModel]
+ mock_inspector.generate_llm_schema_from_models.return_value = "invalid json"
+
+ with self.assertRaises(ConfigurationError):
+ ModelRegistry(MockRootModel, self.mock_logger)
+
+ def test_get_schema_for_models(self):
+ with patch("extrai.core.model_registry.SchemaInspector") as MockInspector:
+ mock_inspector = MockInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = [MockRootModel]
+ mock_inspector.generate_llm_schema_from_models.return_value = (
+ '{"full": "schema"}'
+ )
+
+ registry = ModelRegistry(MockRootModel, self.mock_logger)
+
+ # Reset mock to verify specific call
+ mock_inspector.generate_llm_schema_from_models.return_value = (
+ '{"partial": "schema"}'
+ )
+
+ schema = registry.get_schema_for_models(["MockRootModel"])
+
+ self.assertEqual(schema, '{"partial": "schema"}')
+ mock_inspector.generate_llm_schema_from_models.assert_called_with(
+ [MockRootModel]
+ )
+
+ def test_get_schema_for_models_fallback(self):
+ with patch("extrai.core.model_registry.SchemaInspector") as MockInspector:
+ mock_inspector = MockInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = [MockRootModel]
+ mock_inspector.generate_llm_schema_from_models.return_value = (
+ '{"full": "schema"}'
+ )
+
+ registry = ModelRegistry(MockRootModel, self.mock_logger)
+
+ # Reset mock
+ mock_inspector.generate_llm_schema_from_models.reset_mock()
+
+ schema = registry.get_schema_for_models(["NonExistentModel"])
+
+ self.assertEqual(schema, '{"full": "schema"}')
+ mock_inspector.generate_llm_schema_from_models.assert_not_called()
+
+ def test_init_root_model_class_not_subclass(self):
+ class NotAModel:
+ pass
+
+ with self.assertRaisesRegex(
+ ConfigurationError, "root_model must be a valid SQLModel class"
+ ):
+ ModelRegistry(NotAModel, self.mock_logger)
+
+ def test_init_empty_generated_schema(self):
+ with patch("extrai.core.model_registry.SchemaInspector") as MockInspector:
+ mock_inspector = MockInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = [MockRootModel]
+ # Return empty schema
+ mock_inspector.generate_llm_schema_from_models.return_value = ""
+
+ with self.assertRaisesRegex(
+ ConfigurationError, "Generated LLM schema is empty"
+ ):
+ ModelRegistry(MockRootModel, self.mock_logger)
+
+ def test_get_schema_for_models_exception(self):
+ with patch("extrai.core.model_registry.SchemaInspector") as MockInspector:
+ mock_inspector = MockInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = [MockRootModel]
+ mock_inspector.generate_llm_schema_from_models.return_value = (
+ '{"full": "schema"}'
+ )
+
+ registry = ModelRegistry(MockRootModel, self.mock_logger)
+
+ # Make generation fail for specific call
+ mock_inspector.generate_llm_schema_from_models.side_effect = Exception(
+ "Boom"
+ )
+
+ schema = registry.get_schema_for_models(["MockRootModel"])
+
+ self.assertEqual(schema, '{"full": "schema"}')
+ # Verify error log
+ self.mock_logger.error.assert_called()
+
+ def test_get_all_model_names(self):
+ with patch("extrai.core.model_registry.SchemaInspector") as MockInspector:
+ mock_inspector = MockInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = [MockRootModel]
+ mock_inspector.generate_llm_schema_from_models.return_value = "{}"
+
+ registry = ModelRegistry(MockRootModel, self.mock_logger)
+
+ names = registry.get_all_model_names()
+ self.assertEqual(names, ["MockRootModel"])
+
+ def test_has_model(self):
+ with patch("extrai.core.model_registry.SchemaInspector") as MockInspector:
+ mock_inspector = MockInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = [MockRootModel]
+ mock_inspector.generate_llm_schema_from_models.return_value = "{}"
+
+ registry = ModelRegistry(MockRootModel, self.mock_logger)
+
+ self.assertTrue(registry.has_model("MockRootModel"))
+ self.assertFalse(registry.has_model("Unknown"))
+
+ def test_repr(self):
+ with patch("extrai.core.model_registry.SchemaInspector") as MockInspector:
+ mock_inspector = MockInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = [MockRootModel]
+ mock_inspector.generate_llm_schema_from_models.return_value = "{}"
+
+ registry = ModelRegistry(MockRootModel, self.mock_logger)
+
+ repr_str = repr(registry)
+ self.assertIn("ModelRegistry", repr_str)
+ self.assertIn("MockRootModel", repr_str)
+ self.assertIn("models=1", repr_str)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/prompts/test_counting.py b/tests/core/prompts/test_counting.py
new file mode 100644
index 0000000..e1ef874
--- /dev/null
+++ b/tests/core/prompts/test_counting.py
@@ -0,0 +1,23 @@
+import unittest
+from extrai.core.prompts.counting import (
+ generate_entity_counting_system_prompt,
+ generate_entity_counting_user_prompt,
+)
+
+
+class TestCountingPrompts(unittest.TestCase):
+ def test_generate_entity_counting_system_prompt(self):
+ prompt = generate_entity_counting_system_prompt(
+ model_names=["TestModel"], schema_json="{}"
+ )
+ self.assertIn("You are an expert data analyst", prompt)
+
+ def test_generate_entity_counting_user_prompt(self):
+ docs = ["doc1", "doc2"]
+ prompt = generate_entity_counting_user_prompt(docs)
+ self.assertIn("doc1", prompt)
+ self.assertIn("doc2", prompt)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/prompts/test_examples.py b/tests/core/prompts/test_examples.py
new file mode 100644
index 0000000..99b72d4
--- /dev/null
+++ b/tests/core/prompts/test_examples.py
@@ -0,0 +1,64 @@
+import unittest
+import json
+from extrai.core.prompts.examples import generate_prompt_for_example_json_generation
+
+
+class TestExamplePrompts(unittest.TestCase):
+ def setUp(self):
+ """Set up common test data."""
+ self.sample_schema_dict = {
+ "type": "object",
+ "title": "TestEntity",
+ "properties": {
+ "id": {"type": "string"},
+ "name": {"type": "string"},
+ "value": {"type": "number"},
+ },
+ "required": ["id", "name"],
+ }
+ self.sample_schema_json_str = json.dumps(self.sample_schema_dict, indent=2)
+
+ def test_generate_prompt_for_example_json_generation(self):
+ """Test the prompt generation for creating an example JSON."""
+ root_model_name = "SampleOutputModel"
+
+ prompt = generate_prompt_for_example_json_generation(
+ target_model_schema_str=self.sample_schema_json_str,
+ root_model_name=root_model_name,
+ )
+
+ # General instructions
+ self.assertIn(
+ "You are an AI assistant tasked with generating a sample JSON object.",
+ prompt,
+ )
+ self.assertIn(
+ f"The goal is to create a single, valid JSON object that conforms to the provided schema for a model named '{root_model_name}' and its related models.",
+ prompt,
+ )
+ self.assertIn("# JSON SCHEMA TO ADHERE TO:", prompt)
+ self.assertIn(self.sample_schema_json_str, prompt)
+ self.assertIn(
+ "Your output MUST be a single JSON object with a top-level key named 'entities'.",
+ prompt,
+ )
+ self.assertIn(
+ "Each object inside the 'entities' list MUST include two metadata fields:",
+ prompt,
+ )
+ self.assertIn(
+ "`_type`: This field's value MUST be a string matching the name of the model it represents",
+ prompt,
+ )
+ self.assertIn(
+ "`_temp_id`: This field's value MUST be a unique temporary string identifier for that specific entity instance",
+ prompt,
+ )
+ self.assertIn(
+ "Your 'entities' list should contain an instance of the root model", prompt
+ )
+ self.assertIn("at least one instance of each of its related models", prompt)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/prompts/test_extraction.py b/tests/core/prompts/test_extraction.py
new file mode 100644
index 0000000..b0f506f
--- /dev/null
+++ b/tests/core/prompts/test_extraction.py
@@ -0,0 +1,324 @@
+import unittest
+import json
+from extrai.core.prompts.common import generate_user_prompt_for_docs
+from extrai.core.prompts.extraction import generate_system_prompt
+
+
+class TestExtractionPrompts(unittest.TestCase):
+ def setUp(self):
+ """Set up common test data."""
+ self.sample_schema_dict = {
+ "type": "object",
+ "title": "TestEntity",
+ "properties": {
+ "id": {"type": "string"},
+ "name": {"type": "string"},
+ "value": {"type": "number"},
+ },
+ "required": ["id", "name"],
+ }
+ self.sample_schema_json_str = json.dumps(self.sample_schema_dict, indent=2)
+
+ self.sample_extraction_example_dict = {
+ "id": "test001",
+ "name": "Test Name",
+ "value": 123.45,
+ }
+ self.sample_extraction_example_json_str = json.dumps(
+ self.sample_extraction_example_dict, indent=2
+ )
+
+ def test_generate_system_prompt_basic(self):
+ """Test system prompt with only schema."""
+ prompt = generate_system_prompt(schema_json=self.sample_schema_json_str)
+ self.assertIn("You are an advanced AI specializing in data extraction", prompt)
+ self.assertIn("# JSON SCHEMA TO ADHERE TO:", prompt)
+ self.assertIn(self.sample_schema_json_str, prompt)
+ self.assertIn("# EXTRACTION PROCESS", prompt) # Default process
+ self.assertIn("# IMPORTANT EXTRACTION GUIDELINES", prompt) # Default guidelines
+ self.assertIn("# FINAL CHECK BEFORE SUBMISSION", prompt) # Default checklist
+ self.assertNotIn("# EXAMPLE OF EXTRACTION:", prompt)
+
+ def test_generate_system_prompt_with_example(self):
+ """Test system prompt with schema and example."""
+ prompt = generate_system_prompt(
+ schema_json=self.sample_schema_json_str,
+ extraction_example_json=self.sample_extraction_example_json_str,
+ )
+ self.assertIn(self.sample_schema_json_str, prompt)
+ self.assertIn("# EXAMPLE OF EXTRACTION:", prompt)
+ self.assertIn("## CONCEPTUAL INPUT TEXT", prompt)
+ self.assertIn(self.sample_extraction_example_json_str, prompt)
+
+ def test_generate_system_prompt_with_custom_process(self):
+ """Test system prompt with custom extraction process."""
+ custom_process = "# MY CUSTOM PROCESS\n1. Do this first.\n2. Then do that."
+ prompt = generate_system_prompt(
+ schema_json=self.sample_schema_json_str,
+ custom_extraction_process=custom_process,
+ )
+ self.assertIn(custom_process, prompt)
+ self.assertNotIn(
+ "Follow this step-by-step process meticulously:", prompt
+ ) # Default process heading
+ self.assertIn(
+ "# IMPORTANT EXTRACTION GUIDELINES", prompt
+ ) # Default guidelines should still be there
+ self.assertIn("# FINAL CHECK BEFORE SUBMISSION", prompt) # Default checklist
+
+ def test_generate_system_prompt_with_custom_guidelines(self):
+ """Test system prompt with custom extraction guidelines."""
+ custom_guidelines = "# MY CUSTOM GUIDELINES\n- Be awesome.\n- Be accurate."
+ prompt = generate_system_prompt(
+ schema_json=self.sample_schema_json_str,
+ custom_extraction_guidelines=custom_guidelines,
+ )
+ self.assertIn(custom_guidelines, prompt)
+ self.assertNotIn(
+ "- **Output Format:** Your entire output must be a single, valid JSON object.",
+ prompt,
+ ) # Default guideline
+ self.assertIn("# EXTRACTION PROCESS", prompt) # Default process
+ self.assertIn("# FINAL CHECK BEFORE SUBMISSION", prompt) # Default checklist
+
+ def test_generate_system_prompt_with_custom_checklist(self):
+ """Test system prompt with custom final checklist."""
+ custom_checklist = "# MY CUSTOM CHECKLIST\n1. Did I do well?\n2. Yes."
+ prompt = generate_system_prompt(
+ schema_json=self.sample_schema_json_str,
+ custom_final_checklist=custom_checklist,
+ )
+ self.assertIn(custom_checklist, prompt)
+ self.assertNotIn("1. **Valid JSON?**", prompt) # Default checklist item
+ self.assertIn("# EXTRACTION PROCESS", prompt) # Default process
+ self.assertIn("# IMPORTANT EXTRACTION GUIDELINES", prompt) # Default guidelines
+
+ def test_generate_system_prompt_with_custom_context(self):
+ """Test system prompt with custom_context."""
+ custom_context_content = "This is some important external context for the LLM."
+ prompt = generate_system_prompt(
+ schema_json=self.sample_schema_json_str,
+ custom_context=custom_context_content,
+ )
+ self.assertIn("# ADDITIONAL CONTEXT:", prompt)
+ self.assertIn(custom_context_content, prompt)
+ self.assertIn(
+ self.sample_schema_json_str, prompt
+ ) # Ensure schema is still there
+ self.assertIn(
+ "# EXTRACTION PROCESS", prompt
+ ) # Default process should still be there
+
+ # Test that it's not included if empty
+ prompt_no_custom_context = generate_system_prompt(
+ schema_json=self.sample_schema_json_str, custom_context=""
+ )
+ self.assertNotIn("# ADDITIONAL CONTEXT:", prompt_no_custom_context)
+ self.assertNotIn(custom_context_content, prompt_no_custom_context)
+
+ def test_generate_system_prompt_all_custom(self):
+ """Test system prompt with all custom sections."""
+ custom_process = "# CUSTOM PROCESS V2"
+ custom_guidelines = "# CUSTOM GUIDELINES V2"
+ custom_checklist = "# CUSTOM CHECKLIST V2"
+ custom_context_content = "All custom context here for the all_custom test."
+ prompt = generate_system_prompt(
+ schema_json=self.sample_schema_json_str,
+ extraction_example_json=self.sample_extraction_example_json_str,
+ custom_extraction_process=custom_process,
+ custom_extraction_guidelines=custom_guidelines,
+ custom_final_checklist=custom_checklist,
+ custom_context=custom_context_content,
+ )
+ self.assertIn(custom_process, prompt)
+ self.assertIn(custom_guidelines, prompt)
+ self.assertIn(custom_checklist, prompt)
+ self.assertIn("# ADDITIONAL CONTEXT:", prompt)
+ self.assertIn(custom_context_content, prompt)
+ self.assertIn(self.sample_schema_json_str, prompt)
+ self.assertIn(self.sample_extraction_example_json_str, prompt)
+ self.assertNotIn("Follow this step-by-step process meticulously:", prompt)
+ self.assertNotIn(
+ "- **Output Format:** Your entire output must be a single, valid JSON object.",
+ prompt,
+ )
+ self.assertNotIn("1. **Valid JSON?**", prompt)
+
+ def test_generate_user_prompt_single_document(self):
+ """Test user prompt with a single document."""
+ doc1 = "This is the first document."
+ prompt = generate_user_prompt_for_docs([doc1])
+ self.assertIn(
+ "Please extract information from the following document(s).", prompt
+ )
+ self.assertIn("# DOCUMENT(S) FOR EXTRACTION:", prompt)
+ self.assertIn(doc1, prompt)
+ self.assertNotIn("---END OF DOCUMENT---", prompt) # No separator for single doc
+ self.assertIn(
+ "Remember: Your output must be only a single, valid JSON object.", prompt
+ )
+
+ def test_generate_user_prompt_multiple_documents(self):
+ """Test user prompt with multiple documents."""
+ doc1 = "Document one content."
+ doc2 = "Document two content here."
+ doc3 = "And a third document."
+ separator = "\n\n---END OF DOCUMENT---\n\n---START OF NEW DOCUMENT---\n\n"
+ prompt = generate_user_prompt_for_docs([doc1, doc2, doc3])
+ self.assertIn(doc1, prompt)
+ self.assertIn(doc2, prompt)
+ self.assertIn(doc3, prompt)
+ self.assertEqual(prompt.count(separator), 2) # Two separators for three docs
+ self.assertIn("# DOCUMENT(S) FOR EXTRACTION:", prompt)
+
+ def test_generate_user_prompt_empty_documents_list(self):
+ """Test user prompt with an empty list of documents."""
+ prompt = generate_user_prompt_for_docs([])
+ self.assertIn(
+ "Please extract information from the following document(s).",
+ prompt,
+ )
+ self.assertIn("# DOCUMENT(S) FOR EXTRACTION:", prompt)
+
+ # Check that the space between "EXTRACTION:" and "---" is just newlines (or empty)
+ extraction_header_end_index = prompt.find("# DOCUMENT(S) FOR EXTRACTION:")
+ if extraction_header_end_index != -1:
+ extraction_header_end_index += len("# DOCUMENT(S) FOR EXTRACTION:")
+
+ reminder_start_index = prompt.find(
+ "---"
+ ) # Assuming "---" is the start of the reminder section
+
+ if (
+ extraction_header_end_index != -1
+ and reminder_start_index != -1
+ and extraction_header_end_index < reminder_start_index
+ ):
+ content_between = prompt[extraction_header_end_index:reminder_start_index]
+ self.assertEqual(
+ content_between.strip(),
+ "",
+ f"Content between extraction header and reminder should be whitespace only, but was: '{content_between}'",
+ )
+ elif extraction_header_end_index == -1:
+ self.fail("Could not find '# DOCUMENT(S) FOR EXTRACTION:' marker.")
+ elif reminder_start_index == -1:
+ self.fail("Could not find '---' marker for reminder section.")
+ else: # Markers found but order is wrong or overlap
+ self.fail(
+ f"Problem with marker positions: extraction_header_end_index={extraction_header_end_index}, reminder_start_index={reminder_start_index}"
+ )
+
+ self.assertIn(
+ "Remember: Your output must be only a single, valid JSON object.", prompt
+ )
+ self.assertNotIn(
+ "---END OF DOCUMENT---", prompt
+ ) # Still important for empty list
+
+ def test_generate_user_prompt_documents_with_special_chars(self):
+ """Test user prompt with documents containing special JSON characters."""
+ doc1 = 'This document has "quotes" and {curly braces}.'
+ doc2 = "Another one with a backslash \\ and newlines \n in theory."
+ prompt = generate_user_prompt_for_docs([doc1, doc2])
+ self.assertIn(doc1, prompt)
+ self.assertIn(doc2, prompt)
+ self.assertIn("---END OF DOCUMENT---", prompt)
+
+ def test_generate_system_prompt_with_non_json_example_string(self):
+ """
+ Test system prompt with an extraction_example_json that is a non-empty,
+ non-JSON string. This should cover line 108 and the 'else' branch
+ of the inner conditional (line 119 in prompt_builder.py).
+ """
+ non_json_example_str = "This is a raw example string, not a JSON object."
+ prompt = generate_system_prompt(
+ schema_json=self.sample_schema_json_str,
+ extraction_example_json=non_json_example_str,
+ )
+
+ # Check that line 108's content ("# EXAMPLE OF EXTRACTION:") is present
+ self.assertIn("# EXAMPLE OF EXTRACTION:", prompt)
+
+ # Check that the non_json_example_str itself is used (from line 119)
+ self.assertIn(non_json_example_str, prompt)
+
+ # Check that it's NOT wrapped with {"result": ...} (line 117 should not be hit)
+ # Constructing the exact f-string format for the negative assertion
+ wrapped_example_check = f'{{\n "result": {non_json_example_str}\n}}'
+ self.assertNotIn(wrapped_example_check, prompt)
+
+ # Also check for other parts of the example section to be sure they are still there
+ self.assertIn("## CONCEPTUAL INPUT TEXT", prompt)
+ self.assertIn("## EXAMPLE EXTRACTED JSON", prompt)
+ # Ensure the ```json block markers are present around the example
+ # The example string is non_json_example_str
+ # So we expect "```json\n\n" + non_json_example_str + "\n\n```" (joined by \n\n)
+ # More robustly, check that the example string is between ```json and ```
+ # Find the start of the example section text
+ example_section_header_idx = prompt.find("## EXAMPLE EXTRACTED JSON")
+ self.assertNotEqual(
+ example_section_header_idx != -1, "Example JSON header not found"
+ )
+
+ # Find ```json after this header
+ json_block_start_marker = "```json"
+ json_block_start_idx = prompt.find(
+ json_block_start_marker, example_section_header_idx
+ )
+ self.assertNotEqual(
+ json_block_start_idx != -1,
+ "```json start marker not found after example header",
+ )
+
+ # Find the example string after the ```json marker
+ example_str_idx = prompt.find(
+ non_json_example_str, json_block_start_idx + len(json_block_start_marker)
+ )
+ self.assertNotEqual(
+ example_str_idx != -1,
+ "Non-JSON example string not found after ```json marker",
+ )
+
+ # Find ``` end marker after the example string
+ json_block_end_marker = "```"
+ json_block_end_idx = prompt.find(
+ json_block_end_marker, example_str_idx + len(non_json_example_str)
+ )
+ self.assertNotEqual(
+ json_block_end_idx != -1,
+ "``` end marker not found after non-JSON example string",
+ )
+
+ def test_generate_system_prompt_includes_ordering_and_id_rules(self):
+ """Test that the system prompt includes instructions for ordering and semantic IDs."""
+ prompt = generate_system_prompt(schema_json=self.sample_schema_json_str)
+
+ # Check for ordering instruction
+ self.assertIn(
+ "Maintain the order of items as they appear in the source text",
+ prompt,
+ "Prompt missing instruction about preserving order",
+ )
+
+ # Check for semantic ID instruction
+ self.assertIn(
+ "based on the entity's key attributes",
+ prompt,
+ "Prompt missing instruction about semantic IDs",
+ )
+ self.assertIn(
+ "E.g., `user_john_doe`", prompt, "Prompt missing example of semantic IDs"
+ )
+
+ def test_generate_user_prompt_with_custom_context(self):
+ """Test user prompt with custom context."""
+ doc1 = "This is a document."
+ custom_ctx = "Pay attention to X."
+ prompt = generate_user_prompt_for_docs([doc1], custom_context=custom_ctx)
+ self.assertIn(custom_ctx, prompt)
+ self.assertIn(doc1, prompt)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/prompts/test_sqlmodel.py b/tests/core/prompts/test_sqlmodel.py
new file mode 100644
index 0000000..1810db2
--- /dev/null
+++ b/tests/core/prompts/test_sqlmodel.py
@@ -0,0 +1,193 @@
+import unittest
+import json
+from extrai.core.prompts.sqlmodel import generate_sqlmodel_creation_system_prompt
+
+
+class TestSQLModelPrompts(unittest.TestCase):
+ def test_generate_sqlmodel_creation_system_prompt(self):
+ """Test the specialized system prompt for SQLModel creation."""
+ sample_sqlmodel_schema_str = json.dumps(
+ {
+ "title": "SQLModelDescription",
+ "type": "object",
+ "properties": {
+ "model_name": {"type": "string"},
+ "fields": {"type": "array", "items": {"type": "object"}},
+ },
+ "required": ["model_name", "fields"],
+ },
+ indent=2,
+ )
+ user_task_desc = "Create a model for tracking library books, including title, author, and ISBN."
+
+ prompt = generate_sqlmodel_creation_system_prompt(
+ schema_json=sample_sqlmodel_schema_str, user_task_description=user_task_desc
+ )
+
+ self.assertIn(
+ "You are an AI assistant tasked with designing one or more SQLModel class definitions.",
+ prompt,
+ )
+
+ self.assertIn(
+ "3. Each object in the `sql_models` list MUST strictly adhere to the following JSON schema for a SQLModel description:",
+ prompt,
+ )
+ # Correctly check for the schema block - note the double newlines from how prompt_parts are joined
+ self.assertIn(f"```json\n\n{sample_sqlmodel_schema_str}\n\n```", prompt)
+
+ # Check for the new "IMPORTANT CONSIDERATIONS" section
+ self.assertIn("# IMPORTANT CONSIDERATIONS FOR DATABASE TABLE MODELS:", prompt)
+ self.assertIn(
+ 'field_options_str": "Field(default_factory=list, sa_type=JSON)"', prompt
+ ) # Example for List
+ self.assertIn(
+ 'add `"from sqlmodel import JSON"` to the main `imports` array', prompt
+ )
+ self.assertIn(
+ "MUST provide a sensible `default` value", prompt
+ ) # Instruction for defaults
+
+ # Check for the user task description section
+ self.assertIn("# USER'S TASK:", prompt)
+ self.assertIn(
+ f'The user wants to define a SQLModel based on the following objective: "{user_task_desc}"',
+ prompt,
+ )
+ self.assertIn(
+ "Pay close attention to the requirements for List/Dict types if the model is a table, and try to provide default values for required fields.",
+ prompt,
+ )
+
+ # Check for the hardcoded example section
+ self.assertIn(
+ "# EXAMPLE OF A VALID SQLMODEL DESCRIPTION JSON (Illustrating a list of models):",
+ prompt,
+ )
+ # Check for a snippet from the hardcoded example to ensure it's present
+ self.assertIn('"model_name": "ExampleItem"', prompt)
+ self.assertIn('"table_name": "example_items"', prompt)
+ self.assertIn("Timestamp of when the item was created.", prompt)
+ self.assertIn('"name": "categories"', prompt) # Part of the new example
+ self.assertIn(
+ '"field_options_str": "Field(default_factory=list, sa_type=JSON)"', prompt
+ ) # Part of the new example
+ self.assertIn(
+ '"from sqlmodel import SQLModel, Field, JSON"', prompt
+ ) # Import in the new example
+
+ def test_generate_sqlmodel_creation_system_prompt_structure_and_hardcoded_example(
+ self,
+ ):
+ """
+ Tests the overall structure and presence of the hardcoded example in
+ the SQLModel creation prompt.
+ """
+ sample_schema_for_description_str = json.dumps(
+ {
+ "title": "SQLModelDesc",
+ "type": "object",
+ "properties": {"model_name": {"type": "string"}},
+ "required": ["model_name"],
+ },
+ indent=2,
+ )
+ user_task = "Define a product model."
+
+ prompt = generate_sqlmodel_creation_system_prompt(
+ schema_json=sample_schema_for_description_str,
+ user_task_description=user_task,
+ )
+
+ # General structure checks
+ self.assertTrue(
+ prompt.startswith(
+ "You are an AI assistant tasked with designing one or more SQLModel class definitions."
+ )
+ )
+ self.assertIn("# REQUIREMENTS FOR YOUR OUTPUT:", prompt)
+ self.assertIn(
+ "3. Each object in the `sql_models` list MUST strictly adhere to the following JSON schema for a SQLModel description:",
+ prompt,
+ )
+ self.assertIn(
+ sample_schema_for_description_str, prompt
+ ) # Check if the passed schema is there
+ self.assertIn("# USER'S TASK:", prompt)
+ self.assertIn(user_task, prompt)
+ self.assertIn(
+ "# IMPORTANT CONSIDERATIONS FOR DATABASE TABLE MODELS:", prompt
+ ) # New section check
+ self.assertIn(
+ "# EXAMPLE OF A VALID SQLMODEL DESCRIPTION JSON (Illustrating a list of models):",
+ prompt,
+ )
+ self.assertTrue(
+ prompt.endswith(
+ "Do not include any other narrative, explanations, or conversational elements in your output."
+ )
+ )
+
+ # Specific checks for the hardcoded example's content (which now includes List[str] example)
+ self.assertIn('"model_name": "ExampleItem"', prompt)
+ self.assertIn('"table_name": "example_items"', prompt)
+ self.assertIn('primary_key": true', prompt)
+ self.assertIn("datetime.datetime.utcnow", prompt)
+ self.assertIn('"name": "categories"', prompt)
+ self.assertIn(
+ '"field_options_str": "Field(default_factory=list, sa_type=JSON)"', prompt
+ )
+ self.assertIn('"from sqlmodel import SQLModel, Field, JSON"', prompt)
+
+ # Ensure the example JSON block is correctly formatted
+ example_intro_text = "This is an example of the kind of JSON object you should produce (it conforms to the schema above):"
+ self.assertIn(example_intro_text, prompt)
+
+ # Extract the example JSON part to validate it
+ try:
+ # Find the start of the example JSON block
+ json_block_marker = "```json\n\n" # Note the double newline
+ # Find the specific example block after the intro text
+ example_intro_end_idx = prompt.find(example_intro_text) + len(
+ example_intro_text
+ )
+ json_code_block_start_idx = prompt.find(
+ json_block_marker, example_intro_end_idx
+ )
+
+ if json_code_block_start_idx == -1:
+ self.fail(
+ f"Could not find the start of the example JSON code block ('{json_block_marker}')."
+ )
+
+ # Move past the marker itself
+ actual_json_start_idx = json_code_block_start_idx + len(json_block_marker)
+
+ # Find the end of this specific JSON code block (which is \n\n```)
+ json_code_block_end_marker = "\n\n```"
+ json_code_block_end_idx = prompt.find(
+ json_code_block_end_marker, actual_json_start_idx
+ )
+ if json_code_block_end_idx == -1:
+ self.fail(
+ f"Could not find the end of the example JSON code block ('{json_code_block_end_marker}')."
+ )
+
+ example_json_str_from_prompt = prompt[
+ actual_json_start_idx:json_code_block_end_idx
+ ].strip()
+
+ # Validate that this extracted string is valid JSON
+ json.loads(example_json_str_from_prompt)
+ except json.JSONDecodeError as e:
+ self.fail(
+ f"Hardcoded example JSON in prompt is not valid JSON: {e}\nExtracted JSON string:\n'{example_json_str_from_prompt}'"
+ )
+ except Exception as e: # Catch other potential errors during extraction
+ self.fail(
+ f"Failed to extract or validate the hardcoded example JSON from prompt: {e}"
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/prompts/test_structured_extraction.py b/tests/core/prompts/test_structured_extraction.py
new file mode 100644
index 0000000..e9ba38c
--- /dev/null
+++ b/tests/core/prompts/test_structured_extraction.py
@@ -0,0 +1,22 @@
+from extrai.core.prompts.structured_extraction import (
+ generate_structured_system_prompt,
+)
+from extrai.core.prompts.common import generate_user_prompt_for_docs
+
+
+def test_generate_structured_system_prompt():
+ prompt = generate_structured_system_prompt()
+ assert "# EXTRACTION INSTRUCTIONS" in prompt
+ assert "Extract Entities" in prompt
+
+ custom = "Pay attention to dates."
+ prompt_custom = generate_structured_system_prompt(custom_extraction_process=custom)
+ assert custom in prompt_custom
+
+
+def test_generate_user_prompt_for_docs():
+ docs = ["Doc 1 content", "Doc 2 content"]
+ prompt = generate_user_prompt_for_docs(docs)
+ assert "Doc 1 content" in prompt
+ assert "Doc 2 content" in prompt
+ assert "---END OF DOCUMENT---" in prompt
diff --git a/tests/core/test_sqlalchemy_hydrator.py b/tests/core/result_processor/test_hydration.py
similarity index 85%
rename from tests/core/test_sqlalchemy_hydrator.py
rename to tests/core/result_processor/test_hydration.py
index b82fb60..59f1feb 100644
--- a/tests/core/test_sqlalchemy_hydrator.py
+++ b/tests/core/result_processor/test_hydration.py
@@ -1,4 +1,6 @@
import unittest
+import unittest.mock
+import logging
import uuid
import sys
import io
@@ -13,7 +15,7 @@
select,
)
-from extrai.core.sqlalchemy_hydrator import SQLAlchemyHydrator
+from extrai.core.result_processor import SQLAlchemyHydrator
# 1. Setup SQLModel Models
@@ -81,7 +83,8 @@ def setUp(self):
self.engine = create_engine("sqlite:///:memory:")
SQLModel.metadata.create_all(self.engine)
self.session = SQLModelSession(self.engine)
- self.hydrator = SQLAlchemyHydrator(self.session)
+ self.test_logger = logging.getLogger("test_logger")
+ self.hydrator = SQLAlchemyHydrator(self.session, logger=self.test_logger)
# Correctly define the map with SQLModel classes
self.model_sqlmodel_map: Dict[str, Type[SQLModel]] = {
"author": Author,
@@ -302,17 +305,26 @@ def test_hydrate_broken_ref_id(self):
"author_ref_id": "non_existent_author",
}
]
- with self.captured_stdout() as captured:
+ with unittest.mock.patch.object(self.test_logger, "warning") as mock_warning:
instances = self._hydrate_and_query(entities_list)
- self.assertEqual(len(instances), 1)
- book = instances[0]
- self.assertEqual(book.title, "Book with Broken Link")
- self.assertIsNone(book.author)
- self.assertIn(
- "Warning: Referenced _temp_id 'non_existent_author' for relation 'author'",
- captured.getvalue(),
- )
+ self.assertEqual(len(instances), 1)
+ book = instances[0]
+ self.assertEqual(book.title, "Book with Broken Link")
+ self.assertIsNone(book.author)
+
+ self.assertTrue(mock_warning.called)
+ # Check if any call contains the expected message
+ found = False
+ for call in mock_warning.call_args_list:
+ args, _ = call
+ if (
+ "Referenced _temp_id 'non_existent_author' for relation 'author'"
+ in args[0]
+ ):
+ found = True
+ break
+ self.assertTrue(found, "Expected warning message not found in logger calls")
def test_hydrate_null_ref_id(self):
entities_list = [
@@ -484,16 +496,24 @@ def test_hydrate_ref_ids_with_invalid_value_type(self):
"books_ref_ids": "not_a_list",
},
]
- with self.captured_stdout() as captured:
+ with unittest.mock.patch.object(self.test_logger, "warning") as mock_warning:
instances = self._hydrate_and_query(entities_list)
- author = next((i for i in instances if isinstance(i, Author)), None)
- self.assertIsNotNone(author)
- self.assertEqual(len(author.books), 0)
- self.assertIn(
- "Warning: Value for 'books_ref_ids' on instance 'auth_invalid_ref_ids' is not a list",
- captured.getvalue(),
- )
+ author = next((i for i in instances if isinstance(i, Author)), None)
+ self.assertIsNotNone(author)
+ self.assertEqual(len(author.books), 0)
+
+ self.assertTrue(mock_warning.called)
+ found = False
+ for call in mock_warning.call_args_list:
+ args, _ = call
+ if (
+ "Value for 'books_ref_ids' on instance 'auth_invalid_ref_ids' is not a list"
+ in args[0]
+ ):
+ found = True
+ break
+ self.assertTrue(found, "Expected warning message not found in logger calls")
def test_hydrate_ref_ids_list_with_invalid_item_type_or_missing_ref(self):
with self.subTest("invalid_item_type"):
@@ -506,12 +526,24 @@ def test_hydrate_ref_ids_list_with_invalid_item_type_or_missing_ref(self):
},
{"_type": "book", "_temp_id": "book1", "title": "Book 1"},
]
- with self.captured_stdout() as captured:
+ with unittest.mock.patch.object(
+ self.test_logger, "warning"
+ ) as mock_warning:
self._hydrate_and_query(entities_list)
- self.assertIn(
- "Warning: Referenced _temp_id '123' in list for relation 'books' on instance 'auth1' (type: author) not found or invalid type.",
- captured.getvalue(),
- )
+
+ self.assertTrue(mock_warning.called)
+ found = False
+ for call in mock_warning.call_args_list:
+ args, _ = call
+ if (
+ "Referenced _temp_id '123' in list for relation 'books' on instance 'auth1' (type: author) not found or invalid type."
+ in args[0]
+ ):
+ found = True
+ break
+ self.assertTrue(
+ found, "Expected warning message not found in logger calls"
+ )
with self.subTest("missing_ref"):
entities_list = [
@@ -523,12 +555,24 @@ def test_hydrate_ref_ids_list_with_invalid_item_type_or_missing_ref(self):
},
{"_type": "book", "_temp_id": "book2", "title": "Book 2"},
]
- with self.captured_stdout() as captured:
+ with unittest.mock.patch.object(
+ self.test_logger, "warning"
+ ) as mock_warning:
self._hydrate_and_query(entities_list)
- self.assertIn(
- "Warning: Referenced _temp_id 'non_existent' in list for relation 'books' on instance 'auth2' (type: author) not found or invalid type.",
- captured.getvalue(),
- )
+
+ self.assertTrue(mock_warning.called)
+ found = False
+ for call in mock_warning.call_args_list:
+ args, _ = call
+ if (
+ "Referenced _temp_id 'non_existent' in list for relation 'books' on instance 'auth2' (type: author) not found or invalid type."
+ in args[0]
+ ):
+ found = True
+ break
+ self.assertTrue(
+ found, "Expected warning message not found in logger calls"
+ )
def test_no_pk_coverage(self):
"""
diff --git a/tests/core/result_processor/test_hydration_strategies.py b/tests/core/result_processor/test_hydration_strategies.py
new file mode 100644
index 0000000..93e7a6c
--- /dev/null
+++ b/tests/core/result_processor/test_hydration_strategies.py
@@ -0,0 +1,147 @@
+import unittest
+import unittest.mock
+import logging
+from typing import List, Optional
+from sqlmodel import Relationship, SQLModel, Field, create_engine, Session
+from extrai.core.result_processor import (
+ ResultProcessor,
+ DirectHydrator,
+)
+from extrai.core.model_registry import ModelRegistry
+
+
+# Define models
+class StrategyAuthor(SQLModel, table=True):
+ __tablename__ = "strategy_author"
+ id: Optional[int] = Field(default=None, primary_key=True)
+ name: str
+ nested_books: List["StrategyNestedBook"] = Relationship(back_populates="author")
+
+
+class StrategyBook(SQLModel, table=True):
+ __tablename__ = "strategy_book"
+ id: Optional[int] = Field(default=None, primary_key=True)
+ title: str
+ author_id: Optional[int] = Field(default=None, foreign_key="strategy_author.id")
+
+
+class StrategyNestedBook(SQLModel, table=True):
+ __tablename__ = "strategy_nested_book"
+ id: Optional[int] = Field(default=None, primary_key=True)
+ title: str
+ author_id: Optional[int] = Field(default=None, foreign_key="strategy_author.id")
+ author: Optional[StrategyAuthor] = Relationship(back_populates="nested_books")
+
+
+class StrategyLibrary(SQLModel):
+ books: List[StrategyNestedBook]
+
+
+# Update forward refs
+StrategyAuthor.model_rebuild()
+StrategyNestedBook.model_rebuild()
+
+
+class TestResultProcessorStrategies(unittest.TestCase):
+ def setUp(self):
+ self.engine = create_engine("sqlite:///:memory:")
+ SQLModel.metadata.create_all(self.engine)
+ self.session = Session(self.engine)
+ self.logger = logging.getLogger("test")
+
+ # Mock ModelRegistry
+ self.model_registry = unittest.mock.Mock(spec=ModelRegistry)
+ self.model_registry.model_map = {
+ "author": StrategyAuthor,
+ "book": StrategyBook,
+ "library": StrategyLibrary,
+ "nested_book": StrategyNestedBook,
+ }
+ self.model_registry.root_model = StrategyAuthor # Default
+
+ self.processor = ResultProcessor(
+ self.model_registry, unittest.mock.Mock(), self.logger
+ )
+
+ def tearDown(self):
+ self.session.close()
+
+ def test_auto_detect_sqlalchemy_hydrator(self):
+ data = [{"_type": "author", "_temp_id": "1", "name": "Test"}]
+
+ # Mock hydrators to verify which one is called
+ with unittest.mock.patch(
+ "extrai.core.result_processor.SQLAlchemyHydrator"
+ ) as MockSQLHydrator:
+ mock_instance = MockSQLHydrator.return_value
+ mock_instance.hydrate.return_value = []
+
+ self.processor.hydrate(data, self.session)
+
+ MockSQLHydrator.assert_called()
+
+ def test_auto_detect_direct_hydrator_no_temp_id(self):
+ data = [{"name": "Test"}] # No _temp_id
+
+ with unittest.mock.patch(
+ "extrai.core.result_processor.DirectHydrator"
+ ) as MockDirectHydrator:
+ mock_instance = MockDirectHydrator.return_value
+ mock_instance.hydrate.return_value = []
+
+ self.processor.hydrate(data, self.session)
+
+ MockDirectHydrator.assert_called()
+
+ def test_auto_detect_direct_hydrator_explicit_default(self):
+ data = [{"name": "Test"}]
+
+ with unittest.mock.patch(
+ "extrai.core.result_processor.DirectHydrator"
+ ) as MockDirectHydrator:
+ mock_instance = MockDirectHydrator.return_value
+ mock_instance.hydrate.return_value = []
+
+ self.processor.hydrate(data, self.session, default_model_type="author")
+
+ MockDirectHydrator.assert_called()
+
+ def test_direct_hydrator_with_type(self):
+ hydrator = DirectHydrator(self.session, self.logger)
+ data = [{"_type": "author", "name": "Direct Author"}]
+
+ results = hydrator.hydrate(data, self.model_registry.model_map)
+
+ self.assertEqual(len(results), 1)
+ self.assertIsInstance(results[0], StrategyAuthor)
+ self.assertEqual(results[0].name, "Direct Author")
+
+ def test_direct_hydrator_default_model(self):
+ hydrator = DirectHydrator(self.session, self.logger)
+ data = [{"name": "Default Author"}] # No _type
+
+ results = hydrator.hydrate(
+ data, self.model_registry.model_map, default_model_class=StrategyAuthor
+ )
+
+ self.assertEqual(len(results), 1)
+ self.assertIsInstance(results[0], StrategyAuthor)
+ self.assertEqual(results[0].name, "Default Author")
+
+ def test_direct_hydrator_nested(self):
+ hydrator = DirectHydrator(self.session, self.logger)
+ data = [
+ {
+ "_type": "nested_book",
+ "title": "Nested Title",
+ "author": {"name": "Nested Author"},
+ }
+ ]
+
+ results = hydrator.hydrate(data, self.model_registry.model_map)
+
+ self.assertEqual(len(results), 1)
+ self.assertIsInstance(results[0], StrategyNestedBook)
+ self.assertEqual(results[0].title, "Nested Title")
+ self.assertIsInstance(results[0].author, StrategyAuthor)
+ self.assertEqual(results[0].author.name, "Nested Author")
diff --git a/tests/core/test_db_writer.py b/tests/core/result_processor/test_persistence.py
similarity index 98%
rename from tests/core/test_db_writer.py
rename to tests/core/result_processor/test_persistence.py
index 38ad22c..dca61b1 100644
--- a/tests/core/test_db_writer.py
+++ b/tests/core/result_processor/test_persistence.py
@@ -6,7 +6,7 @@
from sqlalchemy.exc import SQLAlchemyError
# Adjust the import path based on your project structure
-from extrai.core.db_writer import (
+from extrai.core.result_processor import (
persist_objects,
DatabaseWriterError,
)
diff --git a/tests/core/result_processor/test_pk_conflict_prevention.py b/tests/core/result_processor/test_pk_conflict_prevention.py
new file mode 100644
index 0000000..0656f67
--- /dev/null
+++ b/tests/core/result_processor/test_pk_conflict_prevention.py
@@ -0,0 +1,136 @@
+import logging
+from typing import List, Optional
+from sqlmodel import Field, SQLModel, Relationship, create_engine, Session, select
+from extrai.core.result_processor import ResultProcessor
+from extrai.core.model_registry import ModelRegistry
+
+# Setup Logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+# Define Models
+class Parent(SQLModel, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ name: str
+ children: List["Child"] = Relationship(back_populates="parent")
+
+
+class Child(SQLModel, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ name: str
+ parent_id: Optional[int] = Field(default=None, foreign_key="parent.id")
+ parent: Optional[Parent] = Relationship(back_populates="children")
+
+
+# Mock Analytics
+class MockAnalytics:
+ def record_hydration_success(self, count):
+ logger.info(f"Analytics: Hydration Success {count}")
+
+ def record_hydration_failure(self):
+ logger.info("Analytics: Hydration Failure")
+
+
+def test_pk_collision_prevention():
+ # 1. Setup DB
+ engine = create_engine("sqlite:///:memory:")
+ SQLModel.metadata.create_all(engine)
+
+ with Session(engine) as session:
+ # 2. Pre-populate DB with ID=1 for Parent and Child
+ existing_parent = Parent(id=1, name="Existing Parent")
+ existing_child = Child(id=1, name="Existing Child", parent=existing_parent)
+ session.add(existing_parent)
+ session.add(existing_child)
+ session.commit()
+
+ # Verify they exist
+ p1 = session.get(Parent, 1)
+ c1 = session.get(Child, 1)
+ assert p1 is not None and p1.name == "Existing Parent"
+ assert c1 is not None and c1.name == "Existing Child"
+ logger.info("Pre-existing Parent ID=1 and Child ID=1 confirmed.")
+
+ # 3. Prepare ResultProcessor
+ registry = ModelRegistry(root_model=Parent, logger=logger)
+ processor = ResultProcessor(
+ model_registry=registry, analytics_collector=MockAnalytics(), logger=logger
+ )
+
+ # 4. Create Input Data with ID Collision (ID=1) for both Parent and Child
+ # The LLM outputs ID=1, but we want it to be ignored and a new ID assigned.
+ input_data = [
+ {
+ "_type": "Parent",
+ "id": 1, # CONFLICTING ID!
+ "name": "New LLM Parent",
+ "children": [
+ {
+ "_type": "Child",
+ "id": 1, # CONFLICTING ID!
+ "name": "New LLM Child",
+ }
+ ],
+ }
+ ]
+
+ # 5. Run Hydration (DirectHydrator)
+ # Passing default_model_type forces DirectHydrator if _temp_id is missing,
+ # but here we rely on _temp_id check or explicit default.
+ # The input_data does NOT have _temp_id, so ResultProcessor selects DirectHydrator.
+
+ logger.info("Starting Hydration...")
+ # Use a new session for the operation
+ with Session(engine) as db_session:
+ # Since we're using a real DB session (not in-memory inside hydrate), we pass it
+ hydrated_objects = processor.hydrate(input_data, db_session=db_session)
+ processor.persist(hydrated_objects, db_session)
+
+ # 6. Verify Results
+ # We expect:
+ # - The original Parent (ID=1) is UNTOUCHED.
+ # - A NEW Parent is created (ID!=1, likely 2).
+ # - The Child is linked to the NEW Parent.
+
+ parents = db_session.exec(select(Parent)).all()
+ children = db_session.exec(select(Child)).all()
+ logger.info(f"Total Parents in DB: {len(parents)}")
+ logger.info(f"Total Children in DB: {len(children)}")
+
+ for p in parents:
+ logger.info(f"Parent ID: {p.id}, Name: {p.name}")
+ for c in children:
+ logger.info(f"Child ID: {c.id}, Name: {c.name}, ParentID: {c.parent_id}")
+
+ assert len(parents) == 2, "Should have 2 parents"
+ assert len(children) == 2, "Should have 2 children"
+
+ # Verify New Parent
+ new_parent = next(p for p in parents if p.name == "New LLM Parent")
+ assert new_parent.id != 1, (
+ f"New parent should NOT have ID 1, got {new_parent.id}"
+ )
+
+ # Verify New Child
+ new_child = next(c for c in children if c.name == "New LLM Child")
+ assert new_child.id != 1, f"New child should NOT have ID 1, got {new_child.id}"
+ assert new_child.parent_id == new_parent.id, (
+ f"New child should be linked to new parent {new_parent.id}, got {new_child.parent_id}"
+ )
+
+ # Verify Integrity of Existing Data
+ existing_p = db_session.get(Parent, 1)
+ existing_c = db_session.get(Child, 1)
+ assert existing_p.name == "Existing Parent"
+ assert existing_c.name == "Existing Child"
+ assert existing_c.parent_id == 1
+
+
+if __name__ == "__main__":
+ try:
+ test_pk_collision_prevention()
+ logger.info("✅ TEST PASSED: ID Collision successfully prevented!")
+ except Exception as e:
+ logger.error(f"❌ TEST FAILED: {e}", exc_info=True)
+ exit(1)
diff --git a/tests/core/result_processor/test_pk_relationship_recovery.py b/tests/core/result_processor/test_pk_relationship_recovery.py
new file mode 100644
index 0000000..acd91db
--- /dev/null
+++ b/tests/core/result_processor/test_pk_relationship_recovery.py
@@ -0,0 +1,151 @@
+import logging
+import pytest
+from typing import List, Optional
+from decimal import Decimal
+from sqlmodel import Field, SQLModel, Session, create_engine, Relationship, select
+from extrai.core.result_processor import ResultProcessor, ModelRegistry
+from extrai.core.analytics_collector import WorkflowAnalyticsCollector
+
+# Setup Logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+# --- Simplified Models for Reproducing the Issue ---
+
+
+class Zone(SQLModel, table=True):
+ """A named geographic zone used in pricing rules."""
+
+ __tablename__ = "zones"
+
+ id: Optional[int] = Field(default=None, primary_key=True)
+ name: str = Field(max_length=100)
+ items_text: str = Field(description="Comma-separated list of locations.")
+
+ origin_rules: List["PricingRule"] = Relationship(
+ back_populates="origin_zone",
+ sa_relationship_kwargs={
+ "primaryjoin": "PricingRule.origin_zone_id==Zone.id",
+ "foreign_keys": "[PricingRule.origin_zone_id]",
+ },
+ )
+ destination_rules: List["PricingRule"] = Relationship(
+ back_populates="destination_zone",
+ sa_relationship_kwargs={
+ "primaryjoin": "PricingRule.destination_zone_id==Zone.id",
+ "foreign_keys": "[PricingRule.destination_zone_id]",
+ },
+ )
+
+
+class PricingRule(SQLModel, table=True):
+ """Simplified pricing rules for pet transport using Zone-based logic."""
+
+ __tablename__ = "pricing_rules"
+
+ rule_id: Optional[int] = Field(default=None, primary_key=True)
+
+ origin_zone_id: int = Field(foreign_key="zones.id")
+ destination_zone_id: int = Field(foreign_key="zones.id")
+
+ base_fee: Decimal = Field(max_digits=10, decimal_places=2)
+ currency: str = Field(default="USD", max_length=7)
+
+ origin_zone: Optional[Zone] = Relationship(
+ back_populates="origin_rules",
+ sa_relationship_kwargs={"foreign_keys": "[PricingRule.origin_zone_id]"},
+ )
+ destination_zone: Optional[Zone] = Relationship(
+ back_populates="destination_rules",
+ sa_relationship_kwargs={"foreign_keys": "[PricingRule.destination_zone_id]"},
+ )
+
+
+@pytest.fixture
+def db_session():
+ engine = create_engine("sqlite:///:memory:")
+ SQLModel.metadata.create_all(engine)
+ with Session(engine) as session:
+ # Pre-populate with EXISTING objects to create the ID offset
+ z1 = Zone(name="Existing Zone 1", items_text="country:US")
+ z2 = Zone(name="Existing Zone 2", items_text="country:CA")
+ session.add(z1)
+ session.add(z2)
+
+ session.commit()
+
+ # Verify IDs are 1 and 2
+ session.refresh(z1)
+ session.refresh(z2)
+ logger.info(f"Pre-populated DB. Existing Zone IDs: {z1.id}, {z2.id}")
+
+ yield session
+
+
+@pytest.fixture
+def result_processor():
+ # Register only the necessary models
+ model_registry = ModelRegistry(Zone, logger)
+ model_registry.model_map["Zone"] = Zone
+ model_registry.model_map["PricingRule"] = PricingRule
+
+ analytics = WorkflowAnalyticsCollector(logger)
+ return ResultProcessor(model_registry, analytics, logger)
+
+
+def test_pricing_rule_zone_linkage_offset_repro(db_session, result_processor):
+ """
+ Reproduce the issue where PricingRule links to the WRONG Zone ID because
+ existing objects in the DB cause an ID offset.
+ """
+
+ input_data = [
+ # New Zones (Input IDs 10, 11)
+ {"_type": "Zone", "id": 10, "name": "New Zone A", "items_text": "country:FR"},
+ {"_type": "Zone", "id": 11, "name": "New Zone B", "items_text": "country:DE"},
+ # PricingRule linking to Zones 10, 11
+ {
+ "_type": "PricingRule",
+ "rule_id": 999,
+ "origin_zone_id": 10, # Should link to "New Zone A" (which will get ID 3)
+ "destination_zone_id": 11, # Should link to "New Zone B" (which will get ID 4)
+ "base_fee": 150.00,
+ "currency": "USD",
+ },
+ ]
+
+ # Process
+ objects = result_processor.hydrate(input_data, db_session=db_session)
+ result_processor.persist(objects, db_session)
+
+ # Verify persistence
+ zone_a = db_session.exec(select(Zone).where(Zone.name == "New Zone A")).first()
+ zone_b = db_session.exec(select(Zone).where(Zone.name == "New Zone B")).first()
+
+ assert zone_a is not None
+ assert zone_b is not None
+
+ # Expected IDs for new zones should be 3 and 4 (since 1 and 2 exist)
+ logger.info(f"Zone A ID: {zone_a.id} (Input: 10)")
+ logger.info(f"Zone B ID: {zone_b.id} (Input: 11)")
+
+ assert zone_a.id == 3
+ assert zone_b.id == 4
+
+ # Verify PricingRule
+ rule = db_session.exec(
+ select(PricingRule).where(PricingRule.base_fee == 150.00)
+ ).first()
+ assert rule is not None
+
+ logger.info(f"PricingRule OriginFK: {rule.origin_zone_id}")
+ logger.info(f"PricingRule DestFK: {rule.destination_zone_id}")
+
+ # ASSERTIONS
+ assert rule.origin_zone_id == zone_a.id, (
+ f"PricingRule Origin FK ({rule.origin_zone_id}) does not match Zone A ID ({zone_a.id}). "
+ f"It might be pointing to the input ID (10) or an offset ID."
+ )
+ assert rule.destination_zone_id == zone_b.id, (
+ f"PricingRule Dest FK ({rule.destination_zone_id}) does not match Zone B ID ({zone_b.id})."
+ )
diff --git a/tests/core/result_processor/test_result_processor.py b/tests/core/result_processor/test_result_processor.py
new file mode 100644
index 0000000..343e07b
--- /dev/null
+++ b/tests/core/result_processor/test_result_processor.py
@@ -0,0 +1,90 @@
+import unittest
+from unittest.mock import MagicMock, patch
+
+from extrai.core.result_processor import ResultProcessor, HydrationError, WorkflowError
+from sqlmodel import Session
+
+
+class TestResultProcessor(unittest.TestCase):
+ def setUp(self):
+ self.mock_model_registry = MagicMock()
+ self.mock_model_registry.model_map = {}
+ self.mock_analytics_collector = MagicMock()
+ self.mock_logger = MagicMock()
+ self.result_processor = ResultProcessor(
+ self.mock_model_registry, self.mock_analytics_collector, self.mock_logger
+ )
+
+ def test_hydrate_success(self):
+ results = [{"_type": "test", "_temp_id": "1"}]
+ self.mock_model_registry.model_map = {"test": MagicMock()}
+
+ with patch("extrai.core.result_processor.SQLAlchemyHydrator") as MockHydrator:
+ mock_hydrator_instance = MockHydrator.return_value
+ mock_hydrator_instance.hydrate.return_value = [MagicMock()]
+
+ hydrated = self.result_processor.hydrate(results)
+
+ self.assertEqual(len(hydrated), 1)
+ self.mock_analytics_collector.record_hydration_success.assert_called_with(1)
+ mock_hydrator_instance.hydrate.assert_called_once()
+
+ def test_hydrate_empty(self):
+ results = []
+ hydrated = self.result_processor.hydrate(results)
+ self.assertEqual(hydrated, [])
+ self.mock_analytics_collector.record_hydration_success.assert_not_called()
+
+ def test_hydrate_failure(self):
+ results = [{"_type": "test", "_temp_id": "1"}]
+
+ with patch("extrai.core.result_processor.SQLAlchemyHydrator") as MockHydrator:
+ mock_hydrator_instance = MockHydrator.return_value
+ mock_hydrator_instance.hydrate.side_effect = Exception("Hydration failed")
+
+ with self.assertRaises(HydrationError):
+ self.result_processor.hydrate(results)
+
+ self.mock_analytics_collector.record_hydration_failure.assert_called_once()
+
+ def test_persist_success(self):
+ objects = [MagicMock()]
+ mock_session = MagicMock(spec=Session)
+
+ with patch(
+ "extrai.core.result_processor.persist_objects"
+ ) as mock_persist_objects:
+ self.result_processor.persist(objects, mock_session)
+ mock_persist_objects.assert_called_once_with(
+ db_session=mock_session,
+ objects_to_persist=objects,
+ logger=self.mock_logger,
+ )
+
+ def test_persist_empty(self):
+ objects = []
+ mock_session = MagicMock(spec=Session)
+
+ with patch(
+ "extrai.core.result_processor.persist_objects"
+ ) as mock_persist_objects:
+ self.result_processor.persist(objects, mock_session)
+ mock_persist_objects.assert_not_called()
+
+ def test_persist_failure(self):
+ objects = [MagicMock()]
+ mock_session = MagicMock(spec=Session)
+
+ with patch(
+ "extrai.core.result_processor.persist_objects"
+ ) as mock_persist_objects:
+ mock_persist_objects.side_effect = Exception("Persistence failed")
+
+ with self.assertRaises(WorkflowError):
+ self.result_processor.persist(objects, mock_session)
+
+ mock_session.rollback.assert_called_once()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/result_processor/test_saving_issue.py b/tests/core/result_processor/test_saving_issue.py
new file mode 100644
index 0000000..5ef4f4b
--- /dev/null
+++ b/tests/core/result_processor/test_saving_issue.py
@@ -0,0 +1,93 @@
+import logging
+import pytest
+from typing import Optional
+from sqlmodel import Field, SQLModel, Session, create_engine, select
+from extrai.core.result_processor import ResultProcessor
+from extrai.core.model_registry import ModelRegistry
+from extrai.core.analytics_collector import WorkflowAnalyticsCollector
+
+
+# 1. Define a simple SQLModel for testing
+class PetConfig(SQLModel, table=True):
+ config_id: Optional[int] = Field(default=None, primary_key=True)
+ airline_id: Optional[int] = Field(default=1)
+ transport_type: str
+ max_weight_lbs: Optional[float] = None
+ advance_booking_days: Optional[int] = None
+ notes: Optional[str] = None
+
+
+# Mock data mimicking consensus output for structured extraction
+mock_consensus_data = [
+ {
+ "transport_type": "cabin",
+ "max_weight_lbs": 17.64,
+ "notes": "For cats and dogs weighing less than 8 kg / 17.64 lb.",
+ },
+ {
+ "transport_type": "cargo",
+ "max_weight_lbs": 165.35,
+ "notes": "For cats and dogs weighing more than 8 kg/17.64 lb. and up to 75 kg/165.35 lb.",
+ },
+ {
+ "transport_type": "assistance",
+ "advance_booking_days": 2,
+ "notes": "Trained guide or service dogs.",
+ },
+]
+
+
+@pytest.fixture
+def session():
+ engine = create_engine("sqlite:///:memory:")
+ SQLModel.metadata.create_all(engine)
+ with Session(engine) as session:
+ yield session
+
+
+def test_hydrate_and_persist_structured_output(session: Session):
+ # Arrange
+ logger = logging.getLogger("test_logger")
+ model_registry = ModelRegistry(PetConfig, logger)
+ analytics_collector = WorkflowAnalyticsCollector(logger)
+
+ result_processor = ResultProcessor(
+ model_registry=model_registry,
+ analytics_collector=analytics_collector,
+ logger=logger,
+ )
+
+ # Act
+ # Hydrate objects, simulating direct hydration for structured output
+ hydrated_objects = result_processor.hydrate(
+ results=mock_consensus_data, db_session=session, default_model_type="PetConfig"
+ )
+
+ # Assert
+ assert len(hydrated_objects) == 3
+
+ # Verify that objects are in the session before commit (they should be)
+ assert len(session.new) > 0
+
+ # Persist the hydrated objects
+ result_processor.persist(hydrated_objects, session)
+
+ # Query the database to confirm persistence
+ all_configs = session.exec(select(PetConfig)).all()
+ assert len(all_configs) == 3
+
+ # Check some data points
+ cabin_config = session.exec(
+ select(PetConfig).where(PetConfig.transport_type == "cabin")
+ ).one()
+ assert cabin_config.max_weight_lbs == 17.64
+
+ cargo_config = session.exec(
+ select(PetConfig).where(PetConfig.transport_type == "cargo")
+ ).one()
+ assert cargo_config.max_weight_lbs == 165.35
+
+ assistance_config = session.exec(
+ select(PetConfig).where(PetConfig.transport_type == "assistance")
+ ).one()
+ assert assistance_config.advance_booking_days == 2
diff --git a/tests/core/test_base_llm_client.py b/tests/core/test_base_llm_client.py
index e1e92e7..60a3ca0 100644
--- a/tests/core/test_base_llm_client.py
+++ b/tests/core/test_base_llm_client.py
@@ -614,3 +614,22 @@ async def test_generate_and_validate_raw_json_output_success(
assert len(results) == 1
assert results[0] == expected_output
mock_validate.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_batch_methods_raise_not_implemented(mock_client: MockLLMClient):
+ """Tests that batch processing methods raise NotImplementedError by default."""
+ with pytest.raises(NotImplementedError, match="Batch processing is not supported"):
+ await mock_client.create_batch_job([])
+
+ with pytest.raises(NotImplementedError, match="Batch processing is not supported"):
+ await mock_client.retrieve_batch_job("id")
+
+ with pytest.raises(NotImplementedError, match="Batch processing is not supported"):
+ await mock_client.list_batch_jobs()
+
+ with pytest.raises(NotImplementedError, match="Batch processing is not supported"):
+ await mock_client.cancel_batch_job("id")
+
+ with pytest.raises(NotImplementedError, match="Batch processing is not supported"):
+ await mock_client.retrieve_batch_results("id")
diff --git a/tests/core/test_conflict_resolvers.py b/tests/core/test_conflict_resolvers.py
new file mode 100644
index 0000000..3592e03
--- /dev/null
+++ b/tests/core/test_conflict_resolvers.py
@@ -0,0 +1,78 @@
+# tests/core/test_conflict_resolvers.py
+import unittest
+from extrai.core.conflict_resolvers import (
+ levenshtein_similarity,
+ SimilarityClusterResolver,
+ prefer_most_common_resolver,
+)
+
+
+class TestConflictResolvers(unittest.TestCase):
+ def test_levenshtein_similarity(self):
+ # Identical
+ self.assertEqual(levenshtein_similarity("abc", "abc"), 1.0)
+ # Completely different
+ self.assertEqual(levenshtein_similarity("abc", "def"), 0.0)
+ # Partial
+ # "apple" vs "aple" (1 deletion). Ratio > 0.8
+ sim = levenshtein_similarity("apple", "aple")
+ self.assertGreater(sim, 0.8)
+ self.assertLess(sim, 1.0)
+
+ def test_similarity_clustering_simple(self):
+ resolver = SimilarityClusterResolver(similarity_threshold=0.6)
+ # Cluster: "Christmas Party", "Xmas Party" (Maybe not close enough for Levenshtein, check ratio)
+ # "Christmas" vs "Xmas" is distinct.
+ # "Christmas Party" vs "Christmas PArty" is close.
+ values = ["Christmas Party", "Christmas PArty", "War Zone"]
+ # "Christmas Party" vs "Christmas PArty" -> ratio ~0.9
+ # "War Zone" -> ratio ~0.1
+
+ path = ("event",)
+ result = resolver(path, values)
+ # Should pick from the cluster ["Christmas Party", "Christmas PArty"]
+ # prefer_most_common picks the first one in the cluster list usually
+ self.assertIn(result, ["Christmas Party", "Christmas PArty"])
+ self.assertNotEqual(result, "War Zone")
+
+ def test_similarity_clustering_outlier(self):
+ resolver = SimilarityClusterResolver(similarity_threshold=0.5)
+ # 3 values. A and B close. C far.
+ values = ["abcdefg", "abcdefh", "zzzzzzz"]
+ result = resolver(("p",), values)
+ self.assertIn(result, ["abcdefg", "abcdefh"])
+
+ def test_similarity_clustering_weighted(self):
+ resolver = SimilarityClusterResolver(similarity_threshold=0.5)
+ values = ["A", "A'", "B"]
+ # A and A' close. B far.
+ # Weights: A (0.1), A' (0.1), B (0.8) -> B is huge weight but outlier in string space?
+ # If B is outlier, it forms its own cluster [B].
+ # Cluster [A, A'] has size 2.
+ # Cluster [B] has size 1.
+ # Logic says: prefer largest cluster?
+ # Code says:
+ # if weights: best_cluster = max(clusters, key=cluster_weight)
+ # Cluster [A, A'] weight = 0.2
+ # Cluster [B] weight = 0.8
+ # So B should win if we use weighted clustering logic!
+
+ # NOTE: A and A' need to be similar. "abc" and "abd".
+ values = ["abc", "abd", "zzz"]
+ weights = [0.1, 0.1, 0.8]
+ result = resolver(("p",), values, weights)
+ self.assertEqual(result, "zzz")
+ # This confirms that even if "abc" and "abd" are similar, "zzz" wins because of high trust (weight).
+
+ def test_prefer_most_common_weighted(self):
+ values = ["A", "B", "A"]
+ weights = [0.1, 0.8, 0.05]
+ # A total: 0.15
+ # B total: 0.8
+ # B wins
+ result = prefer_most_common_resolver(("p",), values, weights)
+ self.assertEqual(result, "B")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_cost_tracking.py b/tests/core/test_cost_tracking.py
new file mode 100644
index 0000000..f1fb1cc
--- /dev/null
+++ b/tests/core/test_cost_tracking.py
@@ -0,0 +1,59 @@
+import pytest
+from unittest.mock import MagicMock, AsyncMock
+from extrai.core.analytics_collector import WorkflowAnalyticsCollector
+from extrai.llm_providers.generic_openai_client import GenericOpenAIClient
+
+
+@pytest.mark.asyncio
+async def test_cost_tracking_generic_openai():
+ # Setup
+ collector = WorkflowAnalyticsCollector()
+ client = GenericOpenAIClient(
+ api_key="test", model_name="gpt-4o", base_url="http://test"
+ )
+ client.client = MagicMock()
+
+ # Mock chat completion response with usage
+ mock_completion = MagicMock()
+ mock_completion.choices = [MagicMock(message=MagicMock(content="test response"))]
+ mock_completion.usage = MagicMock(prompt_tokens=10, completion_tokens=20)
+
+ client.client.chat.completions.create = AsyncMock(return_value=mock_completion)
+
+ # Execute
+ await client._execute_llm_call("system", "user", analytics_collector=collector)
+
+ # Verify
+ assert collector.total_input_tokens == 10
+ assert collector.total_output_tokens == 20
+ assert len(collector._llm_cost_details) == 1
+ assert collector._llm_cost_details[0]["model"] == "gpt-4o"
+ assert collector._llm_cost_details[0]["input_tokens"] == 10
+
+
+@pytest.mark.asyncio
+async def test_structured_cost_tracking():
+ # Setup
+ collector = WorkflowAnalyticsCollector()
+ client = GenericOpenAIClient(
+ api_key="test", model_name="gpt-4o", base_url="http://test"
+ )
+ client.client = MagicMock()
+
+ # Mock parse response with usage
+ mock_completion = MagicMock()
+ mock_completion.choices = [
+ MagicMock(message=MagicMock(parsed={"foo": "bar"}, refusal=None))
+ ]
+ mock_completion.usage = MagicMock(prompt_tokens=5, completion_tokens=15)
+
+ client.client.beta.chat.completions.parse = AsyncMock(return_value=mock_completion)
+
+ # Execute
+ await client.generate_structured(
+ "system", "user", MagicMock(), analytics_collector=collector
+ )
+
+ # Verify
+ assert collector.total_input_tokens == 5
+ assert collector.total_output_tokens == 15
diff --git a/tests/core/test_example_json_generator.py b/tests/core/test_example_json_generator.py
index 0669a58..dc39b6c 100644
--- a/tests/core/test_example_json_generator.py
+++ b/tests/core/test_example_json_generator.py
@@ -53,25 +53,23 @@ def setUp(self):
analytics_collector=self.mock_analytics_collector,
max_validation_retries_per_revision=self.max_retries,
)
- # Expected derived values
- # In the new implementation, schema generation is more complex.
- # We will mock the functions responsible for it to isolate the test.
- self.patcher_discover = patch(
- "extrai.core.example_json_generator.discover_sqlmodels_from_root"
- )
- self.patcher_generate_schema = patch(
- "extrai.core.example_json_generator.generate_llm_schema_from_models"
- )
- self.mock_discover = self.patcher_discover.start()
- self.mock_generate_schema = self.patcher_generate_schema.start()
+ self.patcher_inspector = patch(
+ "extrai.core.example_json_generator.SchemaInspector"
+ )
+ self.mock_inspector_cls = self.patcher_inspector.start()
+ self.mock_inspector = self.mock_inspector_cls.return_value
- # Define what the mocked functions will return
- self.mock_discover.return_value = {self.mock_output_model}
+ # Define what the mocked inspector methods will return
+ self.mock_inspector.discover_sqlmodels_from_root.return_value = {
+ self.mock_output_model
+ }
self.expected_schema_str = json.dumps(
self.mock_output_model.model_json_schema()
)
- self.mock_generate_schema.return_value = self.expected_schema_str
+ self.mock_inspector.generate_llm_schema_from_models.return_value = (
+ self.expected_schema_str
+ )
self.generator = ExampleJSONGenerator(
llm_client=self.mock_llm_client,
@@ -88,13 +86,14 @@ def setUp(self):
}
def tearDown(self):
- self.patcher_discover.stop()
- self.patcher_generate_schema.stop()
+ self.patcher_inspector.stop()
def test_initialization_success(self):
"""Test that the ExampleJSONGenerator initializes correctly with the new schema logic."""
- self.mock_discover.assert_called_once_with(self.mock_output_model)
- self.mock_generate_schema.assert_called_once_with(
+ self.mock_inspector.discover_sqlmodels_from_root.assert_called_once_with(
+ self.mock_output_model
+ )
+ self.mock_inspector.generate_llm_schema_from_models.assert_called_once_with(
initial_model_classes={self.mock_output_model}
)
@@ -166,29 +165,23 @@ def test_initialization_failure_schema_derivation_error(self):
object,
)
- # Mock issubclass to return True for this specific mock
- def mock_issubclass(obj, classinfo):
- if obj is mock_model_with_bad_schema and classinfo is SQLModel:
- return True
- return orig_issubclass(obj, classinfo)
-
- orig_issubclass = issubclass # Store original issubclass
+ # Configure mock to raise exception
+ self.mock_inspector.discover_sqlmodels_from_root.side_effect = Exception(
+ "Discovery failed"
+ )
- with patch(
- "extrai.core.example_json_generator.discover_sqlmodels_from_root",
- side_effect=Exception("Discovery failed"),
+ with self.assertRaisesRegex(
+ ConfigurationError,
+ "Failed to derive JSON schema from output_model MockSQLModel: Discovery failed",
):
- with self.assertRaisesRegex(
- ConfigurationError,
- "Failed to derive JSON schema from output_model MockSQLModel: Discovery failed",
- ):
- ExampleJSONGenerator(
- llm_client=self.mock_llm_client,
- output_model=self.mock_output_model,
- max_validation_retries_per_revision=1,
- )
- # Restore original issubclass if necessary, though patch should handle it
- # For safety, or if not using patch.object for issubclass on the module itself.
+ ExampleJSONGenerator(
+ llm_client=self.mock_llm_client,
+ output_model=self.mock_output_model,
+ max_validation_retries_per_revision=1,
+ )
+
+ # Reset side effect for other tests (though not strictly needed as tests are isolated)
+ self.mock_inspector.discover_sqlmodels_from_root.side_effect = None
@patch(
"extrai.core.example_json_generator.generate_prompt_for_example_json_generation"
diff --git a/tests/core/test_json_consensus.py b/tests/core/test_json_consensus.py
index a82d3d1..8010e63 100644
--- a/tests/core/test_json_consensus.py
+++ b/tests/core/test_json_consensus.py
@@ -116,11 +116,11 @@ def test_no_clear_majority_prefer_most_common_resolver_dict(self):
# Test when most common does NOT meet threshold
consensus_processor_strict_threshold = JSONConsensus(
- consensus_threshold=0.7, # Requires 3/3 for 3 revisions (0.7*3=2.1, so need 3)
+ consensus_threshold=0.8, # Requires 3/3 for 3 revisions (0.8*3=2.4, so need 3)
conflict_resolver=prefer_most_common_resolver,
)
- # color: red (2/3), blue (1/3) -> red is most common but 2/3 is not > 0.7
- # value: 10 (2/3), 20 (1/3) -> 10 is most common but 2/3 is not > 0.7
+ # color: red (2/3), blue (1/3) -> red is most common but 2/3 is not > 0.8
+ # value: 10 (2/3), 20 (1/3) -> 10 is most common but 2/3 is not > 0.8
# prefer_most_common_resolver will still pick them.
expected_conflict_resolved = {"color": "red", "value": 10}
consensus_obj_strict, analytics_strict = (
@@ -129,7 +129,7 @@ def test_no_clear_majority_prefer_most_common_resolver_dict(self):
self.assertEqual(consensus_obj_strict, expected_conflict_resolved)
self.assertEqual(
analytics_strict["paths_agreed_by_threshold"], 0
- ) # None met the 0.7 threshold
+ ) # None met the 0.8 threshold
self.assertEqual(
analytics_strict["paths_resolved_by_conflict_resolver"], 2
) # Both resolved
@@ -577,7 +577,7 @@ def test_special_conflict_resolution_for_temp_id_and_type(self):
# resolution logic applied, where the most common value is chosen even if
# the consensus threshold is not met.
consensus_processor = JSONConsensus(
- consensus_threshold=0.7, # High threshold to force conflict
+ consensus_threshold=0.8, # High threshold to force conflict
conflict_resolver=default_conflict_resolver, # Default would omit
)
revisions = [
@@ -585,8 +585,8 @@ def test_special_conflict_resolution_for_temp_id_and_type(self):
{"_type": "A", "_temp_id": "id2", "other": "y"},
{"_type": "B", "_temp_id": "id1", "other": "z"},
]
- # _type: A (2/3), B (1/3). 2/3 = 0.66 < 0.7. Conflict.
- # _temp_id: id1 (2/3), id2 (1/3). 2/3 = 0.66 < 0.7. Conflict.
+ # _type: A (2/3), B (1/3). 2/3 = 0.66 < 0.8. Conflict.
+ # _temp_id: id1 (2/3), id2 (1/3). 2/3 = 0.66 < 0.8. Conflict.
# other: x,y,z all 1/3. Conflict.
# Expected: _type and _temp_id are resolved to most common, 'other' is omitted.
expected = {"_type": "A", "_temp_id": "id1"}
diff --git a/tests/core/test_json_consensus_weighted.py b/tests/core/test_json_consensus_weighted.py
new file mode 100644
index 0000000..ee86ade
--- /dev/null
+++ b/tests/core/test_json_consensus_weighted.py
@@ -0,0 +1,90 @@
+# tests/core/test_json_consensus_weighted.py
+import unittest
+from extrai.core.json_consensus import JSONConsensus
+from extrai.core.conflict_resolvers import prefer_most_common_resolver
+
+
+class TestJSONConsensusWeighted(unittest.TestCase):
+ def test_weighted_consensus_outlier_rejection(self):
+ # Scenario:
+ # Rev 1 and Rev 2 agree on 10 context fields.
+ # Rev 3 disagrees on those 10.
+ # On the 'target' field, all 3 disagree.
+ # R1 -> X, R2 -> Z, R3 -> Y.
+
+ # Standard voting (if order is R3, R1, R2) would pick Y because it appears first in the tie.
+ # Weighted voting should identify R1 and R2 as a cluster (high mutual similarity) and R3 as outlier.
+ # Thus R1 and R2 get high weights. R3 gets low weight.
+ # Tie between X and Z (high weights) should win over Y (low weight).
+
+ rev1 = {f"f{i}": "A" for i in range(10)}
+ rev1["target"] = "X"
+
+ rev2 = {f"f{i}": "A" for i in range(10)}
+ rev2["target"] = "Z"
+
+ rev3 = {f"f{i}": "B" for i in range(10)}
+ rev3["target"] = "Y"
+
+ # Put R3 first to bias standard voting towards Y
+ revisions = [rev3, rev1, rev2]
+
+ consensus = JSONConsensus(
+ consensus_threshold=0.5, conflict_resolver=prefer_most_common_resolver
+ )
+ result, analytics = consensus.get_consensus(revisions)
+
+ print(f"Weights: {analytics.get('revision_weights')}")
+ print(f"Result Target: {result.get('target')}")
+
+ # Check that target is NOT Y (the outlier's choice)
+ self.assertNotEqual(result.get("target"), "Y")
+ self.assertIn(result.get("target"), ["X", "Z"])
+
+ # Check weights in analytics
+ w = analytics.get("revision_weights")
+ # R3 is index 0. Should have low weight.
+ # R1 (idx 1) and R2 (idx 2) match on 10 fields.
+ # R3 matches none.
+ self.assertLess(w[0], w[1])
+ self.assertLess(w[0], w[2])
+
+ def test_weighted_consensus_with_strings(self):
+ from extrai.core.conflict_resolvers import SimilarityClusterResolver
+
+ # R1 and R2 match on many string fields.
+ rev1 = {"message": "hello world"}
+ rev2 = {"message": "hello people"}
+ rev3 = {"message": "goodbye moon"}
+
+ # Conflict field
+ # R1: "Apple"
+ # R2: "Banana"
+ # R3: "Cherry"
+ # All distinct. All semantically far.
+ # But R1 and R2 have high weight. R3 low.
+
+ rev1["fruit"] = "Apple"
+ rev2["fruit"] = "Banana"
+ rev3["fruit"] = "Cherry"
+
+ revisions = [rev3, rev1, rev2]
+
+ # Use SimilarityClusterResolver (which falls back to weighted most common if no clusters)
+ resolver = SimilarityClusterResolver(similarity_threshold=0.8)
+ consensus = JSONConsensus(consensus_threshold=1.0, conflict_resolver=resolver)
+
+ result, analytics = consensus.get_consensus(revisions)
+
+ # Should pick Apple or Banana. Definitely not Cherry.
+ self.assertNotEqual(result.get("fruit"), "Cherry")
+ self.assertIn(result.get("fruit"), ["Apple", "Banana"])
+
+ # Check avg string similarity
+ # On paths s0..s4: R1-R2 is 1.0. R1-R3 is low. R2-R3 low.
+ # Avg for those paths: (1.0 + low + low)/3
+ self.assertGreater(analytics.get("average_string_similarity"), 0.0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_model_wrapper_builder.py b/tests/core/test_model_wrapper_builder.py
new file mode 100644
index 0000000..55e31ac
--- /dev/null
+++ b/tests/core/test_model_wrapper_builder.py
@@ -0,0 +1,66 @@
+from typing import List, Optional
+from sqlmodel import SQLModel, Field, Relationship
+from pydantic import BaseModel
+from extrai.core.model_wrapper_builder import ModelWrapperBuilder
+
+
+# Define test models
+class ChildModel1(SQLModel, table=True):
+ __tablename__ = "test_child"
+ __table_args__ = {"extend_existing": True}
+ id: Optional[int] = Field(default=None, primary_key=True)
+ name: str
+ parent_id: Optional[int] = Field(default=None, foreign_key="test_parent.id")
+ parent: Optional["ParentModel1"] = Relationship(back_populates="children")
+
+
+class ParentModel1(SQLModel, table=True):
+ __tablename__ = "test_parent"
+ __table_args__ = {"extend_existing": True}
+ id: Optional[int] = Field(default=None, primary_key=True)
+ name: str
+ children: List[ChildModel1] = Relationship(back_populates="parent")
+
+
+class TestModelWrapperBuilder:
+ def test_generate_wrapper_model(self):
+ builder = ModelWrapperBuilder()
+ wrapper_model = builder.generate_wrapper_model(ParentModel1)
+
+ # Check wrapper structure
+ assert issubclass(wrapper_model, BaseModel)
+ assert "entities" in wrapper_model.model_fields
+ assert wrapper_model.__name__ == "ParentModel1ExtractionResult"
+
+ # Check nested structure
+ entities_field = wrapper_model.model_fields["entities"]
+
+ # We can try to instantiate it to verify structure
+ parent_structure_cls = entities_field.annotation.__args__[0]
+ assert parent_structure_cls.__name__ == "ParentModel1Structure"
+
+ assert "name" in parent_structure_cls.model_fields
+ assert "children" in parent_structure_cls.model_fields
+
+ children_field = parent_structure_cls.model_fields["children"]
+ # Should be List[ChildModel1Structure]
+ child_structure_cls = children_field.annotation.__args__[0]
+ assert child_structure_cls.__name__ == "ChildModel1Structure"
+
+ assert "name" in child_structure_cls.model_fields
+ # Child should NOT have parent field to avoid recursion if we implemented that logic
+ assert "parent" not in child_structure_cls.model_fields
+
+ def test_circular_reference_handling(self):
+ # Already covered by the check above (Child should not have 'parent' field)
+ # because the builder skips MANYTOONE relationships.
+ builder = ModelWrapperBuilder()
+ wrapper_model = builder.generate_wrapper_model(ParentModel1)
+
+ parent_structure_cls = wrapper_model.model_fields[
+ "entities"
+ ].annotation.__args__[0]
+ children_field = parent_structure_cls.model_fields["children"]
+ child_structure_cls = children_field.annotation.__args__[0]
+
+ assert "parent" not in child_structure_cls.model_fields
diff --git a/tests/core/test_prompt_builder.py b/tests/core/test_prompt_builder.py
index a9a8be7..9cd22ff 100644
--- a/tests/core/test_prompt_builder.py
+++ b/tests/core/test_prompt_builder.py
@@ -1,524 +1,28 @@
import unittest
-import json
from extrai.core.prompt_builder import (
+ PromptBuilder,
generate_system_prompt,
generate_user_prompt_for_docs,
generate_sqlmodel_creation_system_prompt,
- generate_prompt_for_example_json_generation, # New function to be tested
+ generate_prompt_for_example_json_generation,
+ generate_entity_counting_system_prompt,
+ generate_entity_counting_user_prompt,
)
-class TestPromptBuilder(unittest.TestCase):
- def setUp(self):
- """Set up common test data."""
- self.sample_schema_dict = {
- "type": "object",
- "title": "TestEntity",
- "properties": {
- "id": {"type": "string"},
- "name": {"type": "string"},
- "value": {"type": "number"},
- },
- "required": ["id", "name"],
- }
- self.sample_schema_json_str = json.dumps(self.sample_schema_dict, indent=2)
-
- self.sample_extraction_example_dict = {
- "id": "test001",
- "name": "Test Name",
- "value": 123.45,
- }
- self.sample_extraction_example_json_str = json.dumps(
- self.sample_extraction_example_dict, indent=2
- )
-
- def test_generate_system_prompt_basic(self):
- """Test system prompt with only schema."""
- prompt = generate_system_prompt(schema_json=self.sample_schema_json_str)
- self.assertIn("You are an advanced AI specializing in data extraction", prompt)
- self.assertIn("# JSON SCHEMA TO ADHERE TO:", prompt)
- self.assertIn(self.sample_schema_json_str, prompt)
- self.assertIn("# EXTRACTION PROCESS", prompt) # Default process
- self.assertIn("# IMPORTANT EXTRACTION GUIDELINES", prompt) # Default guidelines
- self.assertIn("# FINAL CHECK BEFORE SUBMISSION", prompt) # Default checklist
- self.assertNotIn("# EXAMPLE OF EXTRACTION:", prompt)
-
- def test_generate_system_prompt_with_example(self):
- """Test system prompt with schema and example."""
- prompt = generate_system_prompt(
- schema_json=self.sample_schema_json_str,
- extraction_example_json=self.sample_extraction_example_json_str,
- )
- self.assertIn(self.sample_schema_json_str, prompt)
- self.assertIn("# EXAMPLE OF EXTRACTION:", prompt)
- self.assertIn("## CONCEPTUAL INPUT TEXT", prompt)
- self.assertIn(self.sample_extraction_example_json_str, prompt)
-
- def test_generate_system_prompt_with_custom_process(self):
- """Test system prompt with custom extraction process."""
- custom_process = "# MY CUSTOM PROCESS\n1. Do this first.\n2. Then do that."
- prompt = generate_system_prompt(
- schema_json=self.sample_schema_json_str,
- custom_extraction_process=custom_process,
- )
- self.assertIn(custom_process, prompt)
- self.assertNotIn(
- "Follow this step-by-step process meticulously:", prompt
- ) # Default process heading
- self.assertIn(
- "# IMPORTANT EXTRACTION GUIDELINES", prompt
- ) # Default guidelines should still be there
- self.assertIn("# FINAL CHECK BEFORE SUBMISSION", prompt) # Default checklist
-
- def test_generate_system_prompt_with_custom_guidelines(self):
- """Test system prompt with custom extraction guidelines."""
- custom_guidelines = "# MY CUSTOM GUIDELINES\n- Be awesome.\n- Be accurate."
- prompt = generate_system_prompt(
- schema_json=self.sample_schema_json_str,
- custom_extraction_guidelines=custom_guidelines,
- )
- self.assertIn(custom_guidelines, prompt)
- self.assertNotIn(
- "- **Output Format:** Your entire output must be a single, valid JSON object.",
- prompt,
- ) # Default guideline
- self.assertIn("# EXTRACTION PROCESS", prompt) # Default process
- self.assertIn("# FINAL CHECK BEFORE SUBMISSION", prompt) # Default checklist
-
- def test_generate_system_prompt_with_custom_checklist(self):
- """Test system prompt with custom final checklist."""
- custom_checklist = "# MY CUSTOM CHECKLIST\n1. Did I do well?\n2. Yes."
- prompt = generate_system_prompt(
- schema_json=self.sample_schema_json_str,
- custom_final_checklist=custom_checklist,
- )
- self.assertIn(custom_checklist, prompt)
- self.assertNotIn("1. **Valid JSON?**", prompt) # Default checklist item
- self.assertIn("# EXTRACTION PROCESS", prompt) # Default process
- self.assertIn("# IMPORTANT EXTRACTION GUIDELINES", prompt) # Default guidelines
-
- def test_generate_system_prompt_with_custom_context(self):
- """Test system prompt with custom_context."""
- custom_context_content = "This is some important external context for the LLM."
- prompt = generate_system_prompt(
- schema_json=self.sample_schema_json_str,
- custom_context=custom_context_content,
- )
- self.assertIn("# ADDITIONAL CONTEXT:", prompt)
- self.assertIn(custom_context_content, prompt)
- self.assertIn(
- self.sample_schema_json_str, prompt
- ) # Ensure schema is still there
- self.assertIn(
- "# EXTRACTION PROCESS", prompt
- ) # Default process should still be there
-
- # Test that it's not included if empty
- prompt_no_custom_context = generate_system_prompt(
- schema_json=self.sample_schema_json_str, custom_context=""
- )
- self.assertNotIn("# ADDITIONAL CONTEXT:", prompt_no_custom_context)
- self.assertNotIn(custom_context_content, prompt_no_custom_context)
-
- def test_generate_system_prompt_all_custom(self):
- """Test system prompt with all custom sections."""
- custom_process = "# CUSTOM PROCESS V2"
- custom_guidelines = "# CUSTOM GUIDELINES V2"
- custom_checklist = "# CUSTOM CHECKLIST V2"
- custom_context_content = "All custom context here for the all_custom test."
- prompt = generate_system_prompt(
- schema_json=self.sample_schema_json_str,
- extraction_example_json=self.sample_extraction_example_json_str,
- custom_extraction_process=custom_process,
- custom_extraction_guidelines=custom_guidelines,
- custom_final_checklist=custom_checklist,
- custom_context=custom_context_content,
- )
- self.assertIn(custom_process, prompt)
- self.assertIn(custom_guidelines, prompt)
- self.assertIn(custom_checklist, prompt)
- self.assertIn("# ADDITIONAL CONTEXT:", prompt)
- self.assertIn(custom_context_content, prompt)
- self.assertIn(self.sample_schema_json_str, prompt)
- self.assertIn(self.sample_extraction_example_json_str, prompt)
- self.assertNotIn("Follow this step-by-step process meticulously:", prompt)
- self.assertNotIn(
- "- **Output Format:** Your entire output must be a single, valid JSON object.",
- prompt,
- )
- self.assertNotIn("1. **Valid JSON?**", prompt)
-
- def test_generate_user_prompt_single_document(self):
- """Test user prompt with a single document."""
- doc1 = "This is the first document."
- prompt = generate_user_prompt_for_docs([doc1])
- self.assertIn(
- "Please extract information from the following document(s)", prompt
- )
- self.assertIn("# DOCUMENT(S) FOR EXTRACTION:", prompt)
- self.assertIn(doc1, prompt)
- self.assertNotIn("---END OF DOCUMENT---", prompt) # No separator for single doc
- self.assertIn(
- "Remember: Your output must be only a single, valid JSON object.", prompt
- )
-
- def test_generate_user_prompt_multiple_documents(self):
- """Test user prompt with multiple documents."""
- doc1 = "Document one content."
- doc2 = "Document two content here."
- doc3 = "And a third document."
- separator = "\n\n---END OF DOCUMENT---\n\n---START OF NEW DOCUMENT---\n\n"
- prompt = generate_user_prompt_for_docs([doc1, doc2, doc3])
- self.assertIn(doc1, prompt)
- self.assertIn(doc2, prompt)
- self.assertIn(doc3, prompt)
- self.assertEqual(prompt.count(separator), 2) # Two separators for three docs
- self.assertIn("# DOCUMENT(S) FOR EXTRACTION:", prompt)
-
- def test_generate_user_prompt_empty_documents_list(self):
- """Test user prompt with an empty list of documents."""
- prompt = generate_user_prompt_for_docs([])
- self.assertIn(
- "Please extract information from the following document(s) strictly according to the schema and instructions previously provided (in the system prompt).",
- prompt,
- )
- self.assertIn("# DOCUMENT(S) FOR EXTRACTION:", prompt)
-
- # Check that the space between "EXTRACTION:" and "---" is just newlines (or empty)
- extraction_header_end_index = prompt.find("# DOCUMENT(S) FOR EXTRACTION:")
- if extraction_header_end_index != -1:
- extraction_header_end_index += len("# DOCUMENT(S) FOR EXTRACTION:")
-
- reminder_start_index = prompt.find(
- "---"
- ) # Assuming "---" is the start of the reminder section
-
- if (
- extraction_header_end_index != -1
- and reminder_start_index != -1
- and extraction_header_end_index < reminder_start_index
- ):
- content_between = prompt[extraction_header_end_index:reminder_start_index]
- self.assertEqual(
- content_between.strip(),
- "",
- f"Content between extraction header and reminder should be whitespace only, but was: '{content_between}'",
- )
- elif extraction_header_end_index == -1:
- self.fail("Could not find '# DOCUMENT(S) FOR EXTRACTION:' marker.")
- elif reminder_start_index == -1:
- self.fail("Could not find '---' marker for reminder section.")
- else: # Markers found but order is wrong or overlap
- self.fail(
- f"Problem with marker positions: extraction_header_end_index={extraction_header_end_index}, reminder_start_index={reminder_start_index}"
- )
-
- self.assertIn(
- "Remember: Your output must be only a single, valid JSON object.", prompt
- )
- self.assertNotIn(
- "---END OF DOCUMENT---", prompt
- ) # Still important for empty list
-
- def test_generate_user_prompt_documents_with_special_chars(self):
- """Test user prompt with documents containing special JSON characters."""
- doc1 = 'This document has "quotes" and {curly braces}.'
- doc2 = "Another one with a backslash \\ and newlines \n in theory."
- prompt = generate_user_prompt_for_docs([doc1, doc2])
- self.assertIn(doc1, prompt)
- self.assertIn(doc2, prompt)
- self.assertIn("---END OF DOCUMENT---", prompt)
-
- def test_generate_sqlmodel_creation_system_prompt(self):
- """Test the specialized system prompt for SQLModel creation."""
- sample_sqlmodel_schema_str = json.dumps(
- {
- "title": "SQLModelDescription",
- "type": "object",
- "properties": {
- "model_name": {"type": "string"},
- "fields": {"type": "array", "items": {"type": "object"}},
- },
- "required": ["model_name", "fields"],
- },
- indent=2,
- )
- user_task_desc = "Create a model for tracking library books, including title, author, and ISBN."
-
- prompt = generate_sqlmodel_creation_system_prompt(
- schema_json=sample_sqlmodel_schema_str, user_task_description=user_task_desc
- )
-
- self.assertIn(
- "You are an AI assistant tasked with designing one or more SQLModel class definitions.",
- prompt,
- )
-
- self.assertIn(
- "3. Each object in the `sql_models` list MUST strictly adhere to the following JSON schema for a SQLModel description:",
- prompt,
- )
- # Correctly check for the schema block - note the double newlines from how prompt_parts are joined
- self.assertIn(f"```json\n\n{sample_sqlmodel_schema_str}\n\n```", prompt)
-
- # Check for the new "IMPORTANT CONSIDERATIONS" section
- self.assertIn("# IMPORTANT CONSIDERATIONS FOR DATABASE TABLE MODELS:", prompt)
- self.assertIn(
- 'field_options_str": "Field(default_factory=list, sa_type=JSON)"', prompt
- ) # Example for List
- self.assertIn(
- 'add `"from sqlmodel import JSON"` to the main `imports` array', prompt
- )
- self.assertIn(
- "MUST provide a sensible `default` value", prompt
- ) # Instruction for defaults
-
- # Check for the user task description section
- self.assertIn("# USER'S TASK:", prompt)
- self.assertIn(
- f'The user wants to define a SQLModel based on the following objective: "{user_task_desc}"',
- prompt,
- )
- self.assertIn(
- "Pay close attention to the requirements for List/Dict types if the model is a table, and try to provide default values for required fields.",
- prompt,
- )
-
- # Check for the hardcoded example section
- self.assertIn(
- "# EXAMPLE OF A VALID SQLMODEL DESCRIPTION JSON (Illustrating a list of models):",
- prompt,
- )
- # Check for a snippet from the hardcoded example to ensure it's present
- self.assertIn('"model_name": "ExampleItem"', prompt)
- self.assertIn('"table_name": "example_items"', prompt)
- self.assertIn("Timestamp of when the item was created.", prompt)
- self.assertIn('"name": "categories"', prompt) # Part of the new example
- self.assertIn(
- '"field_options_str": "Field(default_factory=list, sa_type=JSON)"', prompt
- ) # Part of the new example
- self.assertIn(
- '"from sqlmodel import SQLModel, Field, JSON"', prompt
- ) # Import in the new example
-
- def test_generate_sqlmodel_creation_system_prompt_structure_and_hardcoded_example(
- self,
- ):
- """
- Tests the overall structure and presence of the hardcoded example in
- the SQLModel creation prompt.
- """
- sample_schema_for_description_str = json.dumps(
- {
- "title": "SQLModelDesc",
- "type": "object",
- "properties": {"model_name": {"type": "string"}},
- "required": ["model_name"],
- },
- indent=2,
- )
- user_task = "Define a product model."
-
- prompt = generate_sqlmodel_creation_system_prompt(
- schema_json=sample_schema_for_description_str,
- user_task_description=user_task,
- )
-
- # General structure checks
- self.assertTrue(
- prompt.startswith(
- "You are an AI assistant tasked with designing one or more SQLModel class definitions."
- )
- )
- self.assertIn("# REQUIREMENTS FOR YOUR OUTPUT:", prompt)
- self.assertIn(
- "3. Each object in the `sql_models` list MUST strictly adhere to the following JSON schema for a SQLModel description:",
- prompt,
- )
- self.assertIn(
- sample_schema_for_description_str, prompt
- ) # Check if the passed schema is there
- self.assertIn("# USER'S TASK:", prompt)
- self.assertIn(user_task, prompt)
- self.assertIn(
- "# IMPORTANT CONSIDERATIONS FOR DATABASE TABLE MODELS:", prompt
- ) # New section check
- self.assertIn(
- "# EXAMPLE OF A VALID SQLMODEL DESCRIPTION JSON (Illustrating a list of models):",
- prompt,
- )
- self.assertTrue(
- prompt.endswith(
- "Do not include any other narrative, explanations, or conversational elements in your output."
- )
- )
-
- # Specific checks for the hardcoded example's content (which now includes List[str] example)
- self.assertIn('"model_name": "ExampleItem"', prompt)
- self.assertIn('"table_name": "example_items"', prompt)
- self.assertIn('primary_key": true', prompt)
- self.assertIn("datetime.datetime.utcnow", prompt)
- self.assertIn('"name": "categories"', prompt)
- self.assertIn(
- '"field_options_str": "Field(default_factory=list, sa_type=JSON)"', prompt
- )
- self.assertIn('"from sqlmodel import SQLModel, Field, JSON"', prompt)
-
- # Ensure the example JSON block is correctly formatted
- example_intro_text = "This is an example of the kind of JSON object you should produce (it conforms to the schema above):"
- self.assertIn(example_intro_text, prompt)
-
- # Extract the example JSON part to validate it
- try:
- # Find the start of the example JSON block
- json_block_marker = "```json\n\n" # Note the double newline
- # Find the specific example block after the intro text
- example_intro_end_idx = prompt.find(example_intro_text) + len(
- example_intro_text
- )
- json_code_block_start_idx = prompt.find(
- json_block_marker, example_intro_end_idx
- )
-
- if json_code_block_start_idx == -1:
- self.fail(
- f"Could not find the start of the example JSON code block ('{json_block_marker}')."
- )
-
- # Move past the marker itself
- actual_json_start_idx = json_code_block_start_idx + len(json_block_marker)
-
- # Find the end of this specific JSON code block (which is \n\n```)
- json_code_block_end_marker = "\n\n```"
- json_code_block_end_idx = prompt.find(
- json_code_block_end_marker, actual_json_start_idx
- )
- if json_code_block_end_idx == -1:
- self.fail(
- f"Could not find the end of the example JSON code block ('{json_code_block_end_marker}')."
- )
-
- example_json_str_from_prompt = prompt[
- actual_json_start_idx:json_code_block_end_idx
- ].strip()
-
- # Validate that this extracted string is valid JSON
- json.loads(example_json_str_from_prompt)
- except json.JSONDecodeError as e:
- self.fail(
- f"Hardcoded example JSON in prompt is not valid JSON: {e}\nExtracted JSON string:\n'{example_json_str_from_prompt}'"
- )
- except Exception as e: # Catch other potential errors during extraction
- self.fail(
- f"Failed to extract or validate the hardcoded example JSON from prompt: {e}"
- )
-
- def test_generate_prompt_for_example_json_generation(self):
- """Test the prompt generation for creating an example JSON."""
- root_model_name = "SampleOutputModel"
-
- prompt = generate_prompt_for_example_json_generation(
- target_model_schema_str=self.sample_schema_json_str,
- root_model_name=root_model_name,
- )
-
- # General instructions
- self.assertIn(
- "You are an AI assistant tasked with generating a sample JSON object.",
- prompt,
- )
- self.assertIn(
- f"The goal is to create a single, valid JSON object that conforms to the provided schema for a model named '{root_model_name}' and its related models.",
- prompt,
- )
- self.assertIn("# JSON SCHEMA TO ADHERE TO:", prompt)
- self.assertIn(self.sample_schema_json_str, prompt)
- self.assertIn(
- "Your output MUST be a single JSON object with a top-level key named 'entities'.",
- prompt,
- )
- self.assertIn(
- "Each object inside the 'entities' list MUST include two metadata fields:",
- prompt,
- )
- self.assertIn(
- "`_type`: This field's value MUST be a string matching the name of the model it represents",
- prompt,
- )
- self.assertIn(
- "`_temp_id`: This field's value MUST be a unique temporary string identifier for that specific entity instance",
- prompt,
- )
- self.assertIn(
- "Your 'entities' list should contain an instance of the root model", prompt
- )
- self.assertIn("at least one instance of each of its related models", prompt)
-
- def test_generate_system_prompt_with_non_json_example_string(self):
- """
- Test system prompt with an extraction_example_json that is a non-empty,
- non-JSON string. This should cover line 108 and the 'else' branch
- of the inner conditional (line 119 in prompt_builder.py).
- """
- non_json_example_str = "This is a raw example string, not a JSON object."
- prompt = generate_system_prompt(
- schema_json=self.sample_schema_json_str,
- extraction_example_json=non_json_example_str,
- )
-
- # Check that line 108's content ("# EXAMPLE OF EXTRACTION:") is present
- self.assertIn("# EXAMPLE OF EXTRACTION:", prompt)
-
- # Check that the non_json_example_str itself is used (from line 119)
- self.assertIn(non_json_example_str, prompt)
-
- # Check that it's NOT wrapped with {"result": ...} (line 117 should not be hit)
- # Constructing the exact f-string format for the negative assertion
- wrapped_example_check = f'{{\n "result": {non_json_example_str}\n}}'
- self.assertNotIn(wrapped_example_check, prompt)
-
- # Also check for other parts of the example section to be sure they are still there
- self.assertIn("## CONCEPTUAL INPUT TEXT", prompt)
- self.assertIn("## EXAMPLE EXTRACTED JSON", prompt)
- # Ensure the ```json block markers are present around the example
- # The example string is non_json_example_str
- # So we expect "```json\n\n" + non_json_example_str + "\n\n```" (joined by \n\n)
- # More robustly, check that the example string is between ```json and ```
- # Find the start of the example section text
- example_section_header_idx = prompt.find("## EXAMPLE EXTRACTED JSON")
- self.assertNotEqual(
- example_section_header_idx != -1, "Example JSON header not found"
- )
-
- # Find ```json after this header
- json_block_start_marker = "```json"
- json_block_start_idx = prompt.find(
- json_block_start_marker, example_section_header_idx
- )
- self.assertNotEqual(
- json_block_start_idx != -1,
- "```json start marker not found after example header",
- )
-
- # Find the example string after the ```json marker
- example_str_idx = prompt.find(
- non_json_example_str, json_block_start_idx + len(json_block_start_marker)
- )
- self.assertNotEqual(
- example_str_idx != -1,
- "Non-JSON example string not found after ```json marker",
- )
-
- # Find ``` end marker after the example string
- json_block_end_marker = "```"
- json_block_end_idx = prompt.find(
- json_block_end_marker, example_str_idx + len(non_json_example_str)
- )
- self.assertNotEqual(
- json_block_end_idx != -1,
- "``` end marker not found after non-JSON example string",
- )
+class TestPromptBuilderFacade(unittest.TestCase):
+ def test_facade_exports(self):
+ """Verify that the facade module exports the expected functions and class."""
+ # Just checking if they are callable is enough for a facade test
+ # deeper logic is tested in tests/core/prompts/*
+ self.assertTrue(callable(PromptBuilder))
+ self.assertTrue(callable(generate_system_prompt))
+ self.assertTrue(callable(generate_user_prompt_for_docs))
+ self.assertTrue(callable(generate_sqlmodel_creation_system_prompt))
+ self.assertTrue(callable(generate_prompt_for_example_json_generation))
+ self.assertTrue(callable(generate_entity_counting_system_prompt))
+ self.assertTrue(callable(generate_entity_counting_user_prompt))
if __name__ == "__main__":
- unittest.main(argv=["first-arg-is-ignored"], exit=False)
+ unittest.main()
diff --git a/tests/core/test_schema_inspector.py b/tests/core/test_schema_inspector.py
index 579c6a9..ca7345c 100644
--- a/tests/core/test_schema_inspector.py
+++ b/tests/core/test_schema_inspector.py
@@ -1,47 +1,32 @@
-import unittest
-import enum
import json
-from typing import (
- Any,
- Dict,
- List,
- Optional,
- Set,
- Union,
- ForwardRef,
- get_origin as typing_get_origin,
- get_args as typing_get_args,
-)
-import datetime
+import pytest
+from typing import List, Optional
from unittest.mock import patch, MagicMock
-from sqlalchemy import create_engine, Column, Integer, String, ForeignKey
-from sqlalchemy.exc import NoInspectionAvailable
+from sqlalchemy import (
+ create_engine,
+ Column,
+ String,
+ Enum,
+ UniqueConstraint,
+)
from sqlalchemy.orm import (
- relationship as sa_relationship,
declarative_base,
Mapped,
mapped_column,
+ relationship,
RelationshipProperty,
)
+from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlmodel import SQLModel, Field, Relationship
-from extrai.core.schema_inspector import (
- inspect_sqlalchemy_model,
- generate_llm_schema_from_models,
- _get_python_type_str_from_pydantic_annotation,
- _map_sql_type_to_llm_type,
- _collect_all_sqla_models_recursively,
- discover_sqlmodels_from_root,
- _get_involved_foreign_keys,
- _process_relationship_for_llm_schema,
-)
-
+from extrai.core.schema_inspector import SchemaInspector
from tests.core.helpers.orchestrator_test_models import (
Base,
Department,
Employee,
+ StatusEnum,
Project,
Member,
ArticleScenarioModel,
@@ -59,1155 +44,463 @@
)
-class SQLAlchemyBaseTestCase(unittest.TestCase):
- engine = None
-
- @classmethod
- def setUpClass(cls):
- cls.engine = create_engine("sqlite:///:memory:")
- Base.metadata.create_all(cls.engine, checkfirst=True) # Uses imported Base
- SQLModel.metadata.create_all(
- cls.engine, checkfirst=True
- ) # Uses imported SQLModel
-
-
-class TestInspectSqlalchemyModel(SQLAlchemyBaseTestCase):
- def test_employee_model_inspection_with_metadata(self):
- schema = inspect_sqlalchemy_model(Employee)
- self.assertEqual(
- schema["comment"], "Stores detailed information about company employees."
- )
- self.assertEqual(schema["info_dict"], {"confidentiality": "high"})
- self.assertEqual(schema["columns"]["id"]["comment"], "Unique Employee ID (PK)")
- self.assertEqual(
- schema["columns"]["email"]["info_dict"],
- {"validation_rule": "standard_email_format"},
- )
- self.assertIsNone(schema["columns"]["email"]["comment"])
- self.assertEqual(
- schema["relationships"]["department"]["info_dict"],
- {"description": "The department this employee is assigned to."},
- )
-
- def test_department_model_inspection_with_metadata(self):
- schema = inspect_sqlalchemy_model(Department)
- self.assertEqual(schema["comment"], "Stores all company departments.")
- self.assertEqual(schema["info_dict"], {"schema_version": "1.2"})
- self.assertEqual(
- schema["columns"]["id"]["comment"], "Unique Department ID (PK)"
- )
- self.assertEqual(schema["columns"]["id"]["info_dict"], {"pk_strategy": "auto"})
- self.assertEqual(
- schema["columns"]["name"]["comment"], "Official name of the department."
- )
- self.assertEqual(
- schema["relationships"]["employees"]["info_dict"],
- {"relationship_detail": "All employees belonging to this department."},
- )
-
- def test_employee_with_nested_department_and_loops(self):
- schema = inspect_sqlalchemy_model(Employee)
- self.assertEqual(schema["table_name"], "employees")
- dept_rel_info = schema["relationships"]["department"]
- nested_dept_schema = dept_rel_info["nested_schema"]
- self.assertEqual(nested_dept_schema["table_name"], "departments")
- dept_employees_rel = nested_dept_schema["relationships"]["employees"]
- employee_recursion_schema = dept_employees_rel["nested_schema"]
- self.assertEqual(
- employee_recursion_schema["recursion_detected_for_type"], "Employee"
- )
- manager_rel_info = schema["relationships"]["manager"]
- manager_recursion_schema = manager_rel_info["nested_schema"]
- self.assertEqual(
- manager_recursion_schema["recursion_detected_for_type"], "Employee"
- )
-
- def test_department_with_nested_employees(self):
- schema = inspect_sqlalchemy_model(Department)
- emp_rel_info = schema["relationships"]["employees"]
- nested_emp_schema = emp_rel_info["nested_schema"]
- self.assertEqual(nested_emp_schema["table_name"], "employees")
- emp_dept_rel = nested_emp_schema["relationships"]["department"]
- department_recursion_schema = emp_dept_rel["nested_schema"]
- self.assertEqual(
- department_recursion_schema["recursion_detected_for_type"], "Department"
- )
-
- def test_project_with_nested_members_and_loops_m2m(self):
- schema = inspect_sqlalchemy_model(Project)
- members_rel_info = schema["relationships"]["members"]
- self.assertEqual(members_rel_info["related_model_name"], "Member")
- nested_member_schema = members_rel_info["nested_schema"]
- member_projects_rel = nested_member_schema["relationships"]["projects"]
- project_recursion_schema = member_projects_rel["nested_schema"]
- self.assertEqual(
- project_recursion_schema["recursion_detected_for_type"], "Project"
- )
- self.assertEqual(
- members_rel_info["info_dict"],
- {"description": "Members participating in this project."},
- )
- self.assertEqual(members_rel_info["secondary_table_name"], "project_member")
-
- def test_foreign_key_details_still_correct(self):
- emp_schema = inspect_sqlalchemy_model(Employee)
- dept_rel = emp_schema["relationships"]["department"]
- self.assertIn(
- "employees.department_id", dept_rel["foreign_key_constraints_involved"]
- )
- proj_schema = inspect_sqlalchemy_model(Project)
- members_rel = proj_schema["relationships"]["members"]
- self.assertIn(
- "project_member.project_id", members_rel["foreign_key_constraints_involved"]
- )
- self.assertIn(
- "project_member.member_id", members_rel["foreign_key_constraints_involved"]
- )
-
- def test_table_support_screws_deep_inspection(self):
- schema = inspect_sqlalchemy_model(
- TableModel
- ) # Updated from Table to TableModel
- self.assertEqual(schema["table_name"], "tables")
- self.assertEqual(schema["comment"], "Stores information about tables.")
-
- supports_rel_info = schema["relationships"]["supports"]
- self.assertEqual(supports_rel_info["related_model_name"], "Support")
- self.assertEqual(
- supports_rel_info["info_dict"]["description"],
- "List of supports for this table",
- )
- nested_support_schema = supports_rel_info["nested_schema"]
- self.assertEqual(nested_support_schema["table_name"], "supports")
- self.assertEqual(
- nested_support_schema["comment"], "Stores information about supports."
- )
- self.assertEqual(
- nested_support_schema["columns"]["name"]["comment"], "Name of the support"
- )
-
- screws_rel_info = nested_support_schema["relationships"]["screws_list"]
- self.assertEqual(screws_rel_info["related_model_name"], "Screws")
- self.assertEqual(
- screws_rel_info["info_dict"]["description"],
- "List of screws for this support",
- )
- nested_screws_schema = screws_rel_info["nested_schema"]
- self.assertEqual(nested_screws_schema["table_name"], "screws")
- self.assertEqual(
- nested_screws_schema["comment"], "Stores information about screws."
- )
- self.assertEqual(
- nested_screws_schema["columns"]["size"]["comment"],
- "Size of the screw (e.g., M5x20)",
- )
-
- screw_support_rel_info = nested_screws_schema["relationships"]["support"]
- screw_support_recursion_schema = screw_support_rel_info["nested_schema"]
- self.assertEqual(
- screw_support_recursion_schema["recursion_detected_for_type"], "Support"
- )
-
- support_table_rel_info = nested_support_schema["relationships"]["table"]
- support_table_recursion_schema = support_table_rel_info["nested_schema"]
- self.assertEqual(
- support_table_recursion_schema["recursion_detected_for_type"], "TableModel"
- ) # Updated from "Table"
-
- def test_skip_column_property_in_columns_list(self):
- schema = inspect_sqlalchemy_model(ModelWithColumnProperty)
- self.assertNotIn(
- "error", schema, f"Schema inspection failed: {schema.get('error')}"
- )
- self.assertIn("id", schema["columns"])
- self.assertIn("data", schema["columns"])
- self.assertNotIn(
- "data_length",
- schema["columns"],
- "Column properties should not be in the 'columns' dict",
- )
-
- @patch("extrai.core.schema_inspector.inspect")
- def test_relationship_not_relationship_property(self, mock_inspect_outer):
- local_base_for_rel_test: Any = declarative_base()
-
- class ModelForNonRelPropTest(local_base_for_rel_test):
- __tablename__ = "model_non_rel_prop_test_local"
- id: Mapped[int] = mapped_column(primary_key=True)
-
- ModelForNonRelPropTest.metadata.create_all(self.engine, checkfirst=True)
+@pytest.fixture(scope="module")
+def engine():
+ e = create_engine("sqlite:///:memory:")
+ Base.metadata.create_all(e)
+ SQLModel.metadata.create_all(e)
+ return e
+
+
+@pytest.fixture
+def inspector(engine):
+ return SchemaInspector()
+
+
+@pytest.mark.parametrize(
+ "model_cls, checks",
+ [
+ (
+ Employee,
+ [
+ lambda s: s["comment"]
+ == "Stores detailed information about company employees.",
+ lambda s: s["info_dict"] == {"confidentiality": "high"},
+ lambda s: s["columns"]["id"]["comment"] == "Unique Employee ID (PK)",
+ lambda s: s["columns"]["email"]["info_dict"]["validation_rule"]
+ == "standard_email_format",
+ lambda s: s["relationships"]["department"]["info_dict"]["description"]
+ == "The department this employee is assigned to.",
+ lambda s: s["relationships"]["department"]["nested_schema"][
+ "table_name"
+ ]
+ == "departments",
+ lambda s: s["relationships"]["department"]["nested_schema"][
+ "relationships"
+ ]["employees"]["nested_schema"]["recursion_detected_for_type"]
+ == "Employee",
+ lambda s: "employees.department_id"
+ in s["relationships"]["department"]["foreign_key_constraints_involved"],
+ ],
+ ),
+ (
+ Department,
+ [
+ lambda s: s["comment"] == "Stores all company departments.",
+ lambda s: s["columns"]["id"]["comment"] == "Unique Department ID (PK)",
+ lambda s: s["relationships"]["employees"]["info_dict"][
+ "relationship_detail"
+ ]
+ == "All employees belonging to this department.",
+ lambda s: s["relationships"]["employees"]["nested_schema"]["table_name"]
+ == "employees",
+ lambda s: s["relationships"]["employees"]["nested_schema"][
+ "relationships"
+ ]["department"]["nested_schema"]["recursion_detected_for_type"]
+ == "Department",
+ ],
+ ),
+ (
+ Project,
+ [
+ lambda s: s["relationships"]["members"]["related_model_name"]
+ == "Member",
+ lambda s: s["relationships"]["members"]["secondary_table_name"]
+ == "project_member",
+ lambda s: "project_member.project_id"
+ in s["relationships"]["members"]["foreign_key_constraints_involved"],
+ ],
+ ),
+ (
+ TableModel,
+ [
+ lambda s: s["table_name"] == "tables",
+ lambda s: s["relationships"]["supports"]["nested_schema"][
+ "relationships"
+ ]["screws_list"]["nested_schema"]["relationships"]["support"][
+ "nested_schema"
+ ]["recursion_detected_for_type"]
+ == "Support",
+ ],
+ ),
+ (ModelWithColumnProperty, [lambda s: "data_length" not in s["columns"]]),
+ (
+ ModelWithCustomColType,
+ [
+ lambda s: s["columns"]["custom_field"]["python_type"]
+ == "unknown_error_accessing_type"
+ ],
+ ),
+ (
+ FKParent,
+ [
+ lambda s: "fk_child_sync.parent_id_col"
+ in s["relationships"]["children_sync"][
+ "foreign_key_constraints_involved"
+ ]
+ ],
+ ),
+ (
+ FKParentDirect,
+ [
+ lambda s: "fk_child_direct.parent_fk_col_name"
+ in s["relationships"]["children_direct"][
+ "foreign_key_constraints_involved"
+ ]
+ ],
+ ),
+ (
+ ViewOnlyParent,
+ [
+ lambda s: "viewonly_child.parent_id"
+ in s["relationships"]["children"]["foreign_key_constraints_involved"]
+ ],
+ ),
+ ],
+)
+def test_inspect_sqlalchemy_model_basics(inspector, model_cls, checks):
+ schema = inspector.inspect_sqlalchemy_model(model_cls)
+ assert "error" not in schema, f"Inspection failed: {schema.get('error')}"
+ for check in checks:
+ assert check(schema)
- class MockNonRelProperty:
- key = "fake_non_rel"
- mock_inspector_instance = MagicMock()
- mock_id_attr = MagicMock(spec=InstrumentedAttribute)
- mock_id_attr.expression = Column(Integer)
- mock_id_attr.key = "id"
- mock_inspector_instance.column_attrs = [mock_id_attr]
+def test_relationship_not_relationship_property(inspector):
+ local_base = declarative_base()
- mock_inspector_instance.relationships = {"fake_non_rel": MockNonRelProperty()}
- mock_inspector_instance.selectable = ModelForNonRelPropTest.__table__
- mock_inspector_instance.mapper = MagicMock()
- mock_inspector_instance.mapper.class_ = ModelForNonRelPropTest
+ class ModelNonRel(local_base):
+ __tablename__ = "model_non_rel"
+ id: Mapped[int] = mapped_column(primary_key=True)
- mock_inspect_outer.return_value = mock_inspector_instance
+ mock_inspector = MagicMock()
+ mock_inspector.relationships = {"fake": MagicMock()} # Not RelationshipProperty
+ mock_inspector.column_attrs = []
+ mock_inspector.selectable = ModelNonRel.__table__
+ mock_inspector.mapper.class_ = ModelNonRel
- schema = inspect_sqlalchemy_model(ModelForNonRelPropTest)
+ with patch("extrai.core.schema_inspector.inspect", return_value=mock_inspector):
+ schema = inspector.inspect_sqlalchemy_model(ModelNonRel)
+ assert "fake" not in schema.get("relationships", {})
- self.assertNotIn(
- "error", schema, f"Schema inspection failed: {schema.get('error')}"
- )
- self.assertNotIn(
- "fake_non_rel",
- schema.get("relationships", {}),
- "Items not of type RelationshipProperty should be skipped.",
- )
- def test_inspector_is_none_error(self):
- class SomeModelInternal(Base):
- __tablename__ = "some_model_for_inspector_none_internal"
- id: Mapped[int] = mapped_column(primary_key=True)
-
- SomeModelInternal.metadata.create_all(self.engine, checkfirst=True)
-
- with patch(
- "extrai.core.schema_inspector.inspect",
- return_value=None,
- ) as mock_inspect:
- schema = inspect_sqlalchemy_model(SomeModelInternal)
- mock_inspect.assert_called_once_with(SomeModelInternal)
- self.assertIn("error", schema)
- self.assertIn(f"Inspector is None for {SomeModelInternal}", schema["error"])
-
- def test_column_python_type_access_error(self):
- schema = inspect_sqlalchemy_model(ModelWithCustomColType)
- self.assertNotIn(
- "error", schema, f"Schema inspection failed: {schema.get('error')}"
- )
- self.assertIn("custom_field", schema["columns"])
- self.assertEqual(
- schema["columns"]["custom_field"]["python_type"],
- "unknown_error_accessing_type",
- )
+def test_inspector_none_error(inspector, engine):
+ class SomeModel(Base):
+ __tablename__ = "some_model_none"
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ SomeModel.metadata.create_all(engine)
+
+ with patch("extrai.core.schema_inspector.inspect", return_value=None):
+ schema = inspector.inspect_sqlalchemy_model(SomeModel)
+ assert "error" in schema
- def test_relationship_fk_path_synchronize_pairs(self):
- schema = inspect_sqlalchemy_model(FKParent)
- self.assertNotIn(
- "error", schema, f"Schema inspection failed: {schema.get('error')}"
- )
- children_sync_rel = schema["relationships"].get("children_sync")
- self.assertIsNotNone(children_sync_rel, "children_sync relationship not found")
- self.assertIn(
- "fk_child_sync.parent_id_col",
- children_sync_rel["foreign_key_constraints_involved"],
- )
+def test_column_python_type_attribute_error(inspector, engine):
+ local_base = declarative_base()
+
+ class ModelAttrError(local_base):
+ __tablename__ = "model_attr_error"
+ id: Mapped[int] = mapped_column(primary_key=True)
+ col: Mapped[str] = mapped_column(String)
- def test_relationship_fk_path_direct_foreign_keys_attr(self):
- schema = inspect_sqlalchemy_model(FKParentDirect)
- self.assertNotIn(
- "error", schema, f"Schema inspection failed: {schema.get('error')}"
- )
+ mock_insp = MagicMock()
+ mock_insp.selectable = ModelAttrError.__table__
+ mock_insp.mapper.class_ = ModelAttrError
+ mock_col = MagicMock(spec=Column)
+ mock_col.name = "col"
+ mock_col.unique = False
+ mock_col.table = None
+ mock_col.primary_key = False
+ mock_col.nullable = True
+ mock_col.comment = None
+ mock_col.info = {}
+ mock_col.foreign_keys = set()
+
+ mock_col.type = MagicMock()
+ type(mock_col.type).python_type = MagicMock(side_effect=AttributeError)
+
+ mock_attr = MagicMock(spec=InstrumentedAttribute)
+ mock_attr.key = "col"
+ mock_attr.expression = mock_col
+ mock_insp.column_attrs = [mock_attr]
+ mock_insp.relationships = {}
+
+ with patch("extrai.core.schema_inspector.inspect", return_value=mock_insp):
+ schema = inspector.inspect_sqlalchemy_model(ModelAttrError)
+ assert schema["columns"]["col"]["python_type"] == "unknown_no_python_type_attr"
+
+
+def test_relationship_no_fks(inspector, engine):
+ local_base = declarative_base()
+
+ class Dest(local_base):
+ __tablename__ = "dest"
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ class Src(local_base):
+ __tablename__ = "src"
+ id: Mapped[int] = mapped_column(primary_key=True)
+ rel = relationship(
+ "Dest", viewonly=True, primaryjoin="foreign(Src.id)==remote(Dest.id)"
+ )
+
+ local_base.metadata.create_all(engine)
+
+ schema = inspector.inspect_sqlalchemy_model(Src)
+ assert schema["relationships"]["rel"]["foreign_key_constraints_involved"] == []
+
+
+def test_string_based_enum_in_column(inspector, engine):
+ local_base = declarative_base()
+
+ class ModelEnum(local_base):
+ __tablename__ = "model_enum"
+ id: Mapped[int] = mapped_column(primary_key=True)
+ status: Mapped[str] = mapped_column(Enum("A", "B", "C", name="status_enum"))
+
+ local_base.metadata.create_all(engine)
+ schema = inspector.inspect_sqlalchemy_model(ModelEnum)
+ assert schema["columns"]["status"]["enum_values"] == ["A", "B", "C"]
+
+
+def test_helper_coverage(inspector):
+ # _get_fks_from_direct_foreign_keys
+ m_fk = MagicMock()
+ m_fk.__str__.return_value = "t.c"
+ mock_rel = MagicMock(spec=RelationshipProperty)
+ mock_rel.secondary = None
+ mock_rel.synchronize_pairs = None
+ mock_rel.foreign_keys = {m_fk}
+ assert inspector._get_involved_foreign_keys(mock_rel) == {"t.c"}
+
+ # _get_fks_from_synchronize_pairs
+ m_local = MagicMock()
+ m_local.__str__.return_value = "t.l"
+ m_local.foreign_keys = {"fk"}
+ m_remote = MagicMock()
+ m_remote.__str__.return_value = "t.r"
+ m_remote.foreign_keys = set()
+
+ mock_rel_sync = MagicMock(spec=RelationshipProperty)
+ mock_rel_sync.secondary = None
+ mock_rel_sync.synchronize_pairs = [(m_local, m_remote)]
+ type(mock_rel_sync).foreign_keys = None
+ assert inspector._get_involved_foreign_keys(mock_rel_sync) == {"t.l"}
+
+ # fallback
+ mock_rel_fall = MagicMock(spec=RelationshipProperty)
+ mock_rel_fall.secondary = None
+ mock_rel_fall.synchronize_pairs = None
+ type(mock_rel_fall).foreign_keys = None
+ assert inspector._get_involved_foreign_keys(mock_rel_fall) == set()
+
+
+def test_default_model_description(inspector, engine):
+ class NoDesc(SQLModel, table=True):
+ __tablename__ = "no_desc"
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+ SQLModel.metadata.create_all(engine)
+
+ schema = json.loads(inspector.generate_llm_schema_from_models([NoDesc]))
+ assert "Represents a NoDesc entity" in schema["NoDesc"]["description"]
+
+
+@pytest.mark.parametrize(
+ "rel_data, expected_id, expected_desc_part",
+ [
+ (
+ {"related_model_name": "M", "type": "MANYTOONE", "uselist": False},
+ "some_rel_ref_id",
+ "_temp_id of the related M",
+ ),
+ (
+ {"related_model_name": "M", "type": "ONETOMANY", "uselist": True},
+ "some_rel_ref_ids",
+ "list of _temp_ids for related M",
+ ),
+ ({"related_model_name": "M", "type": "ONETOMANY"}, None, None), # No uselist
+ ],
+)
+def test_process_relationship_for_llm_schema(
+ inspector, rel_data, expected_id, expected_desc_part
+):
+ res = inspector._process_relationship_for_llm_schema("some_rel", rel_data, {})
+ if expected_id is None:
+ assert res is None
+ else:
+ assert res[0] == expected_id
+ assert expected_desc_part in res[1]
+
+
+@pytest.mark.parametrize(
+ "models, check_fn",
+ [
+ ([Department], lambda s: "employees_ref_ids" in s["Department"]["fields"]),
+ (
+ [Employee, Department],
+ lambda s: "department_ref_id" in s["Employee"]["fields"],
+ ),
+ (
+ [Project, Member],
+ lambda s: "members_ref_ids" in s["Project"]["fields"]
+ and "projects_ref_ids" in s["Member"]["fields"],
+ ),
+ ([], lambda s: s == {}),
+ (
+ [Employee],
+ lambda s: all(
+ m.value in s["Employee"]["fields"]["status"] for m in StatusEnum
+ ),
+ ),
+ (
+ [ArticleScenarioModel],
+ lambda s: "array[string]"
+ in s["ArticleScenarioModel"]["fields"]["key_topics"].lower(),
+ ),
+ (
+ [PlainSQLAlchemyModelWithPydanticHints],
+ lambda s: s["PlainSQLAlchemyModelWithPydanticHints"]["fields"][
+ "complex_field"
+ ].startswith("object //"),
+ ),
+ ],
+)
+def test_generate_llm_schema(inspector, models, check_fn):
+ schema_str = inspector.generate_llm_schema_from_models(models)
+ schema = json.loads(schema_str)
+ assert check_fn(schema)
- children_direct_rel = schema["relationships"].get("children_direct")
- self.assertIsNotNone(
- children_direct_rel, "children_direct relationship not found"
- )
- self.assertIn(
- "fk_child_direct.parent_fk_col_name",
- children_direct_rel["foreign_key_constraints_involved"],
- )
- def test_relationship_fk_path_viewonly_rel_foreign_keys_attr(self):
- schema = inspect_sqlalchemy_model(ViewOnlyParent)
- self.assertNotIn(
- "error", schema, f"Schema inspection failed: {schema.get('error')}"
+def test_custom_descriptions_override(inspector):
+ custom = {"Employee": {"name": "Custom Name", "_model_description": "Custom Model"}}
+ schema = json.loads(
+ inspector.generate_llm_schema_from_models(
+ [Employee], custom_field_descriptions=custom
)
+ )
+ assert "Custom Name" in schema["Employee"]["fields"]["name"]
+ assert "Custom Model" in schema["Employee"]["description"]
- children_rel = schema["relationships"].get("children")
- self.assertIsNotNone(
- children_rel, "children relationship not found in ViewOnlyParent schema"
- )
- self.assertIn(
- "viewonly_child.parent_id",
- children_rel["foreign_key_constraints_involved"],
- "Expected 'viewonly_child.parent_id' to be in foreign_key_constraints_involved for viewonly relationship.",
- )
- def test_column_python_type_attribute_error(self):
- # Covers line 265: python_type_name = 'unknown_no_python_type_attr'
- local_base_for_attr_error: Any = declarative_base()
-
- class ModelForAttrError(local_base_for_attr_error):
- __tablename__ = "model_for_attr_error"
- id: Mapped[int] = mapped_column(primary_key=True)
- problematic_col: Mapped[str] = mapped_column(String)
-
- ModelForAttrError.metadata.create_all(self.engine, checkfirst=True)
-
- mock_inspector = MagicMock()
- mock_inspector.selectable = ModelForAttrError.__table__
- mock_inspector.mapper = MagicMock()
- mock_inspector.mapper.class_ = ModelForAttrError
-
- # Mock column attribute and its expression (the Column object)
- mock_col_attr = MagicMock(spec=InstrumentedAttribute)
- mock_col_attr.key = "problematic_col"
-
- mock_column_obj = MagicMock(spec=Column)
- mock_column_obj.name = "problematic_col"
- mock_column_obj.unique = False
- mock_column_obj.table = ModelForAttrError.__table__
- mock_column_obj.primary_key = False
- mock_column_obj.nullable = True
- mock_column_obj.comment = None
- mock_column_obj.info = {}
- mock_column_obj.foreign_keys = set()
-
- # Mock the 'type' attribute of the column object
- mock_type_obj = MagicMock()
- # Make accessing 'python_type' on mock_type_obj raise AttributeError
- del mock_type_obj.python_type # Ensure it doesn't exist
- type_prop = MagicMock(
- side_effect=AttributeError("Simulated AttributeError for python_type")
- )
- type(mock_type_obj).python_type = type_prop # Mocking the property itself
+def test_not_a_model_skip(inspector):
+ class NotModel:
+ pass
- mock_column_obj.type = mock_type_obj
- mock_col_attr.expression = mock_column_obj
+ schema = json.loads(
+ inspector.generate_llm_schema_from_models([Department, NotModel])
+ )
+ assert "Department" in schema
+ assert "NotModel" not in schema
- mock_inspector.column_attrs = [mock_col_attr]
- mock_inspector.relationships = {} # No relationships for this test
- with patch(
- "extrai.core.schema_inspector.inspect",
- return_value=mock_inspector,
- ):
- schema = inspect_sqlalchemy_model(ModelForAttrError)
+def test_non_serializable_info(inspector):
+ # Column info
+ schema_col = json.loads(
+ inspector.generate_llm_schema_from_models([ModelWithNonSerializableColInfo])
+ )
+ assert "Info: {" in schema_col["ModelWithNonSerializableColInfo"]["fields"]["data"]
- self.assertNotIn(
- "error", schema, f"Schema inspection failed: {schema.get('error')}"
- )
- self.assertIn("problematic_col", schema["columns"])
- self.assertEqual(
- schema["columns"]["problematic_col"]["python_type"],
- "unknown_no_python_type_attr",
+ # Relationship info
+ schema_rel = json.loads(
+ inspector.generate_llm_schema_from_models(
+ [ModelWithNonSerializableRelInfo, ModelRelatedToNonSerializable]
)
+ )
+ assert (
+ "Info: {"
+ in schema_rel["ModelWithNonSerializableRelInfo"]["fields"]["related_ref_id"]
+ )
- def test_relationship_with_no_foreign_keys(self):
- """
- Tests that a relationship with no discernible foreign keys results in an empty set
- for 'foreign_key_constraints_involved' when testing the public API.
- """
- # Define models locally to avoid polluting the global Base metadata
- local_base: Any = declarative_base()
-
- class LocalSomeOtherModel(local_base):
- __tablename__ = "local_some_other_model"
- id: Mapped[int] = mapped_column(primary_key=True)
-
- class LocalModelWithNoFKs(local_base):
- __tablename__ = "local_model_no_fks"
- id: Mapped[int] = mapped_column(primary_key=True)
- related = sa_relationship(
- "LocalSomeOtherModel",
- viewonly=True,
- primaryjoin="foreign(LocalModelWithNoFKs.id) == remote(LocalSomeOtherModel.id)",
- )
-
- # Create tables locally for this test
- local_base.metadata.create_all(self.engine)
-
- schema = inspect_sqlalchemy_model(LocalModelWithNoFKs)
- self.assertNotIn(
- "error", schema, f"Schema inspection failed: {schema.get('error')}"
- )
- self.assertIn("related", schema["relationships"])
- self.assertEqual(
- schema["relationships"]["related"]["foreign_key_constraints_involved"], []
- )
+ # Table info
+ schema_table = json.loads(
+ inspector.generate_llm_schema_from_models([ModelWithNonSerializableTableInfo])
+ )
+ assert "Info: {" in schema_table["ModelWithNonSerializableTableInfo"]["description"]
- def test_direct_foreign_keys_and_sync_pairs_coverage(self):
- """
- Covers the remaining lines using mocks to ensure the exact conditions are met.
- """
- with self.subTest("Test _get_fks_from_direct_foreign_keys"):
- mock_rel_prop = MagicMock(spec=RelationshipProperty)
- mock_fk_col = MagicMock(spec=Column)
- mock_fk_col.__str__.return_value = "mock_table.fk_col"
- # Configure all attributes that will be accessed on the mock
- mock_rel_prop.secondary = None
- mock_rel_prop.synchronize_pairs = None
- mock_rel_prop.foreign_keys = {mock_fk_col}
-
- result = _get_involved_foreign_keys(mock_rel_prop)
- self.assertEqual(result, {"mock_table.fk_col"})
-
- with self.subTest("Test _get_fks_from_synchronize_pairs with local FK"):
- mock_rel_prop_sync = MagicMock(spec=RelationshipProperty)
- mock_local_col = MagicMock(spec=Column)
- mock_local_col.foreign_keys = {"a_foreign_key"}
- mock_local_col.__str__.return_value = "mock_table.local_col"
- mock_remote_col = MagicMock(spec=Column)
- mock_remote_col.foreign_keys = set()
- mock_remote_col.__str__.return_value = "mock_table.remote_col"
-
- # Configure all attributes that will be accessed on the mock
- mock_rel_prop_sync.secondary = None
- mock_rel_prop_sync.synchronize_pairs = [(mock_local_col, mock_remote_col)]
- # To ensure we don't short-circuit, make foreign_keys None
- # This is because hasattr check is now more robust.
- type(mock_rel_prop_sync).foreign_keys = None
-
- result_sync = _get_involved_foreign_keys(mock_rel_prop_sync)
- self.assertEqual(result_sync, {"mock_table.local_col"})
-
- with self.subTest("Test fallback to empty set"):
- mock_rel_prop_fallback = MagicMock(spec=RelationshipProperty)
- mock_rel_prop_fallback.secondary = None
- mock_rel_prop_fallback.synchronize_pairs = None
- type(mock_rel_prop_fallback).foreign_keys = None
-
- result_fallback = _get_involved_foreign_keys(mock_rel_prop_fallback)
- self.assertEqual(result_fallback, set())
-
-
-class TestGenerateLLMSchemaFromModels(SQLAlchemyBaseTestCase):
- def test_process_relationship_for_llm_schema_uselist_flag(self):
- """
- Tests that the `uselist` flag correctly determines whether to generate
- a singular or plural reference ID field.
- """
- # Test case for uselist=False (e.g., ManyToOne, OneToOne)
- rel_data_single = {
- "related_model_name": "SomeModel",
- "type": "MANYTOONE",
- "uselist": False,
- }
- custom_descs = {}
- result_single = _process_relationship_for_llm_schema(
- "some_rel", rel_data_single, custom_descs
- )
- self.assertIsNotNone(result_single)
- self.assertEqual(result_single[0], "some_rel_ref_id")
- self.assertIn("The _temp_id of the related SomeModel", result_single[1])
-
- # Test case for uselist=True (e.g., OneToMany, ManyToMany)
- rel_data_list = {
- "related_model_name": "SomeModel",
- "type": "ONETOMANY",
- "uselist": True,
- }
- result_list = _process_relationship_for_llm_schema(
- "some_rel", rel_data_list, custom_descs
- )
- self.assertIsNotNone(result_list)
- self.assertEqual(result_list[0], "some_rel_ref_ids")
- self.assertIn("A list of _temp_ids for related SomeModel", result_list[1])
-
- def test_process_relationship_for_llm_schema_no_uselist_flag(self):
- """
- Tests that if the `uselist` flag is missing, the function returns None.
- """
- rel_data_no_uselist = {
- "related_model_name": "SomeModel",
- "type": "ONETOMANY",
- # 'uselist' key is intentionally omitted
- }
- custom_descs = {}
- result = _process_relationship_for_llm_schema(
- "some_rel", rel_data_no_uselist, custom_descs
- )
- self.assertIsNone(result)
-
- def test_single_model_basic_schema(self):
- llm_schema_str = generate_llm_schema_from_models([Department])
- self.assertIsInstance(llm_schema_str, str)
- schema = json.loads(llm_schema_str)
- self.assertIn("Department", schema)
- dept_schema = schema["Department"]
- self.assertIn("description", dept_schema)
- self.assertIn("fields", dept_schema)
- self.assertIn("notes_for_llm", dept_schema)
- self.assertIn("Unique Department ID (PK)", dept_schema["fields"]["id"])
- self.assertIn("Official name of the department.", dept_schema["fields"]["name"])
- self.assertIn("employees_ref_ids", dept_schema["fields"])
- self.assertIn(
- "All employees belonging to this department.",
- dept_schema["fields"]["employees_ref_ids"],
- )
- def test_related_models_schema_and_descriptions(self):
- llm_schema_str = generate_llm_schema_from_models([Employee, Department])
- schema = json.loads(llm_schema_str)
- self.assertIn("Employee", schema)
- self.assertIn("Department", schema)
- emp_schema = schema["Employee"]
- self.assertIn("Full legal name of the employee.", emp_schema["fields"]["name"])
- self.assertIn("standard_email_format", emp_schema["fields"]["email"])
- self.assertIn(
- "The department this employee is assigned to.",
- emp_schema["fields"]["department_ref_id"],
- )
- dept_schema = schema["Department"]
- self.assertIn("Official name of the department.", dept_schema["fields"]["name"])
- self.assertIn("pk_strategy", dept_schema["fields"]["id"])
-
- def test_many_to_many_relationship_schema(self):
- llm_schema_str = generate_llm_schema_from_models([Project, Member])
- schema = json.loads(llm_schema_str)
- self.assertIn("Project", schema)
- proj_schema = schema["Project"]
- self.assertIn("members_ref_ids", proj_schema["fields"])
- self.assertIn(
- "Members participating in this project.",
- proj_schema["fields"]["members_ref_ids"],
- )
- self.assertIn("Member", schema)
- mem_schema = schema["Member"]
- self.assertIn("projects_ref_ids", mem_schema["fields"])
- self.assertIn(
- "Projects this member is associated with.",
- mem_schema["fields"]["projects_ref_ids"],
- )
+def test_default_desc_bare_rel(inspector, engine):
+ class P(SQLModel, table=True):
+ __tablename__ = "p_bare"
+ id: Optional[int] = Field(default=None, primary_key=True)
+ children: List["C"] = Relationship(back_populates="parent")
- def test_custom_field_descriptions_override(self):
- custom_descriptions = {
- "Employee": {
- "name": "Employee's complete name as per official records.",
- "_model_description": "Custom model description for Employee.",
- "department_ref_id": "Custom description for department link.",
- }
- }
- llm_schema_str = generate_llm_schema_from_models(
- [Employee], custom_field_descriptions=custom_descriptions
- )
- schema = json.loads(llm_schema_str)
- emp_schema = schema["Employee"]
- self.assertIn(
- "Employee's complete name as per official records.",
- emp_schema["fields"]["name"],
- )
- self.assertIn(
- "Custom model description for Employee.", emp_schema["description"]
- )
- self.assertIn(
- "Custom description for department link.",
- emp_schema["fields"]["department_ref_id"],
- )
+ class C(SQLModel, table=True):
+ __tablename__ = "c_bare"
+ id: Optional[int] = Field(default=None, primary_key=True)
+ parent_id: Optional[int] = Field(default=None, foreign_key="p_bare.id")
+ parent: Optional[P] = Relationship(back_populates="children")
- def test_empty_model_list(self):
- llm_schema_str = generate_llm_schema_from_models([])
- schema = json.loads(llm_schema_str)
- self.assertEqual(schema, {})
-
- def test_model_inspection_error_handling(self):
- class NotAModel:
- __name__ = "NotAModel"
-
- llm_schema_str = generate_llm_schema_from_models([Department, NotAModel]) # type: ignore
- schema = json.loads(llm_schema_str)
- self.assertIn("Department", schema)
- self.assertNotIn("NotAModel", schema)
-
- def test_article_scenario_list_str_as_json_is_array_string(self):
- llm_schema_str = generate_llm_schema_from_models([ArticleScenarioModel])
- self.assertIsInstance(llm_schema_str, str)
- try:
- schema = json.loads(llm_schema_str)
- except json.JSONDecodeError as e:
- self.fail(
- f"LLM Schema is not valid JSON: {e}\nSchema content:\n{llm_schema_str}"
- )
-
- self.assertIn("ArticleScenarioModel", schema)
- model_schema = schema["ArticleScenarioModel"]
- self.assertIn("fields", model_schema)
- fields = model_schema["fields"]
-
- self.assertIn("key_topics", fields)
- field_desc_key_topics = fields["key_topics"].lower()
- self.assertIn("list of key topics.", field_desc_key_topics)
- self.assertNotIn("object //", field_desc_key_topics)
- self.assertNotIn("dict //", field_desc_key_topics)
- self.assertTrue(
- "array[string]" in field_desc_key_topics
- or "list[string]" in field_desc_key_topics,
- f"Expected 'array[string]' or 'list[string]' for key_topics, got: {fields['key_topics']}",
- )
+ P.model_rebuild()
+ C.model_rebuild()
+ SQLModel.metadata.create_all(engine)
- self.assertIn("categories", fields)
- field_desc_categories = fields["categories"].lower()
- self.assertIn("list of categories.", field_desc_categories)
- self.assertNotIn("object //", field_desc_categories)
- self.assertNotIn("dict //", field_desc_categories)
- self.assertTrue(
- "array[string]" in field_desc_categories
- or "list[string]" in field_desc_categories,
- f"Expected 'array[string]' or 'list[string]' for categories, got: {fields['categories']}",
- )
+ schema = json.loads(inspector.generate_llm_schema_from_models([C]))
+ assert "_temp_id of the related P" in schema["C"]["fields"]["parent_ref_id"]
- self.assertIn("meta_data", fields)
- field_desc_meta_data = fields["meta_data"].lower()
- self.assertIn("meta data dictionary.", field_desc_meta_data)
- self.assertTrue(
- "object[string,any]" in field_desc_meta_data
- or "dict[string,any]" in field_desc_meta_data,
- f"Expected 'object[string,any]' or 'dict[string,any]' for meta_data, got: {fields['meta_data']}",
- )
- def test_generate_schema_with_non_serializable_info_dicts(self):
- llm_schema_col_str = generate_llm_schema_from_models(
- [ModelWithNonSerializableColInfo]
- )
- schema_col = json.loads(llm_schema_col_str)
- self.assertIn("ModelWithNonSerializableColInfo", schema_col)
- field_info = schema_col["ModelWithNonSerializableColInfo"]["fields"]["data"]
- self.assertIn("non_serializable", field_info)
- self.assertIn("Info: {", field_info)
- self.assertTrue(
- "set()" in field_info
- or "{1, 2, 3}" in field_info
- or "{2, 1, 3}" in field_info
- )
+def test_discovery_functions(inspector):
+ class NoInspect:
+ pass
- llm_schema_rel_str = generate_llm_schema_from_models(
- [ModelWithNonSerializableRelInfo, ModelRelatedToNonSerializable]
- )
- schema_rel = json.loads(llm_schema_rel_str)
- self.assertIn("ModelWithNonSerializableRelInfo", schema_rel)
- rel_field_name = "related_ref_id"
- self.assertIn(
- rel_field_name, schema_rel["ModelWithNonSerializableRelInfo"]["fields"]
- )
- rel_field_description = schema_rel["ModelWithNonSerializableRelInfo"]["fields"][
- rel_field_name
- ]
- self.assertIn("non_serializable", rel_field_description)
- self.assertIn("(Info: {", rel_field_description)
- self.assertIn(" int (args[1] is NoneType)"):
- self.assertEqual(
- _get_python_type_str_from_pydantic_annotation(Optional[int]), "int"
- )
+ class GuardM(base):
+ __tablename__ = "guard_m"
+ id: Mapped[int] = mapped_column(primary_key=True)
- mock_get_origin.side_effect = (
- lambda x: Optional if x is Optional[bool] else typing_get_origin(x)
- )
- mock_get_args.side_effect = (
- lambda x: (bool,) if x is Optional[bool] else typing_get_args(x)
- )
- with self.subTest(scenario="Optional[bool] -> bool (single arg)"):
- self.assertEqual(
- _get_python_type_str_from_pydantic_annotation(Optional[bool]), "bool"
- )
+ with patch("extrai.core.schema_inspector.inspect") as mock_insp:
+ inspector._collect_all_sqla_models_recursively(GuardM, set(), {GuardM})
+ mock_insp.assert_not_called()
- mock_get_origin.side_effect = (
- lambda x: Optional if x is Optional[type(None)] else typing_get_origin(x)
- )
- mock_get_args.side_effect = (
- lambda x: (type(None),) if x is Optional[type(None)] else typing_get_args(x)
- )
- with self.subTest(
- scenario="Optional[type(None)] -> none (args[0] is NoneType)"
- ):
- self.assertEqual(
- _get_python_type_str_from_pydantic_annotation(Optional[type(None)]),
- "none",
- )
-
- mock_get_origin.side_effect = (
- lambda x: Optional if x is Optional[Any] else typing_get_origin(x)
- )
- mock_get_args.side_effect = (
- lambda x: () if x is Optional[Any] else typing_get_args(x)
- )
- with self.subTest(scenario="Optional[Any] -> none (empty args)"):
- self.assertEqual(
- _get_python_type_str_from_pydantic_annotation(Optional[Any]), "none"
- )
-
- def test_union_no_args_fallback(self):
- with patch(
- "extrai.core.schema_inspector.get_args",
- return_value=(),
- ) as mock_get_args_union:
- with patch(
- "extrai.core.schema_inspector.get_origin",
- side_effect=lambda x: Union if x is Union else typing_get_origin(x),
- ) as mock_get_origin_union:
- result = _get_python_type_str_from_pydantic_annotation(Union)
- self.assertEqual(result, "union")
- mock_get_args_union.assert_called_with(Union)
- mock_get_origin_union.assert_called_with(Union)
-
- def test_fallback_typing_prefix_in_cleaned_str(self):
- class WeirdTypingList:
- def __str__(self):
- return "typing.typing.List[custom.SubType]"
-
- class WeirdTypingDict:
- def __str__(self):
- return "typing.typing.Dict[str, int]"
-
- cases = [
- (WeirdTypingList(), "list[custom.subtype]"),
- (WeirdTypingDict(), "dict[str, int]"),
- ]
- for annotation, expected in cases:
- with self.subTest(annotation=str(annotation)):
- with (
- patch(
- "extrai.core.schema_inspector.get_origin",
- return_value=None,
- ),
- patch(
- "extrai.core.schema_inspector.get_args",
- return_value=(),
- ),
- ):
- self.assertEqual(
- _get_python_type_str_from_pydantic_annotation(annotation),
- expected,
- )
-
-
-class TestMapSqlTypeToLlmType(unittest.TestCase):
- def test_map_sql_type_to_llm_type(self):
- cases = [
- # From test_basic_python_type_mappings
- (("INT", "int"), "integer"),
- (("VARCHAR", "str"), "string"),
- (("BOOLEAN", "bool"), "boolean"),
- (("FLOAT", "float"), "number (float/decimal)"),
- (("DATE", "date"), "string (date format)"),
- (("DATETIME", "datetime"), "string (datetime format)"),
- (("BLOB", "bytes"), "string (base64 encoded)"),
- (("ENUM", "enum"), "string (enum)"),
- (("ANYSQL", "any"), "any"),
- (("ANYSQL", "none"), "null"),
- # From test_list_python_types
- (("JSON", "list[str]"), "array[string]"),
- (("ARRAY", "list[int]"), "array[integer]"),
- (("JSON", "list[list[str]]"), "array[array[string]]"),
- (("ANYSQL", "list[float]"), "array[number (float/decimal)]"),
- # From test_dict_python_types
- (("JSON", "dict[str,int]"), "object[string,integer]"),
- (("JSON", "dict[str,list[int]]"), "object[string,array[integer]]"),
- (("JSON", "dict[invalidformat]"), "object"),
- (("JSON", "dict[str]"), "object"),
- # From test_union_python_types
- (("ANYSQL", "union[int,str]"), "union[integer,string]"),
- (("ANYSQL", "union[str,int]"), "union[integer,string]"),
- (("ANYSQL", "union[int,none,str]"), "union[integer,null,string]"),
- (("ANYSQL", "union[int,null,str]"), "union[integer,string]"),
- (("ANYSQL", "union[]"), "any"),
- (("ANYSQL", "union[ ]"), "any"),
- (("ANYSQL", "union[int]"), "integer"),
- (("ANYSQL", "union[str,str]"), "string"),
- # From test_sql_type_fallbacks_and_priority
- (("INT", "str"), "string"),
- (("VARCHAR", "int"), "integer"),
- (("INTEGER", "some_generic_type"), "integer"),
- (("TEXT", "some_generic_type"), "string"),
- (("CHAR(50)", "some_generic_type"), "string"),
- (("CLOB", "some_generic_type"), "string"),
- (("BOOLEAN", "some_generic_type"), "boolean"),
- (("TIMESTAMP", "some_generic_type"), "string (date/datetime format)"),
- (("TIME", "some_other_type"), "string (date/datetime format)"),
- (("TIMESTAMP", "date"), "string (date format)"),
- (("DATE", "datetime"), "string (datetime format)"),
- (("NUMERIC(10,2)", "some_generic_type"), "number (float/decimal)"),
- (("DECIMAL", "some_generic_type"), "number (float/decimal)"),
- (("FLOAT", "some_generic_type"), "number (float/decimal)"),
- (("DOUBLE", "some_generic_type"), "number (float/decimal)"),
- # From test_list_dict_fallbacks_with_sql_types
- (("JSON", "list"), "array"),
- (("TEXT[]", "list"), "string"),
- (("CUSTOM_ARRAY_TYPE", "list"), "array"),
- (("SOME_OTHER_SQL_TYPE", "list"), "array"),
- (("JSON", "dict"), "object"),
- (("CUSTOM_OBJECT_TYPE", "dict"), "object"),
- # From test_generic_json_sql_type
- (("JSONB", "some_other_py_type"), "object"),
- (("JSON", "unspecific_python"), "object"),
- # From test_unknown_python_types
- (("JSON", "unknown_py_type"), "object"),
- (("TEXT[]", "unknown_py_type_array"), "string"),
- (("SOME_ARRAY_TYPE", "unknown_py_type_array_fallback"), "array"),
- (("VARCHAR", "unknown_py_type_str"), "string"),
- (("SOME_OTHER_SQL", "unknown_py_type_str_fallback"), "string"),
- # From test_final_fallback
- (("RARE_SQL_TYPE", "rare_python_type"), "string"),
- # From test_map_dict_with_json_in_sql_type_returns_object
- (("JSON", "dict"), "object"),
- (("jsonb", "dict"), "object"),
- (("some_json_type", "dict"), "object"),
- (("APPLICATION/JSON", "dict"), "object"),
- # From test_dict_fallback_with_non_json_non_empty_sql_type
- (("ARBITRARYSQLTYPE", "dict"), "object"),
- (("SOMEOTHEROPAQUETYPE", "dict"), "object"),
- # From test_unknown_python_type_with_json_sql_type
- (("JSON_VARIANT", "unknown_specific_case"), "object"),
- (("this_is_json_too", "unknown_anything"), "object"),
- ]
- for (sql_type, py_type), expected in cases:
- with self.subTest(sql_type=sql_type, py_type=py_type):
- self.assertEqual(_map_sql_type_to_llm_type(sql_type, py_type), expected)
-
-
-class TestDiscoveryAndCollectionFunctions(SQLAlchemyBaseTestCase):
- def test_collect_all_sqla_models_no_inspection(self):
- class NonInspectableModelForCollect:
- __name__ = "NonInspectableModelForCollect"
-
- discovered_models: List[Any] = []
- with patch(
- "extrai.core.schema_inspector.inspect",
- side_effect=NoInspectionAvailable,
- ) as mock_inspect:
- _collect_all_sqla_models_recursively(
- NonInspectableModelForCollect, discovered_models, set()
- )
- mock_inspect.assert_called_once_with(NonInspectableModelForCollect)
- self.assertIn(NonInspectableModelForCollect, discovered_models)
-
- def test_collect_all_sqla_models_inspector_none(self):
- class InspectReturnsNoneModelForCollect(Base):
- __tablename__ = "inspect_none_model_collect"
- id: Mapped[int] = mapped_column(primary_key=True)
-
- discovered_models: List[Any] = []
- with patch(
- "extrai.core.schema_inspector.inspect",
- return_value=None,
- ) as mock_inspect:
- _collect_all_sqla_models_recursively(
- InspectReturnsNoneModelForCollect, discovered_models, set()
- )
- mock_inspect.assert_called_once_with(InspectReturnsNoneModelForCollect)
- self.assertIn(InspectReturnsNoneModelForCollect, discovered_models)
-
- def test_discover_sqlmodels_from_root_with_invalid_input(self):
- class NotASQLModelForDiscover:
- __name__ = "NotASQLModelForDiscover"
-
- invalid_inputs = [None, NotASQLModelForDiscover]
- for invalid_input in invalid_inputs:
- with self.subTest(invalid_input=invalid_input):
- with patch("builtins.print") as mock_print:
- result = discover_sqlmodels_from_root(invalid_input)
- self.assertEqual(result, [])
- mock_print.assert_called_once()
- self.assertIn(
- "is not a valid SQLModel class", mock_print.call_args[0][0]
- )
-
- def test_discover_sqlmodels_from_root_collection_exception(self):
- class RootSQLModelForDiscoverException(SQLModel, table=True):
- __tablename__ = "root_sql_model_discover_exc"
- id: Optional[int] = Field(default=None, primary_key=True)
-
- with (
- patch(
- "extrai.core.schema_inspector._collect_all_sqla_models_recursively",
- side_effect=Exception("Test collection error"),
- ) as mock_collect,
- patch("builtins.print") as mock_print,
- ):
- result = discover_sqlmodels_from_root(RootSQLModelForDiscoverException)
- self.assertEqual(result, [])
- mock_collect.assert_called_once()
- mock_print.assert_called_once()
- self.assertIn(
- "Error during SQLModel discovery starting from RootSQLModelForDiscoverException: Test collection error",
- mock_print.call_args[0][0],
- )
-
- def test_collect_all_sqla_models_recursion_guard(self):
- local_base_for_rec_guard_hit: Any = declarative_base()
-
- class ModelToGuard(local_base_for_rec_guard_hit):
- __tablename__ = "model_to_guard"
- id: Mapped[int] = mapped_column(primary_key=True)
- related_dummy_id: Mapped[Optional[int]] = mapped_column(
- ForeignKey("model_to_guard.id")
- ) # Self-referential to simplify
- dummy_rel = sa_relationship("ModelToGuard")
-
- local_base_for_rec_guard_hit.metadata.create_all(self.engine, checkfirst=True)
-
- discovered_models: Set[Any] = set()
- recursion_guard: Set[Any] = {ModelToGuard}
-
- with patch("extrai.core.schema_inspector.inspect") as mock_inspect_call:
- _collect_all_sqla_models_recursively(
- ModelToGuard, discovered_models, recursion_guard
- )
- mock_inspect_call.assert_not_called()
-
- self.assertEqual(
- len(discovered_models),
- 0,
- "Discovered models should be empty if recursion guard is hit immediately.",
- )
- self.assertIn(
- ModelToGuard,
- recursion_guard,
- "ModelToGuard should remain in recursion_guard as it was pre-populated and hit.",
- )
+ # Success discovery
+ class R(SQLModel, table=True):
+ __tablename__ = "root_disc"
+ id: Optional[int] = Field(default=None, primary_key=True)
- def test_discover_sqlmodels_from_root_successful_collection(self):
- # Test the successful path of discover_sqlmodels_from_root (line 706)
- # where _collect_all_sqla_models_recursively completes without error.
-
- # Define models locally for clarity and to avoid altering shared test models
- class DiscoverRoot(SQLModel, table=True):
- __tablename__ = (
- "discover_root_table_for_line_706" # Ensure unique table name
- )
- id: Optional[int] = Field(default=None, primary_key=True)
- name: str = "Root"
-
- related_items: List["DiscoverRelatedForLine706"] = Relationship(
- back_populates="root_model"
- )
-
- class DiscoverRelatedForLine706(SQLModel, table=True):
- __tablename__ = (
- "discover_related_table_for_line_706" # Ensure unique table name
- )
- id: Optional[int] = Field(default=None, primary_key=True)
- data: str = "Related Data"
- root_model_id: Optional[int] = Field(
- default=None, foreign_key="discover_root_table_for_line_706.id"
- )
-
- root_model: Optional[DiscoverRoot] = Relationship(
- back_populates="related_items"
- )
-
- DiscoverRoot.model_rebuild()
- DiscoverRelatedForLine706.model_rebuild()
-
- SQLModel.metadata.create_all(
- self.engine,
- tables=[DiscoverRoot.__table__, DiscoverRelatedForLine706.__table__],
- checkfirst=True,
- )
+ SQLModel.metadata.create_all(create_engine("sqlite:///:memory:"))
+ assert len(inspector.discover_sqlmodels_from_root(R)) == 1
- discovered_models = discover_sqlmodels_from_root(DiscoverRoot)
- self.assertIsInstance(discovered_models, list)
- self.assertIn(DiscoverRoot, discovered_models)
- self.assertIn(DiscoverRelatedForLine706, discovered_models)
- self.assertEqual(len(discovered_models), 2)
-
- # Test with a standalone model to ensure it also hits the line
- class DiscoverStandaloneForLine706(SQLModel, table=True):
- __tablename__ = "discover_standalone_for_line_706"
- id: Optional[int] = Field(default=None, primary_key=True)
-
- SQLModel.metadata.create_all(
- self.engine,
- tables=[DiscoverStandaloneForLine706.__table__],
- checkfirst=True,
- )
- discovered_standalone = discover_sqlmodels_from_root(
- DiscoverStandaloneForLine706
- )
- self.assertIsInstance(discovered_standalone, list)
- self.assertIn(DiscoverStandaloneForLine706, discovered_standalone)
- self.assertEqual(len(discovered_standalone), 1)
+def test_unique_constraint_on_table(inspector, engine):
+ local_base = declarative_base()
+ class ModelUnique(local_base):
+ __tablename__ = "model_unique"
+ id: Mapped[int] = mapped_column(primary_key=True)
+ col: Mapped[str] = mapped_column(String)
+ __table_args__ = (UniqueConstraint("col"),)
-if __name__ == "__main__":
- unittest.main(argv=["first-arg-is-ignored"], verbosity=2, exit=False)
+ local_base.metadata.create_all(engine)
+ schema = inspector.inspect_sqlalchemy_model(ModelUnique)
+ assert schema["columns"]["col"]["unique"] is True
diff --git a/tests/core/test_sqlmodel_generator.py b/tests/core/test_sqlmodel_generator.py
index cfabfc9..16ebb09 100644
--- a/tests/core/test_sqlmodel_generator.py
+++ b/tests/core/test_sqlmodel_generator.py
@@ -2,34 +2,21 @@
import pytest_asyncio
import json
import os
-import re
import sys
-import ast
from unittest import mock
-from sqlalchemy import inspect as sqlalchemy_inspect
-from sqlmodel import SQLModel, create_engine as sqlmodel_create_engine
+from sqlmodel import SQLModel
from extrai.core.errors import (
LLMInteractionError,
LLMAPICallError,
ConfigurationError,
SQLModelCodeGeneratorError,
- SQLModelInstantiationValidationError,
)
from extrai.core.sqlmodel_generator import SQLModelCodeGenerator
from extrai.core.analytics_collector import (
WorkflowAnalyticsCollector,
)
from tests.core.helpers.mock_llm_clients import MockLLMClientSqlGen
-from tests.core.helpers.sqlmodel_generator_test_utils import (
- is_valid_python_code,
- get_field_call_kwargs,
- get_annotation_str,
- get_class_def_node,
-)
-
-# This file will contain the refactored tests for both
-# test_sqlmodel_generator_llm.py and test_sqlmodel_generator_codegen.py.
class TestSQLModelCodeGeneratorLLMIntegrationRefactored:
@@ -386,677 +373,3 @@ async def test_generate_and_load_models_via_llm_multiple_models_in_description(
event_name="dynamic_sqlmodel_class_generated_and_loaded_successfully",
details={"model_name": "Model1", "models_loaded": ["Model1", "Model2"]},
)
-
-
-class TestSQLModelCodeGeneratorCodeGenRefactored:
- @pytest.fixture(autouse=True)
- def clear_sqlmodel_metadata(self):
- """Fixture to clear SQLModel metadata before each test to prevent table redefinition errors."""
- original_metadata_tables = dict(SQLModel.metadata.tables)
- SQLModel.metadata.clear()
- yield
- SQLModel.metadata.clear()
- for table_obj in original_metadata_tables.values():
- table_obj.to_metadata(SQLModel.metadata)
-
- def setup_method(self):
- self.mock_llm_client = MockLLMClientSqlGen()
- self.mock_analytics_collector = mock.Mock(spec=WorkflowAnalyticsCollector)
- self.generator = SQLModelCodeGenerator(
- llm_client=self.mock_llm_client,
- analytics_collector=self.mock_analytics_collector,
- )
- SQLModelCodeGenerator._sqlmodel_description_schema_cache = None
-
- @pytest.fixture(scope="function")
- def comprehensive_code_ast_and_desc(self):
- comprehensive_description = {
- "sql_models": [
- {
- "table_name": "comprehensive_items",
- "model_name": "ComprehensiveItem",
- "description": 'A comprehensive model with "quotes" and\nnewlines.',
- "fields": [
- {
- "name": "id",
- "type": "Optional[int]",
- "primary_key": True,
- "nullable": True,
- },
- {
- "name": "name",
- "type": "str",
- "description": "Name of the item.",
- },
- {
- "name": "entity_uuid",
- "type": "uuid.UUID",
- "primary_key": False,
- "default_factory": "uuid.uuid4",
- "index": True,
- "nullable": False,
- },
- {
- "name": "unique_name",
- "type": "str",
- "unique": True,
- "index": True,
- },
- {
- "name": "description_field",
- "type": "Optional[str]",
- "nullable": True,
- "description": 'A field with "special" chars & new\nline.',
- },
- {"name": "count_val", "type": "int", "default": 0},
- {"name": "amount_val", "type": "float", "default": 0.0},
- {"name": "is_active_flag", "type": "bool", "default": True},
- {
- "name": "created_timestamp",
- "type": "datetime.datetime",
- "default_factory": "datetime.utcnow",
- "sa_column_kwargs": {"server_default": "FUNC.now()"},
- },
- {
- "name": "updated_timestamp",
- "type": "Optional[datetime.datetime]",
- "nullable": True,
- "sa_column_kwargs": {"onupdate": "FUNC.now()"},
- },
- {"name": "tags_list", "type": "List[str]"},
- {"name": "config_dict", "type": "Dict[str, Any]"},
- {
- "name": "related_id",
- "type": "Optional[int]",
- "foreign_key": "other_table.id",
- "nullable": True,
- },
- {"name": "optional_only_field", "type": "Optional[float]"},
- {
- "name": "class",
- "type": "str",
- "description": "A field named 'class'",
- },
- {
- "name": "json_payload",
- "type": "Dict",
- "sa_column_kwargs": {"sa_type": "JSON"},
- },
- {
- "name": "sqlalchemy_json_payload",
- "type": "List",
- "sa_column_kwargs": {"sa_type": "sqlalchemy.JSON"},
- },
- ],
- }
- ]
- }
- code = self.generator._generate_code_from_description(comprehensive_description)
- assert is_valid_python_code(code), (
- f"Generated code is not valid Python:\n{code}"
- )
- tree = ast.parse(code)
- return tree, comprehensive_description
-
- def test_comprehensive_model_generation_and_ast_validation(
- self, comprehensive_code_ast_and_desc
- ):
- tree, _ = comprehensive_code_ast_and_desc
-
- # 1. Test Imports
- expected_imports_from_typing = {"Optional", "List", "Dict", "Any"}
- actual_imports_from_typing = set()
- expected_imports_from_sqlmodel = {"SQLModel", "Field", "JSON"}
- actual_imports_from_sqlmodel = set()
- imported_modules = set()
-
- for node in tree.body:
- if isinstance(node, ast.ImportFrom):
- if node.module == "typing":
- actual_imports_from_typing.update(
- alias.name for alias in node.names
- )
- elif node.module == "sqlmodel":
- actual_imports_from_sqlmodel.update(
- alias.name for alias in node.names
- )
- elif isinstance(node, ast.Import):
- imported_modules.update(alias.name for alias in node.names)
-
- assert actual_imports_from_typing.issuperset(expected_imports_from_typing)
- assert actual_imports_from_sqlmodel.issuperset(expected_imports_from_sqlmodel)
- assert "uuid" in imported_modules
- assert "datetime" in imported_modules
- assert "sqlalchemy" in imported_modules
-
- # 2. Test Class Definition
- class_def_node = get_class_def_node(tree, "ComprehensiveItem")
- assert class_def_node is not None, "Class 'ComprehensiveItem' not found."
- assert class_def_node.name == "ComprehensiveItem"
- assert class_def_node.bases[0].id == "SQLModel"
- assert any(
- kw.arg == "table" and kw.value.value is True
- for kw in class_def_node.keywords
- )
-
- # 3. Test Fields
- fields_ast = {
- node.target.id: node
- for node in class_def_node.body
- if isinstance(node, ast.AnnAssign)
- }
-
- assert get_annotation_str(fields_ast["id"].annotation) == "Optional[int]"
- assert get_field_call_kwargs(fields_ast["id"].value) == {
- "primary_key": True,
- "nullable": True,
- }
-
- assert get_annotation_str(fields_ast["name"].annotation) == "str"
- assert get_field_call_kwargs(fields_ast["name"].value) == {
- "description": "Name of the item."
- }
-
- assert get_annotation_str(fields_ast["entity_uuid"].annotation) == "uuid.UUID"
- assert get_field_call_kwargs(fields_ast["entity_uuid"].value) == {
- "default_factory": "uuid.uuid4",
- "index": True,
- "nullable": False,
- }
-
- assert get_annotation_str(fields_ast["class_"].annotation) == "str"
- assert get_field_call_kwargs(fields_ast["class_"].value) == {
- "description": "A field named 'class'",
- "alias": "class",
- }
-
- assert get_annotation_str(fields_ast["json_payload"].annotation) == "Dict"
- assert get_field_call_kwargs(fields_ast["json_payload"].value) == {
- "sa_type": "JSON"
- }
-
- assert (
- get_annotation_str(fields_ast["sqlalchemy_json_payload"].annotation)
- == "List"
- )
- assert get_field_call_kwargs(fields_ast["sqlalchemy_json_payload"].value) == {
- "sa_type": "sqlalchemy.JSON"
- }
-
- @pytest.mark.parametrize(
- "test_id, model_desc, expected_exception, match_message",
- [
- (
- "instantiation_validation_error",
- {
- "model_name": "ValidationErrorModel",
- "fields": [
- {"name": "id", "type": "Optional[int]", "primary_key": True},
- {"name": "field_a", "type": "int"},
- ],
- },
- SQLModelInstantiationValidationError,
- r"Default instantiation of 'ValidationErrorModel' failed with ValidationError.*",
- ),
- (
- "instantiation_unexpected_error",
- {
- "model_name": "UnexpectedErrorModel",
- "fields": [
- {"name": "id", "type": "Optional[int]", "primary_key": True},
- {
- "name": "bad_field",
- "type": "int",
- "default_factory": "list.append",
- },
- ],
- },
- SQLModelCodeGeneratorError,
- r"failed instantiation with an unexpected error.*unbound method list\.append.*",
- ),
- ],
- )
- def test_generate_and_load_with_natural_errors(
- self, test_id, model_desc, expected_exception, match_message
- ):
- """Tests errors that occur naturally from the code generation and validation process."""
- with pytest.raises(expected_exception, match=match_message):
- self.generator._generate_and_load_class_from_description(
- {"sql_models": [model_desc]}
- )
-
- @pytest.mark.parametrize(
- "test_id, model_desc, mock_setup, expected_exception, match_message",
- [
- (
- "spec_creation_fails",
- {"model_name": "SpecFailModel", "fields": []},
- lambda mocks: (
- setattr(
- mocks["generate_code"],
- "return_value",
- "class SpecFailModel: pass",
- ),
- setattr(mocks["spec_from_file"], "return_value", None),
- ),
- SQLModelCodeGeneratorError,
- "Failed to create import spec",
- ),
- (
- "attr_not_a_class",
- {"model_name": "NotAClassModel", "fields": []},
- lambda mocks: (
- setattr(
- mocks["generate_code"],
- "return_value",
- "class NotAClassModel: pass",
- ),
- setattr(
- mocks["module_from_spec"].return_value, "NotAClassModel", 123
- ),
- ),
- SQLModelCodeGeneratorError,
- "Loaded attribute 'NotAClassModel' is not a class or not a subclass of SQLModel.",
- ),
- (
- "class_not_found_in_module",
- {"model_name": "MissingModel", "fields": []},
- lambda mocks: (
- setattr(
- mocks["generate_code"],
- "return_value",
- "class FoundModel(SQLModel): pass",
- ),
- setattr(
- mocks["module_from_spec"].return_value, "MissingModel", None
- ),
- ),
- SQLModelCodeGeneratorError,
- r"Class 'MissingModel' not found in dynamically loaded module",
- ),
- ],
- )
- @mock.patch(
- "extrai.core.sqlmodel_generator.SQLModelCodeGenerator._generate_code_from_description"
- )
- @mock.patch("importlib.util.module_from_spec")
- @mock.patch("importlib.util.spec_from_file_location")
- def test_load_process_with_mocked_errors(
- self,
- mock_spec_from_file,
- mock_module_from_spec,
- mock_generate_code,
- test_id,
- model_desc,
- mock_setup,
- expected_exception,
- match_message,
- ):
- """Tests errors in the loading process that require mocking importlib."""
- mock_spec = mock.Mock()
- mock_spec.loader = mock.Mock()
- mock_spec_from_file.return_value = mock_spec
-
- if mock_setup:
- mocks = {
- "spec_from_file": mock_spec_from_file,
- "module_from_spec": mock_module_from_spec,
- "generate_code": mock_generate_code,
- }
- mock_setup(mocks)
-
- with pytest.raises(expected_exception, match=match_message):
- self.generator._generate_and_load_class_from_description(
- {"sql_models": [model_desc]}
- )
-
- @pytest.mark.parametrize(
- "test_id, description, expected_substrings",
- [
- (
- "keyword_field_name_with_options",
- {
- "model_name": "KeywordFieldWithOptions",
- "fields": [
- {
- "name": "class",
- "type": "str",
- "field_options_str": 'Field(alias="class")',
- }
- ],
- },
- ['class_: str = Field(alias="class")'],
- ),
- (
- "custom_complex_import",
- {
- "model_name": "ComplexImportModel",
- "fields": [],
- "imports": ["from a.b import c as d"],
- },
- ["from a.b import c as d"],
- ),
- (
- "no_fields",
- {
- "model_name": "NoFieldsModel",
- "table_name": "no_fields",
- "fields": [],
- },
- ["class NoFieldsModel(SQLModel, table=True):", " pass"],
- ),
- (
- "no_model_description",
- {"model_name": "NoDescItem", "fields": [{"name": "id", "type": "int"}]},
- [],
- ), # second assertion checks for absence
- (
- "non_table_model_with_description",
- {
- "model_name": "NonTableModel",
- "description": "A test model.",
- "is_table_model": False,
- "fields": [],
- },
- [
- "class NonTableModel(SQLModel):",
- ' """A test model."""',
- " pass",
- ],
- ),
- (
- "multiple_base_classes_with_sqlmodel",
- {
- "model_name": "MultiBase",
- "is_table_model": True,
- "base_classes_str": ["CustomBase", "SQLModel"],
- "fields": [{"name": "id", "type": "int", "primary_key": True}],
- },
- ["class MultiBase(CustomBase, SQLModel, table=True):"],
- ),
- ],
- )
- def test_special_code_generation_cases(
- self, test_id, description, expected_substrings
- ):
- code = self.generator._generate_code_from_description(
- {"sql_models": [description]}
- )
- assert is_valid_python_code(code)
- for substring in expected_substrings:
- assert substring in code
-
- if test_id == "no_model_description":
- class_def_line_index = code.find(f"class {description['model_name']}")
- assert '"""' not in code[class_def_line_index:]
-
- def test_init_with_none_analytics_collector(self):
- generator_no_collector = SQLModelCodeGenerator(
- llm_client=self.mock_llm_client, analytics_collector=None
- )
- assert isinstance(
- generator_no_collector.analytics_collector, WorkflowAnalyticsCollector
- )
-
- @mock.patch("os.rmdir")
- @mock.patch("os.remove")
- def test_generate_load_and_validate_sqlmodel_class_e2e(
- self, mock_os_remove, mock_os_rmdir
- ):
- model_desc = {
- "sql_models": [
- {
- "model_name": "E2EProduct",
- "table_name": "e2e_products",
- "fields": [
- {
- "name": "id",
- "type": "Optional[int]",
- "primary_key": True,
- "default": None,
- "nullable": True,
- },
- {"name": "name", "type": "str", "default": "Default Product"},
- ],
- }
- ]
- }
-
- (
- loaded_models,
- _,
- ) = self.generator._generate_and_load_class_from_description(model_desc)
- loaded_product_model = loaded_models.get("E2EProduct")
-
- assert loaded_product_model is not None
- assert issubclass(loaded_product_model, SQLModel)
-
- engine = sqlmodel_create_engine("sqlite:///:memory:")
- SQLModel.metadata.create_all(engine)
-
- inspector = sqlalchemy_inspect(engine)
- assert "e2e_products" in inspector.get_table_names()
- engine.dispose()
-
- # Verify cleanup was called
- assert mock_os_remove.call_count > 0
- assert mock_os_rmdir.call_count > 0
-
- @pytest.mark.parametrize(
- "test_id, model_description, expected_code_snippets",
- [
- (
- "union_import",
- {
- "model_name": "M1",
- "fields": [{"name": "f", "type": "Union[int, str]"}],
- },
- ["from typing import Union"],
- ),
- (
- "complex_import",
- {"model_name": "M3", "imports": ["import my_library"], "fields": []},
- ["import my_library"],
- ),
- (
- "custom_sqlmodel_import_merge",
- {
- "model_name": "M4",
- "imports": ["from sqlmodel import Field, Session"],
- "fields": [{"name": "id", "type": "int", "primary_key": True}],
- },
- [r"from sqlmodel import Field, SQLModel, Session"],
- ),
- ],
- )
- def test_code_generation_import_logic(
- self, test_id, model_description, expected_code_snippets
- ):
- """Covers multiple import logic paths in _ImportManager and _FieldGenerator."""
- code = self.generator._generate_code_from_description(
- {"sql_models": [model_description]}
- )
- assert is_valid_python_code(code)
- for snippet in expected_code_snippets:
- assert re.search(snippet, code)
-
- @mock.patch("os.remove")
- @mock.patch("os.rmdir")
- def test_generate_and_load_class_cleanup_os_error(self, mock_rmdir, mock_remove):
- mock_remove.side_effect = OSError("Cannot remove file")
- mock_rmdir.side_effect = OSError("Cannot remove dir")
- model_desc = {
- "sql_models": [
- {
- "model_name": "CleanupFailModel",
- "fields": [
- {"name": "id", "type": "int", "primary_key": True, "default": 0}
- ],
- }
- ]
- }
-
- with mock.patch.object(self.generator.logger, "warning") as mock_logger_warning:
- self.generator._generate_and_load_class_from_description(model_desc)
-
- # Check that the logger's warning method was called with the expected messages
- call_args_list = mock_logger_warning.call_args_list
- warnings = [call[0][0] for call in call_args_list]
-
- assert any(
- "Could not remove temporary file" in warning for warning in warnings
- )
- assert any(
- "Could not remove temporary directory" in warning
- for warning in warnings
- )
-
- def test_generate_and_load_from_multiple_model_description(self):
- """Covers the 'models' list path in _generate_and_load_class_from_description."""
- multi_model_desc = {
- "sql_models": [
- {
- "model_name": "MultiModel1",
- "fields": [{"name": "id", "type": "int", "primary_key": True}],
- },
- {
- "model_name": "MultiModel2",
- "fields": [{"name": "name", "type": "str"}],
- },
- ]
- }
-
- # This test primarily ensures the loading logic correctly identifies the models to load.
- # We can mock the code generation part to simplify.
- generated_code = """
-from sqlmodel import SQLModel
-from typing import Optional
-class MultiModel1(SQLModel):
- id: Optional[int] = None
-class MultiModel2(SQLModel):
- name: Optional[str] = None
-"""
- with mock.patch.object(
- self.generator,
- "_generate_code_from_description",
- return_value=generated_code,
- ):
- (
- loaded_classes,
- _,
- ) = self.generator._generate_and_load_class_from_description(
- multi_model_desc
- )
- assert "MultiModel1" in loaded_classes
- assert "MultiModel2" in loaded_classes
- assert issubclass(loaded_classes["MultiModel1"], SQLModel)
- assert issubclass(loaded_classes["MultiModel2"], SQLModel)
-
- def test_code_generation_with_json_in_field_options_str(self):
- """Covers the JSON import logic when 'JSON' is in field_options_str."""
- model_description = {
- "sql_models": [
- {
- "model_name": "JsonFieldModel",
- "fields": [
- {
- "name": "data",
- "type": "Dict",
- "field_options_str": "Field(sa_column=Column(JSON))",
- }
- ],
- }
- ]
- }
- code = self.generator._generate_code_from_description(model_description)
- assert is_valid_python_code(code)
-
- tree = ast.parse(code)
-
- # Verify import
- json_imported = False
- for node in tree.body:
- if isinstance(node, ast.ImportFrom) and node.module == "sqlmodel":
- if any(alias.name == "JSON" for alias in node.names):
- json_imported = True
- break
- assert json_imported, (
- "from sqlmodel import JSON was not found in generated code."
- )
-
- @mock.patch(
- "extrai.core.sqlmodel_generator.SQLModelCodeGenerator._import_module_from_path"
- )
- def test_generate_and_load_catches_and_wraps_generic_exception(
- self, mock_import_module
- ):
- """Covers the generic exception handling block in _generate_and_load_class_from_description."""
- mock_import_module.side_effect = Exception("A mocked generic error occurred")
- model_desc = {
- "sql_models": [{"model_name": "GenericExceptionTestModel", "fields": []}]
- }
-
- with pytest.raises(SQLModelCodeGeneratorError) as exc_info:
- self.generator._generate_and_load_class_from_description(model_desc)
-
- # Check that the outer exception message is correct
- assert (
- "Failed to dynamically generate and load SQLModel class(es): A mocked generic error occurred"
- in str(exc_info.value)
- )
-
- # Check that the generated code is included in the error message
- assert "class GenericExceptionTestModel(SQLModel, table=True):" in str(
- exc_info.value
- )
-
- # Check that the original exception is preserved in the cause chain
- assert isinstance(exc_info.value.__cause__, Exception)
- assert str(exc_info.value.__cause__) == "A mocked generic error occurred"
-
- def test_code_generation_with_relationship_import(self):
- """Covers the Relationship import logic when 'Relationship' is in field_options_str."""
- model_description = {
- "sql_models": [
- {
- "model_name": "Invoice",
- "fields": [
- {"name": "id", "type": "Optional[int]", "primary_key": True},
- {
- "name": "line_items",
- "type": "List['LineItem']",
- "field_options_str": 'Relationship(back_populates="invoice")',
- },
- ],
- },
- {
- "model_name": "LineItem",
- "fields": [
- {"name": "id", "type": "Optional[int]", "primary_key": True},
- {
- "name": "invoice_id",
- "type": "Optional[int]",
- "foreign_key": "invoice.id",
- },
- {
- "name": "invoice",
- "type": "Optional['Invoice']",
- "field_options_str": 'Relationship(back_populates="line_items")',
- },
- ],
- },
- ]
- }
- code = self.generator._generate_code_from_description(model_description)
- assert is_valid_python_code(code)
-
- assert re.search(r"from sqlmodel import .*Relationship", code)
-
- def test_load_fails_when_no_models_in_description(self):
- """Covers the error path when the 'sql_models' list is empty."""
- model_desc = {"sql_models": []}
- with pytest.raises(
- SQLModelCodeGeneratorError,
- match="No models found in the 'sql_models' list from the LLM description.",
- ):
- self.generator._generate_and_load_class_from_description(model_desc)
diff --git a/tests/core/test_sqlmodel_generator_codegen.py b/tests/core/test_sqlmodel_generator_codegen.py
new file mode 100644
index 0000000..1f2005c
--- /dev/null
+++ b/tests/core/test_sqlmodel_generator_codegen.py
@@ -0,0 +1,625 @@
+import ast
+import pytest
+from unittest import mock
+from sqlalchemy import inspect as sqlalchemy_inspect
+from sqlmodel import SQLModel, create_engine as sqlmodel_create_engine
+
+from extrai.core.errors import (
+ SQLModelCodeGeneratorError,
+ SQLModelInstantiationValidationError,
+)
+from extrai.core.sqlmodel_generator import SQLModelCodeGenerator
+from extrai.core.analytics_collector import (
+ WorkflowAnalyticsCollector,
+)
+from tests.core.helpers.mock_llm_clients import MockLLMClientSqlGen
+from tests.core.helpers.sqlmodel_generator_test_utils import (
+ is_valid_python_code,
+ get_field_call_kwargs,
+ get_annotation_str,
+ get_class_def_node,
+)
+
+
+class TestSQLModelCodeGeneratorCodeGen:
+ @pytest.fixture(autouse=True)
+ def clear_sqlmodel_metadata(self):
+ """Fixture to clear SQLModel metadata before each test to prevent table redefinition errors."""
+ original_metadata_tables = dict(SQLModel.metadata.tables)
+ SQLModel.metadata.clear()
+ yield
+ SQLModel.metadata.clear()
+ for table_obj in original_metadata_tables.values():
+ table_obj.to_metadata(SQLModel.metadata)
+
+ def setup_method(self):
+ self.mock_llm_client = MockLLMClientSqlGen()
+ self.mock_analytics_collector = mock.Mock(spec=WorkflowAnalyticsCollector)
+ self.generator = SQLModelCodeGenerator(
+ llm_client=self.mock_llm_client,
+ analytics_collector=self.mock_analytics_collector,
+ )
+ SQLModelCodeGenerator._sqlmodel_description_schema_cache = None
+
+ @pytest.fixture(scope="function")
+ def comprehensive_code_ast_and_desc(self):
+ comprehensive_description = {
+ "sql_models": [
+ {
+ "table_name": "comprehensive_items",
+ "model_name": "ComprehensiveItem",
+ "description": 'A comprehensive model with "quotes" and\nnewlines.',
+ "fields": [
+ {
+ "name": "id",
+ "type": "Optional[int]",
+ "primary_key": True,
+ "nullable": True,
+ },
+ {
+ "name": "entity_uuid",
+ "type": "uuid.UUID",
+ "default_factory": "uuid.uuid4",
+ "index": True,
+ "nullable": False,
+ },
+ {
+ "name": "unique_name",
+ "type": "str",
+ "unique": True,
+ "index": True,
+ },
+ {
+ "name": "description_field",
+ "type": "Optional[str]",
+ "nullable": True,
+ "description": 'A field with "special" chars & new\nline.',
+ },
+ {"name": "count_val", "type": "int", "default": 0},
+ {"name": "is_active_flag", "type": "bool", "default": True},
+ {
+ "name": "created_timestamp",
+ "type": "datetime.datetime",
+ "default_factory": "datetime.utcnow",
+ "sa_column_kwargs": {"server_default": "FUNC.now()"},
+ },
+ {
+ "name": "updated_timestamp",
+ "type": "Optional[datetime.datetime]",
+ "nullable": True,
+ "sa_column_kwargs": {"onupdate": "FUNC.now()"},
+ },
+ {"name": "tags_list", "type": "List[str]"},
+ {"name": "config_dict", "type": "Dict[str, Any]"},
+ {
+ "name": "related_id",
+ "type": "Optional[int]",
+ "foreign_key": "other_table.id",
+ "nullable": True,
+ },
+ {
+ "name": "class",
+ "type": "str",
+ "description": "A field named 'class'",
+ },
+ {
+ "name": "json_payload",
+ "type": "Dict",
+ "sa_column_kwargs": {"sa_type": "JSON"},
+ },
+ {
+ "name": "sqlalchemy_json_payload",
+ "type": "List",
+ "sa_column_kwargs": {"sa_type": "sqlalchemy.JSON"},
+ },
+ ],
+ }
+ ]
+ }
+ code = self.generator._generate_code_from_description(comprehensive_description)
+ assert is_valid_python_code(code), (
+ f"Generated code is not valid Python:\n{code}"
+ )
+ tree = ast.parse(code)
+ return tree, comprehensive_description
+
+ def test_comprehensive_model_generation_and_ast_validation(
+ self, comprehensive_code_ast_and_desc
+ ):
+ tree, _ = comprehensive_code_ast_and_desc
+
+ # 1. Test Imports
+ expected_imports_from_typing = {"Optional", "List", "Dict", "Any"}
+ actual_imports_from_typing = set()
+ expected_imports_from_sqlmodel = {"SQLModel", "Field", "JSON"}
+ actual_imports_from_sqlmodel = set()
+ imported_modules = set()
+
+ for node in tree.body:
+ if isinstance(node, ast.ImportFrom):
+ if node.module == "typing":
+ actual_imports_from_typing.update(
+ alias.name for alias in node.names
+ )
+ elif node.module == "sqlmodel":
+ actual_imports_from_sqlmodel.update(
+ alias.name for alias in node.names
+ )
+ elif isinstance(node, ast.Import):
+ imported_modules.update(alias.name for alias in node.names)
+
+ assert actual_imports_from_typing.issuperset(expected_imports_from_typing)
+ assert actual_imports_from_sqlmodel.issuperset(expected_imports_from_sqlmodel)
+ assert "uuid" in imported_modules
+ assert "datetime" in imported_modules
+ assert "sqlalchemy" in imported_modules
+
+ # 2. Test Class Definition
+ class_def_node = get_class_def_node(tree, "ComprehensiveItem")
+ assert class_def_node is not None, "Class 'ComprehensiveItem' not found."
+ assert class_def_node.name == "ComprehensiveItem"
+ assert class_def_node.bases[0].id == "SQLModel"
+ assert any(
+ kw.arg == "table" and kw.value.value is True
+ for kw in class_def_node.keywords
+ )
+
+ # 3. Test Fields
+ fields_ast = {
+ node.target.id: node
+ for node in class_def_node.body
+ if isinstance(node, ast.AnnAssign)
+ }
+
+ assert get_annotation_str(fields_ast["id"].annotation) == "Optional[int]"
+ assert get_field_call_kwargs(fields_ast["id"].value) == {
+ "primary_key": True,
+ "nullable": True,
+ }
+
+ assert get_annotation_str(fields_ast["entity_uuid"].annotation) == "uuid.UUID"
+ assert get_field_call_kwargs(fields_ast["entity_uuid"].value) == {
+ "default_factory": "uuid.uuid4",
+ "index": True,
+ "nullable": False,
+ }
+
+ assert get_annotation_str(fields_ast["class_"].annotation) == "str"
+ assert get_field_call_kwargs(fields_ast["class_"].value) == {
+ "description": "A field named 'class'",
+ "alias": "class",
+ }
+
+ assert get_annotation_str(fields_ast["json_payload"].annotation) == "Dict"
+ assert get_field_call_kwargs(fields_ast["json_payload"].value) == {
+ "sa_type": "JSON"
+ }
+
+ assert (
+ get_annotation_str(fields_ast["sqlalchemy_json_payload"].annotation)
+ == "List"
+ )
+ assert get_field_call_kwargs(fields_ast["sqlalchemy_json_payload"].value) == {
+ "sa_type": "sqlalchemy.JSON"
+ }
+
+ @pytest.mark.parametrize(
+ "test_id, model_desc, expected_exception, match_message",
+ [
+ (
+ "instantiation_validation_error",
+ {
+ "model_name": "ValidationErrorModel",
+ "fields": [
+ {"name": "id", "type": "Optional[int]", "primary_key": True},
+ {"name": "field_a", "type": "int"},
+ ],
+ },
+ SQLModelInstantiationValidationError,
+ r"Default instantiation of 'ValidationErrorModel' failed with ValidationError.*",
+ ),
+ (
+ "instantiation_unexpected_error",
+ {
+ "model_name": "UnexpectedErrorModel",
+ "fields": [
+ {"name": "id", "type": "Optional[int]", "primary_key": True},
+ {
+ "name": "bad_field",
+ "type": "int",
+ "default_factory": "list.append",
+ },
+ ],
+ },
+ SQLModelCodeGeneratorError,
+ r"failed instantiation with an unexpected error.*unbound method list\.append.*",
+ ),
+ ],
+ )
+ def test_generate_and_load_with_natural_errors(
+ self, test_id, model_desc, expected_exception, match_message
+ ):
+ """Tests errors that occur naturally from the code generation and validation process."""
+ with pytest.raises(expected_exception, match=match_message):
+ self.generator._generate_and_load_class_from_description(
+ {"sql_models": [model_desc]}
+ )
+
+ @pytest.mark.parametrize(
+ "test_id, model_desc, mock_setup, expected_exception, match_message",
+ [
+ (
+ "spec_creation_fails",
+ {"model_name": "SpecFailModel", "fields": []},
+ lambda mocks: (
+ setattr(
+ mocks["generate_code"],
+ "return_value",
+ "class SpecFailModel: pass",
+ ),
+ setattr(mocks["spec_from_file"], "return_value", None),
+ ),
+ SQLModelCodeGeneratorError,
+ "Failed to create import spec",
+ ),
+ (
+ "attr_not_a_class",
+ {"model_name": "NotAClassModel", "fields": []},
+ lambda mocks: (
+ setattr(
+ mocks["generate_code"],
+ "return_value",
+ "class NotAClassModel: pass",
+ ),
+ setattr(
+ mocks["module_from_spec"].return_value, "NotAClassModel", 123
+ ),
+ ),
+ SQLModelCodeGeneratorError,
+ "Loaded attribute 'NotAClassModel' is not a class or not a subclass of SQLModel.",
+ ),
+ (
+ "class_not_found_in_module",
+ {"model_name": "MissingModel", "fields": []},
+ lambda mocks: (
+ setattr(
+ mocks["generate_code"],
+ "return_value",
+ "class FoundModel(SQLModel): pass",
+ ),
+ setattr(
+ mocks["module_from_spec"].return_value, "MissingModel", None
+ ),
+ ),
+ SQLModelCodeGeneratorError,
+ r"Class 'MissingModel' not found in dynamically loaded module",
+ ),
+ ],
+ )
+ @mock.patch(
+ "extrai.core.sqlmodel_generator.SQLModelCodeGenerator._generate_code_from_description"
+ )
+ @mock.patch("importlib.util.module_from_spec")
+ @mock.patch("importlib.util.spec_from_file_location")
+ def test_load_process_with_mocked_errors(
+ self,
+ mock_spec_from_file,
+ mock_module_from_spec,
+ mock_generate_code,
+ test_id,
+ model_desc,
+ mock_setup,
+ expected_exception,
+ match_message,
+ ):
+ """Tests errors in the loading process that require mocking importlib."""
+ mock_spec = mock.Mock()
+ mock_spec.loader = mock.Mock()
+ mock_spec_from_file.return_value = mock_spec
+
+ if mock_setup:
+ mocks = {
+ "spec_from_file": mock_spec_from_file,
+ "module_from_spec": mock_module_from_spec,
+ "generate_code": mock_generate_code,
+ }
+ mock_setup(mocks)
+
+ with pytest.raises(expected_exception, match=match_message):
+ self.generator._generate_and_load_class_from_description(
+ {"sql_models": [model_desc]}
+ )
+
+ @pytest.mark.parametrize(
+ "test_id, description, expected_substrings",
+ [
+ (
+ "keyword_field_name_with_options",
+ {
+ "model_name": "KeywordFieldWithOptions",
+ "fields": [
+ {
+ "name": "class",
+ "type": "str",
+ "field_options_str": 'Field(alias="class")',
+ }
+ ],
+ },
+ ['class_: str = Field(alias="class")'],
+ ),
+ (
+ "custom_complex_import",
+ {
+ "model_name": "ComplexImportModel",
+ "fields": [],
+ "imports": ["from a.b import c as d"],
+ },
+ ["from a.b import c as d"],
+ ),
+ (
+ "no_fields",
+ {
+ "model_name": "NoFieldsModel",
+ "table_name": "no_fields",
+ "fields": [],
+ },
+ ["class NoFieldsModel(SQLModel, table=True):", " pass"],
+ ),
+ (
+ "no_model_description",
+ {"model_name": "NoDescItem", "fields": [{"name": "id", "type": "int"}]},
+ [],
+ ), # second assertion checks for absence
+ (
+ "non_table_model_with_description",
+ {
+ "model_name": "NonTableModel",
+ "description": "A test model.",
+ "is_table_model": False,
+ "fields": [],
+ },
+ [
+ "class NonTableModel(SQLModel):",
+ ' """A test model."""',
+ " pass",
+ ],
+ ),
+ (
+ "multiple_base_classes_with_sqlmodel",
+ {
+ "model_name": "MultiBase",
+ "is_table_model": True,
+ "base_classes_str": ["CustomBase", "SQLModel"],
+ "fields": [{"name": "id", "type": "int", "primary_key": True}],
+ },
+ ["class MultiBase(CustomBase, SQLModel, table=True):"],
+ ),
+ (
+ "union_import",
+ {
+ "model_name": "M1",
+ "fields": [{"name": "f", "type": "Union[int, str]"}],
+ },
+ ["from typing import Union"],
+ ),
+ (
+ "complex_import",
+ {"model_name": "M3", "imports": ["import my_library"], "fields": []},
+ ["import my_library"],
+ ),
+ (
+ "custom_sqlmodel_import_merge",
+ {
+ "model_name": "M4",
+ "imports": ["from sqlmodel import Field, Session"],
+ "fields": [{"name": "id", "type": "int", "primary_key": True}],
+ },
+ ["from sqlmodel import Field, SQLModel, Session"],
+ ),
+ (
+ "json_in_field_options",
+ {
+ "model_name": "JsonFieldModel",
+ "fields": [
+ {
+ "name": "data",
+ "type": "Dict",
+ "field_options_str": "Field(sa_column=Column(JSON))",
+ }
+ ],
+ },
+ ["from sqlmodel import JSON"],
+ ),
+ (
+ "relationship_import",
+ {
+ "model_name": "Invoice",
+ "fields": [
+ {"name": "id", "type": "Optional[int]", "primary_key": True},
+ {
+ "name": "line_items",
+ "type": "List['LineItem']",
+ "field_options_str": 'Relationship(back_populates="invoice")',
+ },
+ ],
+ },
+ ["from sqlmodel import Field, Relationship, SQLModel"],
+ ),
+ ],
+ )
+ def test_special_code_generation_cases(
+ self, test_id, description, expected_substrings
+ ):
+ code = self.generator._generate_code_from_description(
+ {"sql_models": [description]}
+ )
+ assert is_valid_python_code(code)
+ for substring in expected_substrings:
+ assert substring in code
+
+ if test_id == "no_model_description":
+ class_def_line_index = code.find(f"class {description['model_name']}")
+ assert '"""' not in code[class_def_line_index:]
+
+ def test_init_with_none_analytics_collector(self):
+ generator_no_collector = SQLModelCodeGenerator(
+ llm_client=self.mock_llm_client, analytics_collector=None
+ )
+ assert isinstance(
+ generator_no_collector.analytics_collector, WorkflowAnalyticsCollector
+ )
+
+ @mock.patch("os.rmdir")
+ @mock.patch("os.remove")
+ def test_generate_load_and_validate_sqlmodel_class_e2e(
+ self, mock_os_remove, mock_os_rmdir
+ ):
+ model_desc = {
+ "sql_models": [
+ {
+ "model_name": "E2EProduct",
+ "table_name": "e2e_products",
+ "fields": [
+ {
+ "name": "id",
+ "type": "Optional[int]",
+ "primary_key": True,
+ "default": None,
+ "nullable": True,
+ },
+ {"name": "name", "type": "str", "default": "Default Product"},
+ ],
+ }
+ ]
+ }
+
+ (
+ loaded_models,
+ _,
+ ) = self.generator._generate_and_load_class_from_description(model_desc)
+ loaded_product_model = loaded_models.get("E2EProduct")
+
+ assert loaded_product_model is not None
+ assert issubclass(loaded_product_model, SQLModel)
+
+ engine = sqlmodel_create_engine("sqlite:///:memory:")
+ SQLModel.metadata.create_all(engine)
+
+ inspector = sqlalchemy_inspect(engine)
+ assert "e2e_products" in inspector.get_table_names()
+ engine.dispose()
+
+ # Verify cleanup was called
+ assert mock_os_remove.call_count > 0
+ assert mock_os_rmdir.call_count > 0
+
+ @mock.patch("os.remove")
+ @mock.patch("os.rmdir")
+ def test_generate_and_load_class_cleanup_os_error(self, mock_rmdir, mock_remove):
+ mock_remove.side_effect = OSError("Cannot remove file")
+ mock_rmdir.side_effect = OSError("Cannot remove dir")
+ model_desc = {
+ "sql_models": [
+ {
+ "model_name": "CleanupFailModel",
+ "fields": [
+ {"name": "id", "type": "int", "primary_key": True, "default": 0}
+ ],
+ }
+ ]
+ }
+
+ with mock.patch.object(self.generator.logger, "warning") as mock_logger_warning:
+ self.generator._generate_and_load_class_from_description(model_desc)
+
+ # Check that the logger's warning method was called with the expected messages
+ call_args_list = mock_logger_warning.call_args_list
+ warnings = [call[0][0] for call in call_args_list]
+
+ assert any(
+ "Could not remove temporary file" in warning for warning in warnings
+ )
+ assert any(
+ "Could not remove temporary directory" in warning
+ for warning in warnings
+ )
+
+ def test_generate_and_load_from_multiple_model_description(self):
+ """Covers the 'models' list path in _generate_and_load_class_from_description."""
+ multi_model_desc = {
+ "sql_models": [
+ {
+ "model_name": "MultiModel1",
+ "fields": [{"name": "id", "type": "int", "primary_key": True}],
+ },
+ {
+ "model_name": "MultiModel2",
+ "fields": [{"name": "name", "type": "str"}],
+ },
+ ]
+ }
+
+ # This test primarily ensures the loading logic correctly identifies the models to load.
+ # We can mock the code generation part to simplify.
+ generated_code = """
+from sqlmodel import SQLModel
+from typing import Optional
+class MultiModel1(SQLModel):
+ id: Optional[int] = None
+class MultiModel2(SQLModel):
+ name: Optional[str] = None
+"""
+ with mock.patch.object(
+ self.generator,
+ "_generate_code_from_description",
+ return_value=generated_code,
+ ):
+ (
+ loaded_classes,
+ _,
+ ) = self.generator._generate_and_load_class_from_description(
+ multi_model_desc
+ )
+ assert "MultiModel1" in loaded_classes
+ assert "MultiModel2" in loaded_classes
+ assert issubclass(loaded_classes["MultiModel1"], SQLModel)
+ assert issubclass(loaded_classes["MultiModel2"], SQLModel)
+
+ @mock.patch(
+ "extrai.core.sqlmodel_generator.SQLModelCodeGenerator._import_module_from_path"
+ )
+ def test_generate_and_load_catches_and_wraps_generic_exception(
+ self, mock_import_module
+ ):
+ """Covers the generic exception handling block in _generate_and_load_class_from_description."""
+ mock_import_module.side_effect = Exception("A mocked generic error occurred")
+ model_desc = {
+ "sql_models": [{"model_name": "GenericExceptionTestModel", "fields": []}]
+ }
+
+ with pytest.raises(SQLModelCodeGeneratorError) as exc_info:
+ self.generator._generate_and_load_class_from_description(model_desc)
+
+ # Check that the outer exception message is correct
+ assert (
+ "Failed to dynamically generate and load SQLModel class(es): A mocked generic error occurred"
+ in str(exc_info.value)
+ )
+
+ # Check that the generated code is included in the error message
+ assert "class GenericExceptionTestModel(SQLModel, table=True):" in str(
+ exc_info.value
+ )
+
+ # Check that the original exception is preserved in the cause chain
+ assert isinstance(exc_info.value.__cause__, Exception)
+ assert str(exc_info.value.__cause__) == "A mocked generic error occurred"
+
+ def test_load_fails_when_no_models_in_description(self):
+ """Covers the error path when the 'sql_models' list is empty."""
+ model_desc = {"sql_models": []}
+ with pytest.raises(
+ SQLModelCodeGeneratorError,
+ match="No models found in the 'sql_models' list from the LLM description.",
+ ):
+ self.generator._generate_and_load_class_from_description(model_desc)
diff --git a/tests/core/workflow_orchestrator/test_hierarchical_extractor.py b/tests/core/workflow_orchestrator/test_hierarchical_extractor.py
index d491feb..136db6e 100644
--- a/tests/core/workflow_orchestrator/test_hierarchical_extractor.py
+++ b/tests/core/workflow_orchestrator/test_hierarchical_extractor.py
@@ -1,5 +1,3 @@
-# tests/core/workflow_orchestrator/test_hierarchical_extractor.py
-
import unittest
import json
from unittest import mock
@@ -8,8 +6,7 @@
from sqlmodel import SQLModel, create_engine, Session as SQLModelSession
from extrai.core.workflow_orchestrator import WorkflowOrchestrator
-from extrai.core.errors import WorkflowError
-from extrai.core.example_json_generator import ExampleGenerationError
+from extrai.core.errors import LLMInteractionError
from tests.core.helpers.orchestrator_test_models import (
SimpleModel,
ParentModel,
@@ -22,9 +19,8 @@
class TestHierarchicalExtractor(unittest.IsolatedAsyncioTestCase):
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- @mock.patch("extrai.core.workflow_orchestrator.generate_llm_schema_from_models")
- def setUp(self, mock_generate_llm_schema, mock_discover_sqlmodels):
+ @mock.patch("extrai.core.model_registry.SchemaInspector")
+ def setUp(self, MockSchemaInspector):
self.mock_llm_client = MockLLMClient()
self.engine = create_engine("sqlite:///:memory:")
SQLModel.metadata.create_all(self.engine)
@@ -40,95 +36,25 @@ def setUp(self, mock_generate_llm_schema, mock_discover_sqlmodels):
{"schema_for_prompt": "mock_schema"}
)
- mock_discover_sqlmodels.return_value = self.mock_discovered_sqlmodels
- mock_generate_llm_schema.return_value = self.mock_prompt_llm_schema_str
+ mock_inspector = MockSchemaInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = (
+ self.mock_discovered_sqlmodels
+ )
+ mock_inspector.generate_llm_schema_from_models.return_value = (
+ self.mock_prompt_llm_schema_str
+ )
def tearDown(self):
self.db_session.close()
SQLModel.metadata.drop_all(self.engine)
- @mock.patch("extrai.core.workflow_orchestrator.ExampleJSONGenerator")
- async def test_prepare_extraction_example_auto_generation_failure(
- self, mocked_example_generator
- ):
- # This test covers the error handling in `_prepare_extraction_example` (line 282)
- mock_generator_instance = mocked_example_generator.return_value
- mock_generator_instance.generate_example = AsyncMock(
- side_effect=ExampleGenerationError("Generation failed")
- )
-
- orchestrator = WorkflowOrchestrator(
- root_sqlmodel_class=SimpleModel,
- llm_client=self.mock_llm_client,
- use_hierarchical_extraction=True,
- )
-
- with self.assertRaisesRegex(
- WorkflowError,
- "Failed to auto-generate extraction example: Generation failed",
- ):
- await orchestrator.synthesize(["Some input text"], self.db_session)
-
- mocked_example_generator.assert_called_once()
- mock_generator_instance.generate_example.assert_called_once()
-
- @mock.patch("builtins.print")
- async def test_hierarchical_extraction_handles_missing_temp_id_and_duplicates(
- self, mock_print
- ):
- # This test covers lines 389-395 (missing temp_id and duplicate entities)
- orchestrator = WorkflowOrchestrator(
- root_sqlmodel_class=ParentModel,
- llm_client=self.mock_llm_client,
- use_hierarchical_extraction=True,
- )
-
- # Entity with a temp_id, one without, and a duplicate
- entity1 = {"_type": "ParentModel", "_temp_id": "p1", "name": "Parent 1"}
- entity_no_id = {"_type": "ParentModel", "name": "Parent No ID"}
- entity_duplicate = {
- "_type": "ParentModel",
- "_temp_id": "p1",
- "name": "Duplicate",
- }
-
- def mock_extraction_cycle(*args, **kwargs):
- return [entity1, entity_no_id, entity_duplicate]
-
- with (
- mock.patch(
- "extrai.core.workflow_orchestrator.discover_sqlmodels_from_root",
- return_value=[ParentModel],
- ),
- mock.patch.object(
- orchestrator,
- "_run_single_extraction_cycle",
- side_effect=mock_extraction_cycle,
- ),
- ):
- final_list = await orchestrator._execute_hierarchical_extraction(
- input_strings=["Some text"],
- current_extraction_example_json="",
- custom_extraction_process="",
- custom_extraction_guidelines="",
- custom_final_checklist="",
- )
-
- # Only the first valid entity should be in the final list
- self.assertEqual(len(final_list), 1)
- self.assertIn(entity1, final_list)
- self.assertNotIn(entity_no_id, final_list)
- # Ensure the duplicate did not overwrite the original
- self.assertEqual(final_list[0]["name"], "Parent 1")
-
async def test_synthesize_and_hierarchical_extraction_full_traversal(self):
- # This test covers the main logic of `_execute_hierarchical_extraction` (lines 342-395)
- # and ensures it's called correctly by `synthesize`.
orchestrator = WorkflowOrchestrator(
root_sqlmodel_class=ParentModel,
llm_client=self.mock_llm_client,
use_hierarchical_extraction=True,
)
+ orchestrator.model_registry.models = [ParentModel, ChildModel]
parent_entity = {"_type": "ParentModel", "_temp_id": "p1", "name": "Parent 1"}
child_entity = {
@@ -139,30 +65,30 @@ async def test_synthesize_and_hierarchical_extraction_full_traversal(self):
}
final_entities = [parent_entity, child_entity]
- # Mock the extraction cycle to simulate LLM responses for each model type
def mock_extraction_cycle(*args, **kwargs):
- system_prompt = args[0]
+ system_prompt = kwargs.get("system_prompt") or args[0]
if "extract **only** entities of type 'ParentModel'" in system_prompt:
return [parent_entity]
if "extract **only** entities of type 'ChildModel'" in system_prompt:
return [child_entity]
return []
- # We need to control the order of models processed in the loop
+ orchestrator.model_registry.inspector.discover_sqlmodels_from_root = mock.Mock(
+ return_value=[ParentModel, ChildModel]
+ )
+
with (
- mock.patch(
- "extrai.core.workflow_orchestrator.discover_sqlmodels_from_root",
- return_value=[ParentModel, ChildModel],
- ),
mock.patch.object(
- orchestrator,
- "_run_single_extraction_cycle",
+ orchestrator.pipeline.llm_runner,
+ "run_extraction_cycle",
side_effect=mock_extraction_cycle,
) as mock_run_cycle,
mock.patch.object(
- orchestrator, "_prepare_extraction_example", new_callable=AsyncMock
+ orchestrator.pipeline.context_preparer,
+ "prepare_example",
+ new_callable=AsyncMock,
) as mock_prepare,
- mock.patch.object(orchestrator, "_hydrate_results") as mock_hydrate,
+ mock.patch.object(orchestrator.result_processor, "hydrate") as mock_hydrate,
):
mock_prepare.return_value = "mock_example_json"
mock_hydrate.return_value = [
@@ -172,50 +98,20 @@ def mock_extraction_cycle(*args, **kwargs):
await orchestrator.synthesize(["Some text"], self.db_session)
- # Verify that the extraction cycle was called for each model
self.assertEqual(mock_run_cycle.call_count, 2)
- # Check that the context for the second call (ChildModel) contains the first entity
- second_call_system_prompt = mock_run_cycle.call_args_list[1].args[0]
- self.assertIn("'ChildModel'", second_call_system_prompt)
- self.assertIn(
- json.dumps([parent_entity], indent=2), second_call_system_prompt
+ call_args = mock_run_cycle.call_args_list[1]
+ second_call_system_prompt = (
+ call_args.kwargs.get("system_prompt") or call_args.args[0]
)
+ self.assertIn("'ChildModel'", second_call_system_prompt)
+ self.assertIn("Parent 1", second_call_system_prompt)
- # Verify that the final list was passed to the hydrator
mock_hydrate.assert_called_once()
- # The order might not be guaranteed, so we check for content equivalence
hydrator_arg = mock_hydrate.call_args[0][0]
self.assertCountEqual(hydrator_arg, final_entities)
- async def test_run_single_extraction_cycle_success(self):
- # This test covers a successful run of `_run_single_extraction_cycle` (lines 408-447)
- orchestrator = WorkflowOrchestrator(
- root_sqlmodel_class=SimpleModel,
- llm_client=self.mock_llm_client,
- num_llm_revisions=2,
- use_hierarchical_extraction=True,
- )
-
- llm_revision = {"results": [{"_type": "SimpleModel", "name": "Test"}]}
- self.mock_llm_client.set_revisions_to_return([llm_revision] * 2)
-
- consensus_output = [{"_type": "SimpleModel", "name": "Consensus Result"}]
- mock_consensus_details = {"revisions_processed": 2}
-
- with mock.patch.object(
- orchestrator.json_consensus,
- "get_consensus",
- return_value=(consensus_output, mock_consensus_details),
- ) as mock_get_consensus:
- result = await orchestrator._run_single_extraction_cycle("system", "user")
-
- self.assertEqual(self.mock_llm_client.call_count, 2)
- mock_get_consensus.assert_called_once()
- self.assertEqual(result, consensus_output)
-
async def test_run_single_extraction_cycle_llm_failure(self):
- # This test covers error handling in `_run_single_extraction_cycle`
orchestrator = WorkflowOrchestrator(
root_sqlmodel_class=SimpleModel,
llm_client=self.mock_llm_client,
@@ -226,10 +122,12 @@ async def test_run_single_extraction_cycle_llm_failure(self):
self.mock_llm_client.set_should_raise_exception(ValueError("LLM client failed"))
with self.assertRaisesRegex(
- Exception,
+ LLMInteractionError,
"An unexpected error occurred during LLM interaction: LLM client failed",
):
- await orchestrator._run_single_extraction_cycle("system", "user")
+ await orchestrator.pipeline.llm_runner.run_extraction_cycle(
+ "system", "user"
+ )
self.assertEqual(self.mock_llm_client.call_count, 1)
diff --git a/tests/core/workflow_orchestrator/test_integration.py b/tests/core/workflow_orchestrator/test_integration.py
new file mode 100644
index 0000000..b1f9bf2
--- /dev/null
+++ b/tests/core/workflow_orchestrator/test_integration.py
@@ -0,0 +1,135 @@
+import unittest
+import json
+from unittest import mock
+from unittest.mock import AsyncMock
+
+from sqlmodel import SQLModel, create_engine, Session as SQLModelSession
+
+from extrai.core.workflow_orchestrator import WorkflowOrchestrator
+
+
+from tests.core.helpers.orchestrator_test_models import DepartmentModel, EmployeeModel
+from tests.core.helpers.mock_llm_clients import (
+ MockLLMClientForWorkflow as MockLLMClient,
+)
+
+
+class TestWorkflowOrchestratorExecution(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ self.mock_llm_client1 = MockLLMClient()
+ self.mock_llm_client2 = MockLLMClient()
+
+ self.discovered_sqlmodels_for_execution = [DepartmentModel, EmployeeModel]
+ self.prompt_llm_schema_for_execution = json.dumps(
+ {"schema_for_prompt": "mock_llm_prompt_schema"}
+ )
+
+ self.patcher_inspector = mock.patch(
+ "extrai.core.model_registry.SchemaInspector"
+ )
+ self.MockSchemaInspector = self.patcher_inspector.start()
+ mock_inspector_instance = self.MockSchemaInspector.return_value
+ mock_inspector_instance.discover_sqlmodels_from_root.return_value = (
+ self.discovered_sqlmodels_for_execution
+ )
+ mock_inspector_instance.generate_llm_schema_from_models.return_value = (
+ self.prompt_llm_schema_for_execution
+ )
+
+ self.orchestrator = WorkflowOrchestrator(
+ root_sqlmodel_class=DepartmentModel,
+ llm_client=[self.mock_llm_client1, self.mock_llm_client2],
+ num_llm_revisions=2,
+ max_validation_retries_per_revision=1,
+ )
+
+ self.engine = create_engine("sqlite:///:memory:")
+ SQLModel.metadata.create_all(self.engine)
+ self.db_session: SQLModelSession = SQLModelSession(self.engine)
+
+ def tearDown(self):
+ self.patcher_inspector.stop()
+ self.db_session.close()
+ SQLModel.metadata.drop_all(self.engine)
+
+ async def test_successful_synthesis_clear_consensus(self):
+ dept_rev_content = {
+ "_type": "DepartmentModel",
+ "_temp_id": "dept1",
+ "name": "Engineering",
+ }
+ emp_rev_content = {
+ "_type": "EmployeeModel",
+ "_temp_id": "emp1",
+ "name": "Jane Doe",
+ "department_ref_id": "dept1",
+ }
+
+ revision_content = [dept_rev_content, emp_rev_content]
+
+ llm_output_to_unwrap = {"results": revision_content}
+ self.mock_llm_client1.set_revisions_to_return([llm_output_to_unwrap])
+ self.mock_llm_client2.set_revisions_to_return([llm_output_to_unwrap])
+
+ from extrai.utils.flattening_utils import flatten_json
+
+ example_flat_revision = flatten_json(revision_content)
+ num_unique_paths = len(example_flat_revision)
+
+ mock_consensus_output = revision_content
+ mock_analytics_for_clear_consensus = {
+ "revisions_processed": 2,
+ "unique_paths_considered": num_unique_paths,
+ "paths_agreed_by_threshold": num_unique_paths,
+ "paths_resolved_by_conflict_resolver": 0,
+ "paths_omitted_due_to_no_consensus_or_resolver_omission": 0,
+ }
+ expected_consensus_input = [revision_content] * 2
+
+ with mock.patch.object(
+ self.orchestrator.pipeline.llm_runner.consensus,
+ "get_consensus",
+ return_value=(mock_consensus_output, mock_analytics_for_clear_consensus),
+ ) as mock_get_consensus_call:
+ await self.orchestrator.synthesize(["input"], self.db_session)
+ mock_get_consensus_call.assert_called_once_with(expected_consensus_input)
+
+ self.assertEqual(self.mock_llm_client1.call_count, 1)
+ self.assertEqual(self.mock_llm_client2.call_count, 1)
+
+ async def test_synthesize_with_user_provided_example_json(self):
+ user_example_json = json.dumps(
+ {"_type": "DepartmentModel", "name": "HR", "_temp_id": "hr_dept_example"}
+ )
+ input_strings = ["Some HR related text"]
+
+ llm_main_extraction_response = [
+ {
+ "results": [
+ {
+ "_type": "DepartmentModel",
+ "_temp_id": "dept_hr_actual",
+ "name": "Human Resources",
+ }
+ ]
+ }
+ ]
+ self.mock_llm_client1.set_revisions_to_return(llm_main_extraction_response)
+ self.mock_llm_client2.set_revisions_to_return(llm_main_extraction_response)
+
+ with mock.patch.object(
+ self.orchestrator.pipeline.context_preparer,
+ "prepare_example",
+ new_callable=AsyncMock,
+ ) as mock_prepare:
+ mock_prepare.return_value = user_example_json
+ await self.orchestrator.synthesize(
+ input_strings,
+ self.db_session,
+ extraction_example_json=user_example_json,
+ )
+ mock_prepare.assert_called_once()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/workflow_orchestrator/test_workflow_orchestrator_batch.py b/tests/core/workflow_orchestrator/test_workflow_orchestrator_batch.py
new file mode 100644
index 0000000..5f7aac8
--- /dev/null
+++ b/tests/core/workflow_orchestrator/test_workflow_orchestrator_batch.py
@@ -0,0 +1,157 @@
+import unittest
+from unittest import mock
+import logging
+from sqlmodel import Session
+from extrai.core.workflow_orchestrator import WorkflowOrchestrator
+from extrai.core.batch_models import BatchJobStatus, BatchProcessResult
+from tests.core.helpers.orchestrator_test_models import DepartmentModel
+from tests.core.helpers.mock_llm_clients import (
+ MockLLMClientForWorkflow as MockLLMClient,
+)
+
+
+class TestWorkflowOrchestratorBatch(unittest.IsolatedAsyncioTestCase):
+ @mock.patch("extrai.core.workflow_orchestrator.ModelRegistry")
+ def setUp(self, MockModelRegistry):
+ self.mock_llm_client = MockLLMClient()
+ self.root_sqlmodel_class = DepartmentModel
+
+ # Create orchestrator with mocks
+ self.orchestrator = WorkflowOrchestrator(
+ root_sqlmodel_class=self.root_sqlmodel_class,
+ llm_client=self.mock_llm_client,
+ )
+
+ # Mock the internal components that we want to verify delegation to
+ self.orchestrator.batch_pipeline = mock.AsyncMock()
+ self.orchestrator.result_processor = mock.Mock()
+ self.orchestrator.logger = mock.Mock(spec=logging.Logger)
+
+ async def test_synthesize_batch(self):
+ input_strings = ["input1", "input2"]
+ db_session = mock.Mock(spec=Session)
+ expected_batch_id = "batch_123"
+ self.orchestrator.batch_pipeline.submit_batch.return_value = expected_batch_id
+
+ result = await self.orchestrator.synthesize_batch(input_strings, db_session)
+
+ # Verify call with default args
+ self.orchestrator.batch_pipeline.submit_batch.assert_called_once()
+ call_kwargs = self.orchestrator.batch_pipeline.submit_batch.call_args[1]
+ self.assertEqual(call_kwargs["input_strings"], input_strings)
+ self.assertEqual(call_kwargs["db_session"], db_session)
+ self.assertEqual(result, expected_batch_id)
+
+ async def test_get_batch_status(self):
+ batch_id = "batch_123"
+ db_session = mock.Mock(spec=Session)
+ expected_status = BatchJobStatus.COMPLETED
+ self.orchestrator.batch_pipeline.get_status.return_value = expected_status
+
+ result = await self.orchestrator.get_batch_status(batch_id, db_session)
+
+ self.orchestrator.batch_pipeline.get_status.assert_called_once_with(
+ batch_id, db_session
+ )
+ self.assertEqual(result, expected_status)
+
+ async def test_process_batch_success(self):
+ batch_id = "batch_123"
+ db_session = mock.Mock(spec=Session)
+ hydrated_objects = ["obj1"]
+ expected_result = BatchProcessResult(
+ status=BatchJobStatus.COMPLETED, hydrated_objects=hydrated_objects
+ )
+ self.orchestrator.batch_pipeline.process_batch.return_value = expected_result
+
+ result = await self.orchestrator.process_batch(batch_id, db_session)
+
+ self.orchestrator.batch_pipeline.process_batch.assert_called_once_with(
+ batch_id, db_session
+ )
+ self.orchestrator.result_processor.persist.assert_called_once_with(
+ hydrated_objects, db_session
+ )
+ self.assertEqual(result, expected_result)
+
+ async def test_process_batch_not_completed(self):
+ batch_id = "batch_123"
+ db_session = mock.Mock(spec=Session)
+ expected_result = BatchProcessResult(status=BatchJobStatus.PROCESSING)
+ self.orchestrator.batch_pipeline.process_batch.return_value = expected_result
+
+ result = await self.orchestrator.process_batch(batch_id, db_session)
+
+ self.orchestrator.batch_pipeline.process_batch.assert_called_once_with(
+ batch_id, db_session
+ )
+ self.orchestrator.result_processor.persist.assert_not_called()
+ self.assertEqual(result, expected_result)
+
+ async def test_process_batch_persistence_failure(self):
+ batch_id = "batch_123"
+ db_session = mock.Mock(spec=Session)
+ hydrated_objects = ["obj1"]
+ process_result = BatchProcessResult(
+ status=BatchJobStatus.COMPLETED, hydrated_objects=hydrated_objects
+ )
+
+ self.orchestrator.batch_pipeline.process_batch.return_value = process_result
+ self.orchestrator.result_processor.persist.side_effect = Exception(
+ "Persistence Error"
+ )
+
+ with self.assertRaisesRegex(Exception, "Persistence Error"):
+ await self.orchestrator.process_batch(batch_id, db_session)
+
+ self.orchestrator.logger.error.assert_called()
+ self.assertIn(
+ "Extraction successful but persistence failed", process_result.message
+ )
+
+ async def test_monitor_batch_job_counting_transition(self):
+ batch_id = "batch_123"
+ db_session = mock.Mock(spec=Session)
+
+ # Mock status sequence:
+ # 1. COUNTING_READY_TO_PROCESS -> triggers first process_batch
+ # 2. PROCESSING -> waits
+ # 3. READY_TO_PROCESS -> triggers second process_batch
+ self.orchestrator.batch_pipeline.get_status.side_effect = [
+ BatchJobStatus.COUNTING_READY_TO_PROCESS,
+ BatchJobStatus.PROCESSING,
+ BatchJobStatus.READY_TO_PROCESS,
+ ]
+
+ # Mock process results
+ # 1. Result of processing COUNTING_READY: new batch submitted (PROCESSING)
+ process_result_1 = BatchProcessResult(
+ status=BatchJobStatus.PROCESSING,
+ message="Transitioned from counting to extraction",
+ )
+ # 2. Result of processing READY_TO_PROCESS: completed
+ process_result_2 = BatchProcessResult(
+ status=BatchJobStatus.COMPLETED, hydrated_objects=["obj1"]
+ )
+
+ self.orchestrator.batch_pipeline.process_batch.side_effect = [
+ process_result_1,
+ process_result_2,
+ ]
+
+ # Run monitoring with short poll interval
+ result = await self.orchestrator.monitor_batch_job(
+ batch_id, db_session, poll_interval=0.001
+ )
+
+ # Verify final result
+ self.assertEqual(result.status, BatchJobStatus.COMPLETED)
+ self.assertEqual(result.hydrated_objects, ["obj1"])
+
+ # Verify calls
+ self.assertEqual(self.orchestrator.batch_pipeline.get_status.call_count, 3)
+ self.assertEqual(self.orchestrator.batch_pipeline.process_batch.call_count, 2)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/workflow_orchestrator/test_workflow_orchestrator_e2e.py b/tests/core/workflow_orchestrator/test_workflow_orchestrator_e2e.py
index 251b049..db9106b 100644
--- a/tests/core/workflow_orchestrator/test_workflow_orchestrator_e2e.py
+++ b/tests/core/workflow_orchestrator/test_workflow_orchestrator_e2e.py
@@ -17,9 +17,8 @@
class TestWorkflowOrchestratorE2E(unittest.IsolatedAsyncioTestCase):
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- @mock.patch("extrai.core.workflow_orchestrator.generate_llm_schema_from_models")
- def setUp(self, mock_generate_llm_schema, mock_discover_sqlmodels):
+ @mock.patch("extrai.core.model_registry.SchemaInspector")
+ def setUp(self, MockSchemaInspector):
self.mock_llm_client = MockE2ELLMClient()
self.engine = create_engine("sqlite:///:memory:")
SQLModel.metadata.create_all(self.engine)
@@ -38,8 +37,14 @@ def setUp(self, mock_generate_llm_schema, mock_discover_sqlmodels):
{"schema_for_prompt_e2e": "mock_e2e_prompt_schema"}
)
- mock_discover_sqlmodels.return_value = self.mock_discovered_sqlmodels_e2e
- mock_generate_llm_schema.return_value = self.mock_prompt_llm_schema_str_e2e
+ # Configure the mock inspector instance
+ mock_inspector = MockSchemaInspector.return_value
+ mock_inspector.discover_sqlmodels_from_root.return_value = (
+ self.mock_discovered_sqlmodels_e2e
+ )
+ mock_inspector.generate_llm_schema_from_models.return_value = (
+ self.mock_prompt_llm_schema_str_e2e
+ )
def tearDown(self):
self.db_session.close()
diff --git a/tests/core/workflow_orchestrator/test_workflow_orchestrator_execution.py b/tests/core/workflow_orchestrator/test_workflow_orchestrator_execution.py
deleted file mode 100644
index e95c5e9..0000000
--- a/tests/core/workflow_orchestrator/test_workflow_orchestrator_execution.py
+++ /dev/null
@@ -1,861 +0,0 @@
-# tests/core/test_workflow_orchestrator_execution.py
-
-import unittest
-import json
-from unittest import mock
-from unittest.mock import AsyncMock
-
-from sqlmodel import SQLModel, create_engine, Session as SQLModelSession
-
-from extrai.core.errors import (
- LLMAPICallError,
- LLMOutputParseError,
- LLMOutputValidationError,
- WorkflowError,
- ConsensusProcessError,
- HydrationError,
-)
-from extrai.core.workflow_orchestrator import (
- WorkflowOrchestrator,
- LLMInteractionError,
-)
-from extrai.core.analytics_collector import (
- WorkflowAnalyticsCollector,
-)
-from extrai.core.db_writer import DatabaseWriterError
-
-from extrai.core.example_json_generator import (
- ExampleGenerationError,
-)
-
-
-from tests.core.helpers.orchestrator_test_models import DepartmentModel, EmployeeModel
-from tests.core.helpers.mock_llm_clients import (
- MockLLMClientForWorkflow as MockLLMClient,
-)
-
-
-class TestWorkflowOrchestratorExecution(unittest.IsolatedAsyncioTestCase):
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- @mock.patch("extrai.core.workflow_orchestrator.generate_llm_schema_from_models")
- def setUp(self, mock_generate_llm_schema, mock_discover_sqlmodels):
- self.mock_llm_client1 = MockLLMClient()
- self.mock_llm_client2 = MockLLMClient()
-
- self.discovered_sqlmodels_for_execution = [DepartmentModel, EmployeeModel]
- self.prompt_llm_schema_for_execution = json.dumps(
- {"schema_for_prompt": "mock_llm_prompt_schema"}
- )
-
- mock_discover_sqlmodels.return_value = self.discovered_sqlmodels_for_execution
- mock_generate_llm_schema.return_value = self.prompt_llm_schema_for_execution
-
- self.orchestrator = WorkflowOrchestrator(
- root_sqlmodel_class=DepartmentModel,
- llm_client=[self.mock_llm_client1, self.mock_llm_client2],
- num_llm_revisions=2,
- max_validation_retries_per_revision=1,
- )
-
- self.engine = create_engine("sqlite:///:memory:")
- SQLModel.metadata.create_all(self.engine)
- self.db_session: SQLModelSession = SQLModelSession(self.engine)
-
- def tearDown(self):
- self.db_session.close()
- SQLModel.metadata.drop_all(self.engine)
-
- async def test_successful_synthesis_clear_consensus(self):
- dept_rev_content = {
- "_type": "DepartmentModel",
- "_temp_id": "dept1",
- "name": "Engineering",
- }
- emp_rev_content = {
- "_type": "EmployeeModel",
- "_temp_id": "emp1",
- "name": "Jane Doe",
- "department_ref_id": "dept1",
- }
-
- # The content of a single, successful revision is a list of entities.
- revision_content = [dept_rev_content, emp_rev_content]
-
- # The raw output from the LLM is a dict that needs to be unwrapped by the client.
- # The orchestrator will call each client `num_llm_revisions` times in total.
- # With 2 clients and num_llm_revisions=2, each client is called once.
- llm_output_to_unwrap = {"results": revision_content}
- self.mock_llm_client1.set_revisions_to_return([llm_output_to_unwrap])
- self.mock_llm_client2.set_revisions_to_return([llm_output_to_unwrap])
-
- input_strings = ["Some text about Jane Doe in Engineering."]
-
- from extrai.utils.flattening_utils import flatten_json
-
- example_flat_revision = flatten_json(revision_content)
- num_unique_paths = len(example_flat_revision)
-
- # The output of the consensus process is now a flat list of entities.
- mock_consensus_output = revision_content
- mock_analytics_for_clear_consensus = {
- "revisions_processed": 2, # 2 clients * 1 revision each = 2
- "unique_paths_considered": num_unique_paths,
- "paths_agreed_by_threshold": num_unique_paths,
- "paths_resolved_by_conflict_resolver": 0,
- "paths_omitted_due_to_no_consensus_or_resolver_omission": 0,
- }
-
- # The input to the consensus function is a list of revisions.
- # With 2 clients and num_llm_revisions=2, we expect 2 revisions in total.
- expected_consensus_input = [revision_content] * 2
-
- with mock.patch.object(
- self.orchestrator.json_consensus,
- "get_consensus",
- return_value=(mock_consensus_output, mock_analytics_for_clear_consensus),
- ) as mock_get_consensus_call:
- hydrated_objects = await self.orchestrator.synthesize(
- input_strings, self.db_session
- )
- mock_get_consensus_call.assert_called_once_with(expected_consensus_input)
-
- self.assertEqual(self.mock_llm_client1.call_count, 1)
- self.assertEqual(self.mock_llm_client2.call_count, 1)
- self.assertFalse(hasattr(self.mock_llm_client1, "last_formal_json_schema_str"))
- self.assertEqual(self.mock_llm_client1.last_max_validation_retries, 1)
- self.assertIn(
- "mock_llm_prompt_schema", self.mock_llm_client1.last_system_prompt
- )
- self.assertIn(
- "mock_llm_prompt_schema", self.mock_llm_client2.last_system_prompt
- )
-
- self.assertIsInstance(hydrated_objects, list)
- self.assertEqual(len(hydrated_objects), 2)
- self.assertIsInstance(hydrated_objects[0], SQLModel)
- self.assertIsInstance(hydrated_objects[1], SQLModel)
-
- self.assertIs(
- self.mock_llm_client1.last_analytics_collector_passed,
- self.orchestrator.analytics_collector,
- )
- self.assertIs(
- self.mock_llm_client2.last_analytics_collector_passed,
- self.orchestrator.analytics_collector,
- )
- report = self.orchestrator.get_analytics_report()
- self.assertEqual(report["llm_api_call_failures"], 0)
- self.assertEqual(report["total_invalid_parsing_errors"], 0)
- self.assertEqual(report["number_of_consensus_runs"], 1)
-
- self.assertIn("all_consensus_run_details", report)
- self.assertEqual(len(report["all_consensus_run_details"]), 1)
- run_details = report["all_consensus_run_details"][0]
- self.assertEqual(run_details["revisions_processed"], 2)
- self.assertGreater(run_details["unique_paths_considered"], 0)
- self.assertGreaterEqual(run_details["paths_agreed_by_threshold"], 0)
- if run_details["unique_paths_considered"] > 0:
- self.assertAlmostEqual(
- report["average_path_agreement_ratio"],
- run_details["paths_agreed_by_threshold"]
- / run_details["unique_paths_considered"],
- )
-
- async def test_successful_synthesis_without_db_session(self):
- dept_rev_content = {
- "_type": "DepartmentModel",
- "_temp_id": "dept1",
- "name": "Engineering",
- }
- emp_rev_content = {
- "_type": "EmployeeModel",
- "_temp_id": "emp1",
- "name": "Jane Doe",
- "department_ref_id": "dept1",
- }
- revision_content = [dept_rev_content, emp_rev_content]
- llm_output_to_unwrap = {"results": revision_content}
- self.mock_llm_client1.set_revisions_to_return([llm_output_to_unwrap])
- self.mock_llm_client2.set_revisions_to_return([llm_output_to_unwrap])
-
- input_strings = ["Some text about Jane Doe in Engineering."]
-
- mock_consensus_output = revision_content
-
- with mock.patch.object(
- self.orchestrator.json_consensus,
- "get_consensus",
- return_value=(mock_consensus_output, {}),
- ) as mock_get_consensus_call:
- hydrated_objects = await self.orchestrator.synthesize(
- input_strings, db_session_for_hydration=None
- )
- expected_consensus_input = [revision_content] * 2
- mock_get_consensus_call.assert_called_once_with(expected_consensus_input)
-
- self.assertEqual(len(hydrated_objects), 2)
- self.assertIsInstance(hydrated_objects[0], DepartmentModel)
- self.assertIsInstance(hydrated_objects[1], EmployeeModel)
-
- # Verify that the objects are not attached to any session
- from sqlalchemy.orm import object_session
-
- self.assertIsNone(object_session(hydrated_objects[0]))
- self.assertIsNone(object_session(hydrated_objects[1]))
-
- async def test_synthesis_llm_client_raises_parse_error(self):
- self.mock_llm_client1.set_should_raise_exception(
- LLMOutputParseError("Parsing failed", raw_content="bad json")
- )
- self.mock_llm_client1.set_should_raise_exception_for_example_gen(None)
-
- with self.assertRaisesRegex(
- LLMInteractionError, "LLM client operation failed: Parsing failed"
- ):
- await self.orchestrator.synthesize(["Some text"], self.db_session)
-
- self.assertEqual(self.mock_llm_client1.example_gen_call_count, 1)
- self.assertEqual(self.mock_llm_client1.call_count, 1)
-
- report = self.orchestrator.get_analytics_report()
- self.assertEqual(report["llm_output_parse_errors"], 1)
- self.assertEqual(report["total_invalid_parsing_errors"], 1)
- self.assertEqual(report["number_of_consensus_runs"], 0)
-
- async def test_synthesis_llm_client_raises_validation_error(self):
- self.mock_llm_client1.set_should_raise_exception(
- LLMOutputValidationError(
- "Validation failed", parsed_json={"key": "wrong_type"}
- )
- )
- self.mock_llm_client1.set_should_raise_exception_for_example_gen(None)
-
- with self.assertRaisesRegex(
- LLMInteractionError, "LLM client operation failed: Validation failed"
- ):
- await self.orchestrator.synthesize(["Some text"], self.db_session)
-
- self.assertEqual(self.mock_llm_client1.example_gen_call_count, 1)
- self.assertEqual(self.mock_llm_client1.call_count, 1)
-
- report = self.orchestrator.get_analytics_report()
- self.assertEqual(report["llm_output_validation_errors"], 1)
- self.assertEqual(report["total_invalid_parsing_errors"], 1)
- self.assertEqual(report["number_of_consensus_runs"], 0)
-
- async def test_synthesis_llm_client_raises_api_call_error(self):
- self.mock_llm_client1.set_should_raise_exception(
- LLMAPICallError("API unavailable")
- )
- self.mock_llm_client1.set_should_raise_exception_for_example_gen(None)
-
- with self.assertRaisesRegex(
- LLMInteractionError, "LLM client operation failed: API unavailable"
- ):
- await self.orchestrator.synthesize(["Some text"], self.db_session)
-
- self.assertEqual(self.mock_llm_client1.example_gen_call_count, 1)
- self.assertEqual(self.mock_llm_client1.call_count, 1)
-
- report = self.orchestrator.get_analytics_report()
- self.assertEqual(report["llm_api_call_failures"], 1)
- self.assertEqual(report["total_invalid_parsing_errors"], 0)
- self.assertEqual(report["number_of_consensus_runs"], 0)
-
- async def test_synthesis_no_consensus_reached_analytics(self):
- llm_revisions = [
- {
- "results": [
- {
- "_type": "DepartmentModel",
- "name": "Engineering",
- "_temp_id": "d1",
- }
- ]
- },
- {
- "results": [
- {"_type": "DepartmentModel", "name": "Marketing", "_temp_id": "d2"}
- ]
- },
- ]
- # With 2 clients and num_llm_revisions=2, each client will be called once.
- # We set each to return one of the different revisions.
- self.mock_llm_client1.set_revisions_to_return([llm_revisions[0]])
- self.mock_llm_client2.set_revisions_to_return([llm_revisions[1]])
-
- mock_consensus_details = {
- "revisions_processed": 2,
- "unique_paths_considered": 2,
- "paths_agreed_by_threshold": 0,
- "paths_resolved_by_conflict_resolver": 0,
- "paths_omitted_due_to_no_consensus_or_resolver_omission": 2,
- }
-
- with mock.patch.object(
- self.orchestrator.json_consensus,
- "get_consensus",
- return_value=([], mock_consensus_details),
- ) as mock_get_consensus:
- hydrated_objects = await self.orchestrator.synthesize(
- ["Some text"], self.db_session
- )
- # The orchestrator will gather one revision from each of the two clients.
- expected_consensus_input = [
- llm_revisions[0]["results"],
- llm_revisions[1]["results"],
- ]
- mock_get_consensus.assert_called_once()
- actual_call_args = mock_get_consensus.call_args[0][0]
- self.assertCountEqual(actual_call_args, expected_consensus_input)
-
- self.assertEqual(len(hydrated_objects), 0)
-
- report = self.orchestrator.get_analytics_report()
- self.assertEqual(report["llm_api_call_failures"], 0)
- self.assertEqual(report["total_invalid_parsing_errors"], 0)
- self.assertEqual(report["number_of_consensus_runs"], 1)
- self.assertEqual(report["average_path_agreement_ratio"], 0.0)
- self.assertEqual(report["all_consensus_run_details"][0], mock_consensus_details)
-
- async def test_synthesize_with_user_provided_example_json(self):
- user_example_json = json.dumps(
- {"_type": "DepartmentModel", "name": "HR", "_temp_id": "hr_dept_example"}
- )
- input_strings = ["Some HR related text"]
-
- llm_main_extraction_response = [
- {
- "results": [
- {
- "_type": "DepartmentModel",
- "_temp_id": "dept_hr_actual",
- "name": "Human Resources",
- }
- ]
- }
- ]
- self.mock_llm_client1.set_revisions_to_return(llm_main_extraction_response)
- self.mock_llm_client2.set_revisions_to_return(llm_main_extraction_response)
-
- with mock.patch(
- "extrai.core.workflow_orchestrator.ExampleJSONGenerator"
- ) as mocked_example_generator:
- await self.orchestrator.synthesize(
- input_strings,
- self.db_session,
- extraction_example_json=user_example_json,
- )
- mocked_example_generator.assert_not_called()
-
- self.assertIn(user_example_json, self.mock_llm_client1.last_system_prompt)
- if user_example_json:
- self.assertIn(
- "# EXAMPLE OF EXTRACTION:", self.mock_llm_client1.last_system_prompt
- )
-
- @mock.patch("extrai.core.workflow_orchestrator.ExampleJSONGenerator")
- async def test_synthesize_auto_generate_example_json_success(
- self, mocked_example_generator
- ):
- input_strings = ["Some text for engineering department"]
- auto_generated_example_str = json.dumps(
- {"_type": "DepartmentModel", "name": "AutoGeneratedExample"}
- )
-
- mock_generator_instance = mocked_example_generator.return_value
- mock_generator_instance.generate_example = AsyncMock(
- return_value=auto_generated_example_str
- )
-
- llm_main_extraction_response = [
- {
- "results": [
- {
- "_type": "DepartmentModel",
- "_temp_id": "dept_eng_actual",
- "name": "Engineering",
- }
- ]
- }
- ]
- self.mock_llm_client1.set_revisions_to_return(llm_main_extraction_response)
- self.mock_llm_client2.set_revisions_to_return(llm_main_extraction_response)
-
- await self.orchestrator.synthesize(
- input_strings, self.db_session, extraction_example_json=""
- )
-
- mocked_example_generator.assert_called_once_with(
- llm_client=self.mock_llm_client1,
- output_model=self.orchestrator.root_sqlmodel_class,
- analytics_collector=self.orchestrator.analytics_collector,
- max_validation_retries_per_revision=self.orchestrator.max_validation_retries_per_revision,
- logger=self.orchestrator.logger,
- )
- mock_generator_instance.generate_example.assert_called_once()
-
- self.assertIn(
- auto_generated_example_str, self.mock_llm_client2.last_system_prompt
- )
- self.assertIn(
- "# EXAMPLE OF EXTRACTION:", self.mock_llm_client2.last_system_prompt
- )
-
- @mock.patch("extrai.core.workflow_orchestrator.ExampleJSONGenerator")
- async def test_synthesize_auto_generate_example_json_failure(
- self, mocked_example_generator
- ):
- input_strings = ["Some text"]
- mock_generator_instance = mocked_example_generator.return_value
- example_gen_error = ExampleGenerationError("Failed to generate example")
- mock_generator_instance.generate_example = AsyncMock(
- side_effect=example_gen_error
- )
-
- with self.assertRaisesRegex(
- WorkflowError,
- r"Failed to auto-generate extraction example:.*Failed to generate example",
- ):
- await self.orchestrator.synthesize(
- input_strings, self.db_session, extraction_example_json=""
- )
-
- mocked_example_generator.assert_called_once_with(
- llm_client=self.mock_llm_client1,
- output_model=self.orchestrator.root_sqlmodel_class,
- analytics_collector=self.orchestrator.analytics_collector,
- max_validation_retries_per_revision=self.orchestrator.max_validation_retries_per_revision,
- logger=self.orchestrator.logger,
- )
- mock_generator_instance.generate_example.assert_called_once()
- self.assertEqual(self.mock_llm_client1.call_count, 0)
- self.assertEqual(self.mock_llm_client2.call_count, 0)
-
- def test_get_analytics_collector_method(self):
- collector_instance = self.orchestrator.get_analytics_collector()
- self.assertIsInstance(collector_instance, WorkflowAnalyticsCollector)
- self.assertIs(collector_instance, self.orchestrator.analytics_collector)
-
- async def test_synthesis_llm_returns_malformed_revisions_now_caught_by_client(self):
- self.mock_llm_client1.set_should_raise_exception(
- LLMOutputParseError("Client failed to parse", raw_content="not a dict")
- )
- self.mock_llm_client1.set_should_raise_exception_for_example_gen(None)
-
- with self.assertRaisesRegex(
- LLMInteractionError, "LLM client operation failed: Client failed to parse"
- ):
- await self.orchestrator.synthesize(["Some text"], self.db_session)
-
- self.assertEqual(self.mock_llm_client1.example_gen_call_count, 1)
- self.assertEqual(self.mock_llm_client1.call_count, 1)
-
- report = self.orchestrator.get_analytics_report()
- self.assertEqual(report["llm_output_parse_errors"], 1)
- self.assertEqual(report["total_invalid_parsing_errors"], 1)
- self.assertEqual(report["number_of_consensus_runs"], 0)
-
- async def test_synthesize_empty_input_strings(self):
- with self.assertRaisesRegex(ValueError, "Input strings list cannot be empty."):
- await self.orchestrator.synthesize([], self.db_session)
-
- @mock.patch("extrai.core.workflow_orchestrator.ExampleJSONGenerator")
- async def test_synthesize_auto_generate_example_json_raises_generic_exception(
- self, mocked_example_generator
- ):
- mock_generator_instance = mocked_example_generator.return_value
- mock_generator_instance.generate_example = AsyncMock(
- side_effect=Exception("Unexpected boom during example gen")
- )
-
- current_orchestrator = WorkflowOrchestrator(
- root_sqlmodel_class=DepartmentModel,
- llm_client=self.mock_llm_client1,
- num_llm_revisions=self.orchestrator.num_llm_revisions,
- max_validation_retries_per_revision=self.orchestrator.max_validation_retries_per_revision,
- )
- current_orchestrator.analytics_collector.record_custom_event = mock.Mock()
-
- with self.assertRaisesRegex(
- WorkflowError,
- r"An unexpected error occurred during auto-generation of extraction example:.*Unexpected boom during example gen",
- ):
- await current_orchestrator.synthesize(
- ["Some text"], self.db_session, extraction_example_json=""
- )
-
- mocked_example_generator.assert_called_once()
- mock_generator_instance.generate_example.assert_called_once()
- current_orchestrator.analytics_collector.record_custom_event.assert_any_call(
- "example_json_auto_generation_unexpected_failure"
- )
- self.assertEqual(self.mock_llm_client1.call_count, 0)
-
- async def test_synthesis_llm_client_returns_no_revisions(self):
- # Each call to the mock should return an empty list, simulating no content.
- self.mock_llm_client1.set_revisions_to_return([[]] * 2)
- self.mock_llm_client2.set_revisions_to_return([[]] * 2)
- self.mock_llm_client1.set_should_raise_exception_for_example_gen(None)
-
- # The orchestrator should handle cases where all revisions are empty without raising an error.
- hydrated_objects = await self.orchestrator.synthesize(
- ["Some text"], self.db_session
- )
-
- self.assertEqual(len(hydrated_objects), 0)
- self.assertEqual(self.mock_llm_client1.example_gen_call_count, 1)
- self.assertEqual(self.mock_llm_client1.call_count, 1)
- self.assertEqual(self.mock_llm_client2.call_count, 1)
-
- async def test_synthesis_llm_client_returns_no_revisions_at_all(self):
- self.mock_llm_client1.set_revisions_to_return([[]])
- self.mock_llm_client2.set_revisions_to_return([[]])
- self.mock_llm_client1.set_should_raise_exception_for_example_gen(None)
-
- with mock.patch("asyncio.gather", new_callable=AsyncMock) as mock_gather:
- mock_gather.return_value = []
- with self.assertRaisesRegex(
- LLMInteractionError,
- "LLM client returned no revisions despite being requested.",
- ):
- await self.orchestrator.synthesize(["Some text"], self.db_session)
-
- async def test_synthesis_llm_client_returns_malformed_revision_item(self):
- # The mock now expects a list of revisions to return.
- # We simulate one client returning a valid-looking list, and the other returning a malformed one.
- self.mock_llm_client1.set_revisions_to_return([[{"key": "value"}]])
- self.mock_llm_client2.set_revisions_to_return([["not a dict"]])
-
- # Malformed items are now filtered out before consensus, so no error should be raised.
- # The process should complete and return no hydrated objects.
- hydrated_objects = await self.orchestrator.synthesize(
- ["Some text"], self.db_session
- )
- self.assertEqual(len(hydrated_objects), 0)
-
- async def test_synthesis_llm_client_raises_generic_exception(self):
- self.mock_llm_client1.set_should_raise_exception(Exception("LLM Generic Boom!"))
- self.mock_llm_client1.set_should_raise_exception_for_example_gen(None)
-
- with self.assertRaisesRegex(
- LLMInteractionError,
- "An unexpected error occurred during LLM interaction: LLM Generic Boom!",
- ):
- await self.orchestrator.synthesize(["Some text"], self.db_session)
-
- self.assertEqual(self.mock_llm_client1.example_gen_call_count, 1)
- self.assertEqual(self.mock_llm_client1.call_count, 1)
-
- async def test_synthesis_consensus_returns_none(self):
- self.mock_llm_client1.set_revisions_to_return([{"key": "value"}])
- self.mock_llm_client2.set_revisions_to_return([{"key": "value"}])
-
- mock_consensus_details = {"revisions_processed": 1}
-
- with (
- mock.patch.object(
- self.orchestrator.analytics_collector, "record_consensus_run_details"
- ) as mock_record_details,
- mock.patch.object(
- self.orchestrator.json_consensus,
- "get_consensus",
- return_value=(None, mock_consensus_details),
- ) as mock_get_consensus,
- ):
- hydrated_objects = await self.orchestrator.synthesize(
- ["Some text"], self.db_session
- )
-
- mock_get_consensus.assert_called_once()
- mock_record_details.assert_called_once_with(mock_consensus_details)
-
- self.assertEqual(len(hydrated_objects), 0)
-
- async def test_synthesis_consensus_returns_empty_dict(self):
- self.mock_llm_client1.set_revisions_to_return([{"key": "value"}])
- mock_consensus_details = {"revisions_processed": 1}
-
- with mock.patch.object(
- self.orchestrator.json_consensus,
- "get_consensus",
- return_value=([], mock_consensus_details),
- ) as mock_get_consensus:
- hydrated_objects = await self.orchestrator.synthesize(
- ["Some text"], self.db_session
- )
- mock_get_consensus.assert_called_once()
-
- self.assertEqual(len(hydrated_objects), 0)
-
- async def test_synthesis_consensus_results_not_all_dicts(self):
- self.mock_llm_client1.set_revisions_to_return([{"key": "value"}])
- mock_consensus_output = {"results": ["not_a_dict", {"item": 2}]}
- mock_consensus_details = {"revisions_processed": 1}
-
- with mock.patch.object(
- self.orchestrator.json_consensus,
- "get_consensus",
- return_value=(mock_consensus_output["results"], mock_consensus_details),
- ):
- # This scenario now raises a HydrationError because the list contains non-dict items.
- with self.assertRaises(HydrationError):
- await self.orchestrator.synthesize(["Some text"], self.db_session)
-
- async def test_synthesis_consensus_returns_dict(self):
- llm_revisions = [[{"_type": "DepartmentModel", "name": "ConsensusDept"}]]
- self.mock_llm_client1.set_revisions_to_return(llm_revisions)
- self.mock_llm_client2.set_revisions_to_return(llm_revisions)
-
- # Case 1: Consensus returns a single dictionary
- consensus_single_dict = {
- "_type": "DepartmentModel",
- "_temp_id": "cd1",
- "name": "Single Dict Dept",
- }
- # Case 2: Consensus returns a dictionary with a 'results' key
- consensus_dict_with_results = {"results": [consensus_single_dict]}
-
- test_cases = [
- ("single_dict", consensus_single_dict),
- ("dict_with_results", consensus_dict_with_results),
- ]
-
- for name, consensus_output in test_cases:
- with self.subTest(name=name):
- with mock.patch.object(
- self.orchestrator.json_consensus,
- "get_consensus",
- return_value=(consensus_output, {}),
- ):
- hydrated_objects = await self.orchestrator.synthesize(
- ["Some text"], self.db_session
- )
-
- self.assertEqual(len(hydrated_objects), 1)
- self.assertIsInstance(hydrated_objects[0], DepartmentModel)
- self.assertEqual(hydrated_objects[0].name, "Single Dict Dept")
-
- async def test_synthesis_consensus_returns_unexpected_type(self):
- self.mock_llm_client1.set_revisions_to_return([{"key": "value"}])
- mock_consensus_details = {"revisions_processed": 1}
-
- with mock.patch.object(
- self.orchestrator.json_consensus,
- "get_consensus",
- return_value=("unexpected_string_type", mock_consensus_details),
- ):
- with self.assertRaisesRegex(
- ConsensusProcessError,
- "Unexpected type from json_consensus.get_consensus: .",
- ):
- await self.orchestrator.synthesize(["Some text"], self.db_session)
-
- async def test_synthesis_consensus_get_consensus_raises_exception(self):
- self.mock_llm_client1.set_revisions_to_return([{"key": "value"}])
-
- with mock.patch.object(
- self.orchestrator.json_consensus,
- "get_consensus",
- side_effect=Exception("Extrai boom!"),
- ):
- with self.assertRaisesRegex(
- ConsensusProcessError,
- "Failed during JSON consensus processing: Extrai boom!",
- ):
- await self.orchestrator.synthesize(["Some text"], self.db_session)
-
- async def test_synthesis_hydration_fails(self):
- # This test covers hydration failure with and without a db session.
- llm_return = [{"results": [{"_type": "DepartmentModel", "name": "Valid Dept"}]}]
- self.mock_llm_client1.set_revisions_to_return(llm_return)
- self.mock_llm_client2.set_revisions_to_return(llm_return)
-
- mock_consensus_output = [
- {"_type": "DepartmentModel", "_temp_id": "d1", "name": "Dept For Hydration"}
- ]
-
- for session in [self.db_session, None]:
- with self.subTest(session=session):
- with (
- mock.patch.object(
- self.orchestrator.json_consensus,
- "get_consensus",
- return_value=(mock_consensus_output, {}),
- ),
- mock.patch(
- "extrai.core.sqlalchemy_hydrator.SQLAlchemyHydrator.hydrate",
- side_effect=Exception("Hydration boom!"),
- ) as mock_hydrate,
- ):
- with self.assertRaisesRegex(
- HydrationError,
- "Failed during SQLAlchemy object hydration: Hydration boom!",
- ):
- await self.orchestrator.synthesize(
- ["Some text"], db_session_for_hydration=session
- )
- mock_hydrate.assert_called_once()
-
- @mock.patch("extrai.core.workflow_orchestrator.persist_objects")
- async def test_synthesize_and_save_no_hydrated_objects(self, mock_persist_objects):
- with (
- mock.patch.object(
- self.orchestrator, "synthesize", AsyncMock(return_value=[])
- ),
- mock.patch.object(self.db_session, "rollback") as mock_rollback,
- mock.patch.object(self.orchestrator.logger, "info") as mock_logger_info,
- ):
- await self.orchestrator.synthesize_and_save(["some input"], self.db_session)
-
- mock_logger_info.assert_called_with(
- "WorkflowOrchestrator: No objects were hydrated, thus nothing to persist."
- )
- mock_persist_objects.assert_not_called()
- mock_rollback.assert_not_called()
-
- @mock.patch("extrai.core.workflow_orchestrator.persist_objects")
- async def test_synthesize_and_save_persist_raises_db_writer_error(
- self, mock_persist_objects
- ):
- mock_persist_objects.side_effect = DatabaseWriterError("DB write failed")
-
- mock_hydrated_object = DepartmentModel(name="Test Dept")
- with (
- mock.patch.object(
- self.orchestrator,
- "synthesize",
- AsyncMock(return_value=[mock_hydrated_object]),
- ),
- mock.patch.object(self.db_session, "rollback") as mock_rollback,
- ):
- with self.assertRaises(DatabaseWriterError):
- await self.orchestrator.synthesize_and_save(
- ["some input"], self.db_session
- )
-
- mock_persist_objects.assert_called_once_with(
- db_session=self.db_session,
- objects_to_persist=[mock_hydrated_object],
- logger=self.orchestrator.logger,
- )
- mock_rollback.assert_called_once()
-
- @mock.patch("extrai.core.workflow_orchestrator.persist_objects")
- async def test_synthesize_and_save_persist_raises_generic_exception(
- self, mock_persist_objects
- ):
- mock_persist_objects.side_effect = Exception("Generic DB boom")
-
- mock_hydrated_object = DepartmentModel(name="Test Dept")
- with (
- mock.patch.object(
- self.orchestrator,
- "synthesize",
- AsyncMock(return_value=[mock_hydrated_object]),
- ),
- mock.patch.object(self.db_session, "rollback") as mock_rollback,
- ):
- with self.assertRaisesRegex(
- WorkflowError,
- "An unexpected error occurred during database persistence phase: Generic DB boom",
- ):
- await self.orchestrator.synthesize_and_save(
- ["some input"], self.db_session
- )
-
- mock_persist_objects.assert_called_once_with(
- db_session=self.db_session,
- objects_to_persist=[mock_hydrated_object],
- logger=self.orchestrator.logger,
- )
- mock_rollback.assert_called_once()
-
- async def test_synthesize_with_extraction_example_parameters(self):
- """Test different scenarios for extraction_example_object and extraction_example_json."""
- dept1 = DepartmentModel(name="Dept 1")
- dept2 = DepartmentModel(name="Dept 2")
-
- test_cases = [
- {
- "name": "single_object",
- "object": dept1,
- "json_arg": "",
- "expected_json_in_prepare": lambda j: len(json.loads(j)) == 1
- and json.loads(j)[0]["name"] == "Dept 1",
- "expect_warning": False,
- },
- {
- "name": "list_of_objects",
- "object": [dept1, dept2],
- "json_arg": "",
- "expected_json_in_prepare": lambda j: len(json.loads(j)) == 2
- and json.loads(j)[1]["name"] == "Dept 2",
- "expect_warning": False,
- },
- {
- "name": "priority_json_over_object",
- "object": dept1,
- "json_arg": '[{"name": "Override"}]',
- "expected_json_in_prepare": lambda j: j == '[{"name": "Override"}]',
- "expect_warning": False,
- },
- {
- "name": "unsupported_type",
- "object": ["unsupported"],
- "json_arg": "",
- "expected_json_in_prepare": lambda j: j == "",
- "expect_warning": True,
- },
- ]
-
- for case in test_cases:
- with self.subTest(case=case["name"]):
- with (
- mock.patch.object(
- self.orchestrator,
- "_prepare_extraction_example",
- new_callable=AsyncMock,
- ) as mock_prepare,
- mock.patch.object(
- self.orchestrator,
- "_execute_standard_extraction",
- AsyncMock(return_value=[]),
- ),
- mock.patch.object(
- self.orchestrator,
- "_hydrate_results",
- mock.MagicMock(return_value=[]),
- ),
- mock.patch.object(
- self.orchestrator.logger, "warning"
- ) as mock_logger_warning,
- ):
- mock_prepare.return_value = "{}"
-
- await self.orchestrator.synthesize(
- input_strings=["test"],
- db_session_for_hydration=self.db_session,
- extraction_example_object=case["object"],
- extraction_example_json=case["json_arg"],
- )
-
- # Verify warning
- if case["expect_warning"]:
- mock_logger_warning.assert_called()
- args, _ = mock_logger_warning.call_args
- self.assertIn("Skipping unsupported object type", args[0])
- else:
- mock_logger_warning.assert_not_called()
-
- # Verify _prepare_extraction_example argument
- args, _ = mock_prepare.call_args
- actual_json = args[0]
- self.assertTrue(
- case["expected_json_in_prepare"](actual_json),
- f"Failed for case {case['name']}: actual json {actual_json}",
- )
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/core/workflow_orchestrator/test_workflow_orchestrator_init.py b/tests/core/workflow_orchestrator/test_workflow_orchestrator_init.py
index 58ddb61..fefb548 100644
--- a/tests/core/workflow_orchestrator/test_workflow_orchestrator_init.py
+++ b/tests/core/workflow_orchestrator/test_workflow_orchestrator_init.py
@@ -4,13 +4,10 @@
import json
from unittest import mock
-from extrai.core.workflow_orchestrator import (
- WorkflowOrchestrator,
- ConfigurationError,
-)
-from extrai.core.analytics_collector import (
- WorkflowAnalyticsCollector,
-)
+from extrai.core.errors import ConfigurationError
+from extrai.core.workflow_orchestrator import WorkflowOrchestrator
+
+from extrai.core.analytics_collector import WorkflowAnalyticsCollector
from tests.core.helpers.orchestrator_test_models import DepartmentModel, EmployeeModel
from tests.core.helpers.mock_llm_clients import (
MockLLMClientForWorkflow as MockLLMClient,
@@ -26,67 +23,36 @@ def setUp(self):
{"schema_for_prompt": "mock_llm_prompt_schema"}
)
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- @mock.patch("extrai.core.workflow_orchestrator.generate_llm_schema_from_models")
- def test_successful_initialization(
- self, mock_generate_llm_schema, mock_discover_sqlmodels
- ):
- mock_discover_sqlmodels.return_value = self.mock_discovered_sqlmodel_classes
- mock_generate_llm_schema.return_value = self.mock_prompt_llm_schema_str
-
+ @mock.patch("extrai.core.workflow_orchestrator.ModelRegistry")
+ def test_successful_initialization(self, MockModelRegistry):
orchestrator = WorkflowOrchestrator(
root_sqlmodel_class=self.root_sqlmodel_class,
llm_client=self.mock_llm_client,
max_validation_retries_per_revision=2,
)
- mock_discover_sqlmodels.assert_called_once_with(self.root_sqlmodel_class)
-
- expected_sqla_models_set = {DepartmentModel, EmployeeModel}
- mock_generate_llm_schema.assert_called_once()
- called_initial_models_list = mock_generate_llm_schema.call_args.kwargs[
- "initial_model_classes"
- ]
- self.assertEqual(set(called_initial_models_list), expected_sqla_models_set)
+ MockModelRegistry.assert_called_once_with(self.root_sqlmodel_class, mock.ANY)
- expected_model_map = {
- "DepartmentModel": DepartmentModel,
- "EmployeeModel": EmployeeModel,
- }
- self.assertEqual(
- orchestrator.model_schema_map_for_hydration, expected_model_map
- )
- self.assertEqual(
- orchestrator.target_json_schema_for_llm, self.mock_prompt_llm_schema_str
- )
- self.assertFalse(hasattr(orchestrator, "formal_json_schema_for_validation"))
- self.assertEqual(orchestrator.max_validation_retries_per_revision, 2)
+ self.assertEqual(orchestrator.config.max_validation_retries_per_revision, 2)
self.assertIsNotNone(orchestrator.analytics_collector)
self.assertIsInstance(
orchestrator.analytics_collector, WorkflowAnalyticsCollector
)
- self.assertFalse(orchestrator.use_hierarchical_extraction) # Default
+ self.assertFalse(orchestrator.config.use_hierarchical_extraction)
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- @mock.patch("extrai.core.workflow_orchestrator.generate_llm_schema_from_models")
+ @mock.patch("extrai.core.workflow_orchestrator.ModelRegistry")
def test_initialization_with_hierarchical_extraction_enabled(
- self, mock_generate_llm_schema, mock_discover_sqlmodels
+ self, MockModelRegistry
):
- mock_discover_sqlmodels.return_value = self.mock_discovered_sqlmodel_classes
- mock_generate_llm_schema.return_value = self.mock_prompt_llm_schema_str
+ # We don't check for log warning here as it's likely handled inside ExtractionConfig or not logged anymore
+ # If it is logged, it would be by config init.
- with mock.patch("logging.Logger.warning") as mock_logger_warning:
- orchestrator = WorkflowOrchestrator(
- root_sqlmodel_class=self.root_sqlmodel_class,
- llm_client=self.mock_llm_client,
- use_hierarchical_extraction=True,
- )
- self.assertTrue(orchestrator.use_hierarchical_extraction)
- mock_logger_warning.assert_called_once_with(
- "Hierarchical extraction is enabled. "
- "This may significantly increase LLM API calls and processing time "
- "based on model complexity and the number of entities."
- )
+ orchestrator = WorkflowOrchestrator(
+ root_sqlmodel_class=self.root_sqlmodel_class,
+ llm_client=self.mock_llm_client,
+ use_hierarchical_extraction=True,
+ )
+ self.assertTrue(orchestrator.config.use_hierarchical_extraction)
def test_init_with_provided_analytics_collector(self):
custom_collector = WorkflowAnalyticsCollector()
@@ -98,8 +64,9 @@ def test_init_with_provided_analytics_collector(self):
self.assertIs(orchestrator.analytics_collector, custom_collector)
def test_init_invalid_max_validation_retries(self):
+ # Validation happens in ExtractionConfig
with self.assertRaisesRegex(
- ConfigurationError, "Max validation retries per revision must be at least 1"
+ ValueError, "max_validation_retries_per_revision must be at least 1"
):
WorkflowOrchestrator(
self.root_sqlmodel_class,
@@ -107,144 +74,104 @@ def test_init_invalid_max_validation_retries(self):
max_validation_retries_per_revision=0,
)
- def test_init_invalid_root_sqlmodel_class(self):
- with self.assertRaisesRegex(
- ConfigurationError, "root_sqlmodel_class must be a valid SQLModel class."
- ):
- WorkflowOrchestrator(None, self.mock_llm_client) # type: ignore
-
- class NotASQLModel:
- pass
+ @mock.patch("extrai.core.workflow_orchestrator.ModelRegistry")
+ def test_init_invalid_root_sqlmodel_class(self, MockModelRegistry):
+ MockModelRegistry.side_effect = ConfigurationError(
+ "root_sqlmodel_class must be a valid SQLModel class."
+ )
with self.assertRaisesRegex(
ConfigurationError, "root_sqlmodel_class must be a valid SQLModel class."
):
- WorkflowOrchestrator(NotASQLModel, self.mock_llm_client) # type: ignore
+ WorkflowOrchestrator(None, self.mock_llm_client) # type: ignore
def test_init_invalid_num_llm_revisions(self):
- with self.assertRaisesRegex(
- ConfigurationError, "Number of LLM revisions must be at least 1."
- ):
+ with self.assertRaisesRegex(ValueError, "num_llm_revisions must be at least 1"):
WorkflowOrchestrator(
self.root_sqlmodel_class, self.mock_llm_client, num_llm_revisions=0
)
def test_init_invalid_consensus_threshold(self):
with self.assertRaisesRegex(
- ConfigurationError,
- "Extrai threshold must be between 0.0 and 1.0 inclusive.",
+ ValueError,
+ "consensus_threshold must be between 0.0 and 1.0",
):
WorkflowOrchestrator(
self.root_sqlmodel_class, self.mock_llm_client, consensus_threshold=-0.1
)
with self.assertRaisesRegex(
- ConfigurationError,
- "Extrai threshold must be between 0.0 and 1.0 inclusive.",
+ ValueError,
+ "consensus_threshold must be between 0.0 and 1.0",
):
WorkflowOrchestrator(
self.root_sqlmodel_class, self.mock_llm_client, consensus_threshold=1.1
)
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- def test_init_discover_sqlmodels_fails_generic_exception(
- self, mock_discover_sqlmodels
- ):
- mock_discover_sqlmodels.side_effect = Exception("Discovery boom!")
+ @mock.patch("extrai.core.workflow_orchestrator.ModelRegistry")
+ def test_init_discover_sqlmodels_fails_generic_exception(self, MockModelRegistry):
+ MockModelRegistry.side_effect = ConfigurationError(
+ "Failed to discover SQLModel classes: Discovery boom!"
+ )
+
with self.assertRaisesRegex(
ConfigurationError, "Failed to discover SQLModel classes: Discovery boom!"
):
WorkflowOrchestrator(self.root_sqlmodel_class, self.mock_llm_client)
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- def test_init_discover_sqlmodels_returns_empty(self, mock_discover_sqlmodels):
- mock_discover_sqlmodels.return_value = []
- with self.assertRaisesRegex(
- ConfigurationError,
- "No SQLModel classes were discovered from the root model.",
- ):
- WorkflowOrchestrator(self.root_sqlmodel_class, self.mock_llm_client)
+ @mock.patch("extrai.core.workflow_orchestrator.ModelRegistry")
+ def test_init_discover_sqlmodels_returns_empty(self, MockModelRegistry):
+ MockModelRegistry.side_effect = ConfigurationError(
+ "No SQLModel classes were discovered from the root model."
+ )
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- @mock.patch("extrai.core.workflow_orchestrator.generate_llm_schema_from_models")
- def test_init_generate_llm_schema_returns_empty_string(
- self, mock_generate_llm_schema, mock_discover_sqlmodels
- ):
- mock_discover_sqlmodels.return_value = self.mock_discovered_sqlmodel_classes
- mock_generate_llm_schema.return_value = ""
with self.assertRaisesRegex(
ConfigurationError,
- r"Generated target_json_schema_for_llm \(prompt schema\) is empty.",
+ "No SQLModel classes were discovered from the root model.",
):
WorkflowOrchestrator(self.root_sqlmodel_class, self.mock_llm_client)
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- @mock.patch("extrai.core.workflow_orchestrator.generate_llm_schema_from_models")
- def test_init_generate_llm_schema_returns_invalid_json(
- self, mock_generate_llm_schema, mock_discover_sqlmodels
- ):
- mock_discover_sqlmodels.return_value = self.mock_discovered_sqlmodel_classes
- mock_generate_llm_schema.return_value = "not a valid json"
- with self.assertRaisesRegex(
- ConfigurationError,
- "The internally generated LLM prompt JSON schema is not valid:",
- ):
- WorkflowOrchestrator(self.root_sqlmodel_class, self.mock_llm_client)
+ # These tests about schema generation failure are now part of ModelRegistry tests
+ # But we can verify WorkflowOrchestrator bubbles up the error if ModelRegistry raises it.
+ @mock.patch("extrai.core.workflow_orchestrator.ModelRegistry")
+ def test_init_generate_llm_schema_fails(self, MockModelRegistry):
+ MockModelRegistry.side_effect = ConfigurationError(
+ "Failed to generate the LLM prompt JSON schema"
+ )
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- @mock.patch("extrai.core.workflow_orchestrator.generate_llm_schema_from_models")
- def test_init_generate_llm_schema_fails_generic_exception(
- self, mock_generate_llm_schema, mock_discover_sqlmodels
- ):
- mock_discover_sqlmodels.return_value = self.mock_discovered_sqlmodel_classes
- mock_generate_llm_schema.side_effect = Exception("Schema gen boom!")
with self.assertRaisesRegex(
ConfigurationError,
- "Failed to generate the LLM prompt JSON schema: Schema gen boom!",
+ "Failed to generate the LLM prompt JSON schema",
):
WorkflowOrchestrator(self.root_sqlmodel_class, self.mock_llm_client)
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- @mock.patch("extrai.core.workflow_orchestrator.generate_llm_schema_from_models")
- def test_init_with_invalid_llm_client_in_list(
- self, mock_generate_llm_schema, mock_discover_sqlmodels
- ):
- mock_discover_sqlmodels.return_value = self.mock_discovered_sqlmodel_classes
- mock_generate_llm_schema.return_value = self.mock_prompt_llm_schema_str
+ @mock.patch("extrai.core.workflow_orchestrator.ModelRegistry")
+ def test_init_with_invalid_llm_client_in_list(self, MockModelRegistry):
+ # Validation happens in ExtractionPipeline which is initialized after ModelRegistry
with self.assertRaisesRegex(
- ConfigurationError,
- "All items in llm_client list must be instances of BaseLLMClient.",
+ ValueError,
+ "All items in llm_client list must be instances of BaseLLMClient",
):
WorkflowOrchestrator(
root_sqlmodel_class=self.root_sqlmodel_class,
llm_client=[self.mock_llm_client, "not a client"],
)
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- @mock.patch("extrai.core.workflow_orchestrator.generate_llm_schema_from_models")
- def test_init_with_empty_llm_client_list(
- self, mock_generate_llm_schema, mock_discover_sqlmodels
- ):
- mock_discover_sqlmodels.return_value = self.mock_discovered_sqlmodel_classes
- mock_generate_llm_schema.return_value = self.mock_prompt_llm_schema_str
+ @mock.patch("extrai.core.workflow_orchestrator.ModelRegistry")
+ def test_init_with_empty_llm_client_list(self, MockModelRegistry):
with self.assertRaisesRegex(
- ConfigurationError,
- "llm_client list cannot be empty.",
+ ValueError,
+ "At least one client must be provided",
):
WorkflowOrchestrator(
root_sqlmodel_class=self.root_sqlmodel_class,
llm_client=[],
)
- @mock.patch("extrai.core.workflow_orchestrator.discover_sqlmodels_from_root")
- @mock.patch("extrai.core.workflow_orchestrator.generate_llm_schema_from_models")
- def test_init_with_invalid_llm_client_type(
- self, mock_generate_llm_schema, mock_discover_sqlmodels
- ):
- mock_discover_sqlmodels.return_value = self.mock_discovered_sqlmodel_classes
- mock_generate_llm_schema.return_value = self.mock_prompt_llm_schema_str
+ @mock.patch("extrai.core.workflow_orchestrator.ModelRegistry")
+ def test_init_with_invalid_llm_client_type(self, MockModelRegistry):
with self.assertRaisesRegex(
- ConfigurationError,
- "llm_client must be an instance of BaseLLMClient or a list of them.",
+ ValueError,
+ "llm_client must be an instance of BaseLLMClient or a list of them",
):
WorkflowOrchestrator(
root_sqlmodel_class=self.root_sqlmodel_class,
diff --git a/tests/utils/test_alignment_utils.py b/tests/utils/test_alignment_utils.py
new file mode 100644
index 0000000..20dbff3
--- /dev/null
+++ b/tests/utils/test_alignment_utils.py
@@ -0,0 +1,315 @@
+import unittest
+import io
+import sys
+from extrai.utils.alignment_utils import (
+ normalize_json_revisions,
+ align_entity_arrays,
+ find_best_match,
+ calculate_similarity,
+ compare_values,
+)
+
+
+class TestAlignmentUtils(unittest.TestCase):
+ def test_normalize_json_revisions(self):
+ """Test normalize_json_revisions with various inputs"""
+ cases = [
+ {"name": "empty list", "input": [], "expected": []},
+ {
+ "name": "simple lists reordering",
+ "input": [
+ [{"id": 1, "val": "A"}, {"id": 2, "val": "B"}],
+ [{"id": 2, "val": "B"}, {"id": 1, "val": "A"}],
+ ],
+ "check": lambda res: (
+ len(res) == 2
+ and res[0][0]["id"] == 1
+ and res[1][0]["id"] == 1
+ and res[0][1]["id"] == 2
+ and res[1][1]["id"] == 2
+ ),
+ },
+ {
+ "name": "results wrapper",
+ "input": [
+ {"results": [{"id": 1, "val": "A"}]},
+ {"results": [{"id": 1, "val": "A"}]},
+ ],
+ "check": lambda res: (
+ isinstance(res[0], dict)
+ and "results" in res[0]
+ and res[0]["results"][0]["id"] == 1
+ ),
+ },
+ {
+ "name": "single objects (no-op)",
+ "input": [{"id": 1, "val": "A"}, {"id": 1, "val": "A"}],
+ "expected": [{"id": 1, "val": "A"}, {"id": 1, "val": "A"}],
+ },
+ {
+ "name": "mixed empty revisions",
+ "input": [[{"id": 1}], []],
+ "check": lambda res: len(res) == 2,
+ },
+ {
+ "name": "all empty arrays",
+ "input": [[], [], []],
+ "expected": [[], [], []],
+ },
+ ]
+
+ for case in cases:
+ with self.subTest(case["name"]):
+ result = normalize_json_revisions(case["input"])
+ if "expected" in case:
+ self.assertEqual(result, case["expected"])
+ if "check" in case:
+ self.assertTrue(case["check"](result))
+
+ def test_align_entity_arrays(self):
+ """Test align_entity_arrays with various inputs"""
+ cases = [
+ {"name": "empty arrays", "input": [], "expected": []},
+ {"name": "list of empty arrays", "input": [[], []], "expected": [[], []]},
+ {
+ "name": "same order",
+ "input": [
+ [{"id": 1, "name": "A"}, {"id": 2, "name": "B"}],
+ [{"id": 1, "name": "A"}, {"id": 2, "name": "B"}],
+ ],
+ "check": lambda res: (res[0][0]["id"] == 1 and res[1][0]["id"] == 1),
+ },
+ {
+ "name": "reorder needed",
+ "input": [
+ [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}],
+ [{"id": 2, "name": "Bob"}, {"id": 1, "name": "Alice"}],
+ ],
+ "check": lambda res: (
+ res[0][0]["name"] == "Alice"
+ and res[1][0]["name"] == "Alice"
+ and res[0][1]["name"] == "Bob"
+ and res[1][1]["name"] == "Bob"
+ ),
+ },
+ {
+ "name": "no match found (None)",
+ "input": [
+ [{"id": 1, "val": "X"}],
+ [{"id": 99, "val": "completely different"}],
+ ],
+ "check": lambda res: (res[0][0]["id"] == 1 and res[1][0] is not None),
+ },
+ {
+ "name": "deeply nested objects",
+ "input": [
+ [
+ {"id": 1, "d": {"n": {"v": "deep"}}},
+ {"id": 2, "d": {"n": {"v": "deep"}}},
+ ],
+ [
+ {"id": 2, "d": {"n": {"v": "deep"}}},
+ {"id": 1, "d": {"n": {"v": "deep"}}},
+ ],
+ ],
+ "check": lambda res: (res[0][0]["id"] == 1 and res[1][0]["id"] == 1),
+ },
+ {
+ "name": "lists in objects",
+ "input": [
+ [{"id": 1, "tags": ["a", "b"]}, {"id": 2, "tags": ["x", "y"]}],
+ [{"id": 2, "tags": ["x", "y"]}, {"id": 1, "tags": ["a", "b"]}],
+ ],
+ "check": lambda res: (res[0][0]["id"] == 1 and res[1][0]["id"] == 1),
+ },
+ {
+ "name": "three arrays alignment",
+ "input": [
+ [{"name": "Alice"}, {"name": "Bob"}, {"name": "Charlie"}],
+ [{"name": "Charlie"}, {"name": "Alice"}, {"name": "Bob"}],
+ [{"name": "Bob"}, {"name": "Charlie"}, {"name": "Alice"}],
+ ],
+ "check": lambda res: (
+ all(r[0]["name"] == "Alice" for r in res)
+ and all(r[1]["name"] == "Bob" for r in res)
+ ),
+ },
+ {
+ "name": "preserves reference object identity",
+ "input": [[{"id": 1}, {"id": 2}], [{"id": 2}, {"id": 1}]],
+ "check": lambda res: (
+ # The first revision (reference) objects should be identical
+ res[0][0]["id"] == 1
+ ),
+ },
+ ]
+
+ for case in cases:
+ with self.subTest(case["name"]):
+ result = align_entity_arrays(case["input"])
+ if "expected" in case:
+ self.assertEqual(result, case["expected"])
+ if "check" in case:
+ self.assertTrue(case["check"](result))
+
+ def test_align_different_lengths_warning(self):
+ """Test arrays with different lengths (should truncate to min)"""
+ arr1 = [{"id": 1}, {"id": 2}, {"id": 3}]
+ arr2 = [{"id": 1}, {"id": 2}]
+
+ captured_output = io.StringIO()
+ sys.stdout = captured_output
+
+ try:
+ aligned = align_entity_arrays([arr1, arr2])
+ finally:
+ sys.stdout = sys.__stdout__
+
+ output = captured_output.getvalue()
+
+ self.assertIn("Warning", output)
+ self.assertIn("[3, 2]", output)
+ self.assertEqual(len(aligned[0]), 2)
+ self.assertEqual(len(aligned[1]), 2)
+
+ def test_find_best_match(self):
+ """Test find_best_match with various scenarios"""
+ cases = [
+ {
+ "name": "exact id",
+ "target": {"id": 5, "name": "Test"},
+ "candidates": [
+ {"id": 1, "name": "Other"},
+ {"id": 5, "name": "Test"},
+ {"id": 3, "name": "Another"},
+ ],
+ "used": set(),
+ "expected": 1,
+ },
+ {
+ "name": "skip used indices",
+ "target": {"id": 5, "name": "Test"},
+ "candidates": [{"id": 5}, {"id": 5}],
+ "used": {0},
+ "expected": 1,
+ },
+ {
+ "name": "all used",
+ "target": {"id": 1},
+ "candidates": [{"id": 1}, {"id": 2}],
+ "used": {0, 1},
+ "expected": -1,
+ },
+ {
+ "name": "by similarity",
+ "target": {"name": "Alice", "age": 30, "city": "NYC"},
+ "candidates": [
+ {"name": "Bob", "age": 25, "city": "LA"},
+ {"name": "Alice", "age": 30, "city": "Boston"}, # Better match
+ {"name": "Charlie", "age": 40, "city": "Chicago"},
+ ],
+ "used": set(),
+ "expected": 1,
+ },
+ ]
+
+ for case in cases:
+ with self.subTest(case["name"]):
+ idx = find_best_match(case["target"], case["candidates"], case["used"])
+ self.assertEqual(idx, case["expected"])
+
+ def test_calculate_similarity(self):
+ """Test calculate_similarity with various scenarios"""
+ cases = [
+ ("non-dict strings", "test", "test", 1.0),
+ ("non-dict mismatch", "test", "other", 0.0),
+ ("both none", None, None, 1.0),
+ ("exact id match", {"id": 5, "o": "d"}, {"id": 5, "o": "v"}, 1.0),
+ (
+ "temp id priority",
+ {"_temp_id": "t1", "id": 1},
+ {"_temp_id": "t1", "id": 2},
+ 1.0,
+ ),
+ ("empty dicts", {}, {}, 1.0),
+ ("no common fields", {"a": 1}, {"b": 2}, 0.0),
+ (
+ "partial match",
+ {"a": 1, "b": 2, "c": 3},
+ {"a": 1, "b": 2, "c": 999},
+ lambda s: 0.5 < s < 1.0,
+ ),
+ ("missing fields", {"a": 1, "b": 2}, {"a": 1}, 0.5),
+ (
+ "numeric similarity",
+ {"value": 100},
+ {"value": 105},
+ lambda s: s > calculate_similarity({"value": 100}, {"value": 1000}),
+ ),
+ ]
+
+ for case in cases:
+ name = case[0]
+ val1 = case[1]
+ val2 = case[2]
+ expected = case[3]
+
+ with self.subTest(name):
+ score = calculate_similarity(val1, val2)
+ if callable(expected):
+ self.assertTrue(expected(score))
+ else:
+ self.assertEqual(score, expected)
+
+ def test_compare_values(self):
+ """Test compare_values with various types and scenarios"""
+ cases = [
+ # Basic types
+ ("none both", None, None, 1.0),
+ ("none one", None, "test", 0.0),
+ ("exact int", 42, 42, 1.0),
+ ("exact str", "test", "test", 1.0),
+ ("exact bool", True, True, 1.0),
+ # Strings
+ ("case insensitive", "Hello", "hello", 1.0),
+ ("string fuzzy", "hello world", "hello word", lambda s: 0.8 < s < 1.0),
+ ("string distinct", "abc", "xyz", lambda s: s < 0.3),
+ # Numbers
+ ("int float equal", 10.0, 10, 1.0),
+ ("close numbers", 100, 110, lambda s: 0.8 < s < 1.0),
+ ("far numbers", 10, 1000, lambda s: s < 0.5),
+ # Booleans
+ ("bool mismatch", True, False, 0.0),
+ # Lists
+ ("empty lists", [], [], 1.0),
+ ("one empty list", [], [1], 0.0),
+ ("similar lists", [1, 2, 3], [1, 2, 999], lambda s: 0.5 < s < 1.0),
+ # Dicts
+ ("nested dicts", {"a": 1}, {"a": 1}, 1.0),
+ ("nested partial", {"a": 1}, {"a": 2}, lambda s: 0.0 < s < 1.0),
+ # Mixed Types (The fix we implemented)
+ ("int vs string", 1, "1", 0.0),
+ ("list vs dict", [1], {"a": 1}, 0.0),
+ ("bool vs int", True, 1, 0.0),
+ ]
+
+ for case in cases:
+ name = case[0]
+ val1 = case[1]
+ val2 = case[2]
+ expected = case[3]
+
+ with self.subTest(name):
+ score = compare_values(val1, val2)
+ if callable(expected):
+ self.assertTrue(
+ expected(score), f"Score {score} failed check for {name}"
+ )
+ else:
+ self.assertEqual(
+ score, expected, f"Score {score} != {expected} for {name}"
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/utils/test_llm_output_processing.py b/tests/utils/test_llm_output_processing.py
index 5fab1b6..f812bb1 100644
--- a/tests/utils/test_llm_output_processing.py
+++ b/tests/utils/test_llm_output_processing.py
@@ -76,6 +76,10 @@ def make_json_str(data: Any) -> str:
),
# Unwrapping a non-dict from a single key
({"DynamicKey": "not_a_dict"}, "not_a_dict"),
+ # Single dict in list wrapped with priority key
+ ([{"result": {"name": "test"}}], {"name": "test"}),
+ # Single dict in list wrapped with priority key containing priority key
+ ([{"result": {"data": {"name": "test"}}}], {"name": "test"}),
],
)
def test_unwrap_llm_output(input_data, expected_output):
@@ -394,3 +398,15 @@ def test_process_and_validate_raw_json_schema_validation_fails():
process_and_validate_raw_json(
raw_content, "test_schema_fail", target_json_schema=schema
)
+
+
+def test_process_and_validate_raw_json_no_unwrap():
+ """Tests that attempt_unwrap=False prevents unwrapping."""
+ # A wrapped structure that would normally be unwrapped
+ raw_content = make_json_str({"result": {"key": "value"}})
+
+ # With attempt_unwrap=False, it should return the raw dict
+ result = process_and_validate_raw_json(
+ raw_content, "test_no_unwrap", attempt_unwrap=False
+ )
+ assert result == {"result": {"key": "value"}}
diff --git a/tests/utils/test_rate_limiter.py b/tests/utils/test_rate_limiter.py
new file mode 100644
index 0000000..313a878
--- /dev/null
+++ b/tests/utils/test_rate_limiter.py
@@ -0,0 +1,80 @@
+import pytest
+import asyncio
+import time
+from extrai.utils.rate_limiter import AsyncRateLimiter
+
+
+@pytest.mark.asyncio
+async def test_rate_limiter_basic():
+ """Test basic RPM limiting."""
+ limiter = AsyncRateLimiter(max_capacity=2, period=0.5)
+ start = time.monotonic()
+
+ # First 2 should be immediate
+ await limiter.acquire(1)
+ await limiter.acquire(1)
+
+ # 3rd should wait
+ await limiter.acquire(1)
+ duration = time.monotonic() - start
+
+ assert duration >= 0.5
+
+
+@pytest.mark.asyncio
+async def test_rate_limiter_tokens():
+ """Test token-based limiting."""
+ limiter = AsyncRateLimiter(max_capacity=10, period=0.5)
+ start = time.monotonic()
+
+ # Consuming 5 tokens twice
+ await limiter.acquire(5)
+ await limiter.acquire(5)
+
+ # Consuming 1 more token should wait
+ await limiter.acquire(1)
+ duration = time.monotonic() - start
+
+ assert duration >= 0.5
+
+
+@pytest.mark.asyncio
+async def test_rate_limiter_partial_wait():
+ """Test waiting only for necessary capacity to be freed."""
+ limiter = AsyncRateLimiter(max_capacity=10, period=1.0)
+
+ await limiter.acquire(5) # t=0
+ # Wait half the period
+ await asyncio.sleep(0.5)
+ await limiter.acquire(5) # t=0.5
+
+ # Now usage is 10.
+ # We want 5 more.
+ # The first 5 expire at t=1.0. The second 5 expire at t=1.5.
+ # We are currently at t=0.5.
+ # We should wait until t=1.0 (0.5s wait) to free the first 5.
+
+ start = time.monotonic()
+ await limiter.acquire(5)
+ duration = time.monotonic() - start
+
+ # Should be close to 0.5s
+ # Use a range to be safe against system timing jitter
+ assert 0.4 <= duration <= 0.7
+
+
+@pytest.mark.asyncio
+async def test_rate_limiter_context_manager():
+ """Test context manager usage."""
+ limiter = AsyncRateLimiter(max_capacity=1, period=0.2)
+ start = time.monotonic()
+
+ async with limiter:
+ pass
+
+ # Should wait here
+ async with limiter:
+ pass
+
+ duration = time.monotonic() - start
+ assert duration >= 0.2
diff --git a/tests/utils/test_serialization_utils.py b/tests/utils/test_serialization_utils.py
new file mode 100644
index 0000000..14a292a
--- /dev/null
+++ b/tests/utils/test_serialization_utils.py
@@ -0,0 +1,73 @@
+from decimal import Decimal
+from typing import List, Optional
+from sqlmodel import SQLModel, Field, Relationship
+from extrai.utils.serialization_utils import serialize_sqlmodel_with_relationships
+
+
+# Define test models
+class Item(SQLModel, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ name: str
+ price: Decimal
+ order_id: Optional[int] = Field(default=None, foreign_key="order.id")
+ order: Optional["Order"] = Relationship(back_populates="items")
+
+
+class Order(SQLModel, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ description: str
+ items: List[Item] = Relationship(back_populates="order")
+
+
+def test_serialize_basic():
+ item = Item(name="Test Item", price=Decimal("10.50"), id=1)
+ serialized = serialize_sqlmodel_with_relationships(item)
+ assert serialized["name"] == "Test Item"
+ # Pydantic v2 default for Decimal serialization in mode='json' is string
+ assert serialized["price"] == "10.50"
+ assert serialized["id"] == 1
+
+
+def test_serialize_with_relationships():
+ order = Order(id=1, description="Test Order")
+ item1 = Item(name="Item 1", price=Decimal("10.50"), id=1)
+ item2 = Item(name="Item 2", price=Decimal("20.00"), id=2)
+ order.items = [item1, item2]
+
+ serialized = serialize_sqlmodel_with_relationships(order)
+
+ assert serialized["id"] == 1
+ assert "items" in serialized
+ assert len(serialized["items"]) == 2
+ assert serialized["items"][0]["name"] == "Item 1"
+ assert serialized["items"][0]["price"] == "10.50"
+ assert serialized["items"][1]["name"] == "Item 2"
+
+
+def test_circular_reference():
+ order = Order(id=1, description="Circular Order")
+ item = Item(name="Circular Item", price=Decimal("5.00"), id=1)
+
+ # Create cycle
+ order.items = [item]
+ item.order = order
+
+ serialized = serialize_sqlmodel_with_relationships(order)
+
+ assert serialized["id"] == 1
+ assert serialized["items"][0]["name"] == "Circular Item"
+ # The 'order' inside the item should be empty dict to avoid recursion
+ assert serialized["items"][0]["order"] == {}
+
+
+def test_transient_relationship_access():
+ """Test that we can access relationships on transient objects correctly."""
+ order = Order(id=1, description="Transient")
+ # items not set
+ serialized = serialize_sqlmodel_with_relationships(order)
+ # Should not crash.
+ # SQLModel relationships might return [] or not be present depending on version/config.
+ # But if present, it should be empty list.
+ assert serialized["id"] == 1
+ if "items" in serialized:
+ assert serialized["items"] == []
diff --git a/tests/utils/test_type_mapping.py b/tests/utils/test_type_mapping.py
new file mode 100644
index 0000000..6d18be1
--- /dev/null
+++ b/tests/utils/test_type_mapping.py
@@ -0,0 +1,253 @@
+import datetime
+import enum
+from typing import (
+ List,
+ Dict,
+ Optional,
+ Union,
+ Any,
+)
+from extrai.utils.type_mapping import (
+ get_python_type_str_from_pydantic_annotation,
+ map_sql_type_to_llm_type,
+ _handle_list_type,
+ _handle_dict_type,
+ _handle_union_type,
+ _handle_generic_or_unknown_type,
+ _process_union_types,
+)
+
+# --- Test Data ---
+
+
+class MyEnum(enum.Enum):
+ A = 1
+ B = 2
+
+
+class CustomType:
+ pass
+
+
+class SecretStr:
+ pass
+
+
+# --- Tests for get_python_type_str_from_pydantic_annotation ---
+
+
+def test_get_python_type_str_base_types():
+ assert get_python_type_str_from_pydantic_annotation(int) == "int"
+ assert get_python_type_str_from_pydantic_annotation(str) == "str"
+ assert get_python_type_str_from_pydantic_annotation(bool) == "bool"
+ assert get_python_type_str_from_pydantic_annotation(float) == "float"
+ assert get_python_type_str_from_pydantic_annotation(datetime.date) == "date"
+ assert get_python_type_str_from_pydantic_annotation(datetime.datetime) == "datetime"
+ assert get_python_type_str_from_pydantic_annotation(bytes) == "bytes"
+ assert get_python_type_str_from_pydantic_annotation(Any) == "any"
+ assert get_python_type_str_from_pydantic_annotation(type(None)) == "none"
+
+
+def test_get_python_type_str_complex_types():
+ # List
+ assert get_python_type_str_from_pydantic_annotation(List[int]) == "list[int]"
+ assert get_python_type_str_from_pydantic_annotation(list) == "list"
+
+ # Dict
+ assert (
+ get_python_type_str_from_pydantic_annotation(Dict[str, int]) == "dict[str,int]"
+ )
+ assert get_python_type_str_from_pydantic_annotation(dict) == "dict"
+
+ # Optional
+ assert get_python_type_str_from_pydantic_annotation(Optional[int]) == "int"
+ assert get_python_type_str_from_pydantic_annotation(Optional[str]) == "str"
+
+ # Union
+ # Note: Union[int, str] order is not guaranteed in string representation across versions/implementations strictly,
+ # but the implementation sorts them.
+ assert (
+ get_python_type_str_from_pydantic_annotation(Union[int, str])
+ == "union[int,str]"
+ )
+ assert (
+ get_python_type_str_from_pydantic_annotation(Union[str, int])
+ == "union[int,str]"
+ )
+ assert (
+ get_python_type_str_from_pydantic_annotation(Union[int, None]) == "int"
+ ) # Same as Optional
+
+ # Nested
+ assert (
+ get_python_type_str_from_pydantic_annotation(List[Dict[str, Any]])
+ == "list[dict[str,any]]"
+ )
+
+
+def test_get_python_type_str_enum():
+ assert get_python_type_str_from_pydantic_annotation(MyEnum) == "enum"
+
+
+def test_get_python_type_str_custom_and_fallback():
+ # SecretStr simulation (by name)
+ # The code checks hasattr(annotation, "__name__") and name_lower == "secretstr"
+ # To test this we can pass the class itself if it matches, or mock it.
+ # Actually SecretStr is usually pydantic.SecretStr.
+ # Let's create a dummy class with that name.
+
+ class SecretStr:
+ pass
+
+ assert get_python_type_str_from_pydantic_annotation(SecretStr) == "str"
+
+ # Custom Type
+ assert get_python_type_str_from_pydantic_annotation(CustomType) == "customtype"
+
+ # Fallback with typing.
+ # The fallback code does: str(annotation).lower().replace("typing.", "")
+ # and handles "~" prefix.
+ # We can pass something that doesn't match other rules.
+ assert get_python_type_str_from_pydantic_annotation("JustAString") == "justastring"
+
+ # Test ForwardRef style string (starts with ~)
+ assert get_python_type_str_from_pydantic_annotation("~ForwardRef") == "forwardref"
+
+
+def test_process_union_types_edge_case():
+ # Test _process_union_types with empty args
+ # This is hard to trigger via get_python_type_str because Union[] is invalid syntax usually,
+ # but we can call the helper directly.
+ assert _process_union_types([], lambda x: x) == "union"
+
+ # Test deduplication and sorting: 'a', 'b', 'a' -> 'a', 'b'
+ assert _process_union_types(["b", "a", "a"], lambda x: x) == "union[a,b]"
+
+ # Test single element after processing
+ assert _process_union_types(["a", "a"], lambda x: x) == "a"
+
+ # Test None filtering ("none" string)
+ assert _process_union_types(["a", "none"], lambda x: x) == "a"
+
+
+# --- Tests for map_sql_type_to_llm_type and helpers ---
+
+
+def test_map_sql_type_simple():
+ assert map_sql_type_to_llm_type("INTEGER", "int") == "integer"
+ assert map_sql_type_to_llm_type("VARCHAR", "str") == "string"
+ assert map_sql_type_to_llm_type("BOOLEAN", "bool") == "boolean"
+ assert map_sql_type_to_llm_type("FLOAT", "float") == "number (float/decimal)"
+ assert map_sql_type_to_llm_type("DATE", "date") == "string (date format)"
+ assert (
+ map_sql_type_to_llm_type("DATETIME", "datetime") == "string (datetime format)"
+ )
+ assert map_sql_type_to_llm_type("BLOB", "bytes") == "string (base64 encoded)"
+ assert map_sql_type_to_llm_type("ENUM", "enum") == "string (enum)"
+ assert map_sql_type_to_llm_type("ANY", "any") == "any"
+ assert map_sql_type_to_llm_type("NONE", "none") == "null"
+
+
+def test_handle_list_type():
+ assert _handle_list_type("list[int]") == "array[integer]"
+ assert _handle_list_type("list[str]") == "array[string]"
+ assert _handle_list_type("notalist") is None
+
+ # Integration via main function
+ assert map_sql_type_to_llm_type("", "list[int]") == "array[integer]"
+
+
+def test_handle_dict_type():
+ assert _handle_dict_type("dict[str,int]") == "object[string,integer]"
+ assert _handle_dict_type("dict[str, str]") == "object[string,string]" # spacing
+ assert _handle_dict_type("notadict") is None
+
+ # Test ValueError handling (malformed dict string)
+ # The code splits by ",", 1. If no comma, it raises ValueError and returns "object"
+ assert _handle_dict_type("dict[int]") == "object"
+
+ # Integration via main function
+ assert map_sql_type_to_llm_type("", "dict[str,int]") == "object[string,integer]"
+
+
+def test_handle_union_type():
+ assert (
+ _handle_union_type("union[int,str]") == "union[integer,string]"
+ ) # sorted: integer, string -> integer, string?
+ # int->integer, str->string. sorted(['integer', 'string']) -> ['integer', 'string']
+
+ assert _handle_union_type("union[str,int]") == "union[integer,string]"
+
+ # Single type in union
+ assert _handle_union_type("union[int]") == "integer"
+
+ # Empty parts -> "any"
+ assert _handle_union_type("union[]") == "any"
+ assert _handle_union_type("union[ ]") == "any"
+
+ assert _handle_union_type("notaunion") is None
+
+ # Integration
+ assert map_sql_type_to_llm_type("", "union[int,str]") == "union[integer,string]"
+
+
+def test_handle_generic_or_unknown_type():
+ # list
+ assert _handle_generic_or_unknown_type("list", "") == "array"
+ # list with text in sql -> None (fallback)
+ assert _handle_generic_or_unknown_type("list", "text[]") is None
+
+ # dict
+ assert _handle_generic_or_unknown_type("dict", "") == "object"
+
+ # unknown
+ assert _handle_generic_or_unknown_type("unknown_stuff", "json") == "object"
+ assert _handle_generic_or_unknown_type("unknown_stuff", "array") == "array"
+ assert _handle_generic_or_unknown_type("unknown_stuff", "other") == "string"
+
+ assert _handle_generic_or_unknown_type("other", "") is None
+
+
+def test_sql_keyword_fallback():
+ # Only if python type not handled above
+ assert map_sql_type_to_llm_type("int", "other") == "integer"
+ assert map_sql_type_to_llm_type("text", "other") == "string"
+ assert map_sql_type_to_llm_type("json", "other") == "object"
+ assert map_sql_type_to_llm_type("array", "other") == "array"
+
+
+def test_final_fallback():
+ assert map_sql_type_to_llm_type("nomatch", "nomatch") == "string"
+
+
+def test_map_sql_type_generic_integration():
+ # Trigger _handle_generic_or_unknown_type via map_sql_type_to_llm_type
+ # "list" -> "array" (if sql type is not text)
+ assert map_sql_type_to_llm_type("", "list") == "array"
+
+ # "unknown" with "array" in sql -> "array"
+ assert map_sql_type_to_llm_type("ARRAY", "unknown_type") == "array"
+
+
+# --- Additional Coverage for Origin Handlers ---
+
+
+def test_origin_handler_list_variations():
+ # Test list vs List origin
+ assert get_python_type_str_from_pydantic_annotation(List[int]) == "list[int]"
+ assert get_python_type_str_from_pydantic_annotation(list) == "list"
+ # To test 'args' presence check for list/List, we relied on List[int] vs list.
+
+
+def test_origin_handler_dict_variations():
+ # Test dict vs Dict origin
+ assert (
+ get_python_type_str_from_pydantic_annotation(Dict[str, int]) == "dict[str,int]"
+ )
+ assert get_python_type_str_from_pydantic_annotation(dict) == "dict"
+
+
+def test_origin_handler_optional_none():
+ # Optional handler: if args[0] is type(None) -> "none"
+ # Optional[None] is basically NoneType
+ assert get_python_type_str_from_pydantic_annotation(Optional[type(None)]) == "none"