From 64c409941aa0864f3a96f970095f721a54defc81 Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Tue, 27 Jan 2026 00:06:26 +0100 Subject: [PATCH 1/9] fixed relation checker --- graflo/db/tigergraph/conn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/graflo/db/tigergraph/conn.py b/graflo/db/tigergraph/conn.py index d6815c4..e4bffa6 100644 --- a/graflo/db/tigergraph/conn.py +++ b/graflo/db/tigergraph/conn.py @@ -2182,7 +2182,9 @@ def _define_schema_local(self, schema: Schema) -> None: expected_name = v.name expected_vertex_types.add(expected_name) - expected_edge_types = {e.relation for e in edges_to_create if e.relation} + expected_edge_types = { + e.relation_dbname for e in edges_to_create if e.relation + } # Convert to sets for case-insensitive comparison # TigerGraph may capitalize vertex names, so compare case-insensitively From 9e425cee33e3f3e6b6179dc6a3278718434e9434 Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Tue, 27 Jan 2026 11:31:17 +0100 Subject: [PATCH 2/9] fixed resources for vertex-like tables --- graflo/db/postgres/resource_mapping.py | 84 +++++++++++++++++++++++--- graflo/hq/inferencer.py | 4 +- graflo/hq/sanitizer.py | 43 ++++++------- 3 files changed, 95 insertions(+), 36 deletions(-) diff --git a/graflo/db/postgres/resource_mapping.py b/graflo/db/postgres/resource_mapping.py index c682849..b59ef48 100644 --- a/graflo/db/postgres/resource_mapping.py +++ b/graflo/db/postgres/resource_mapping.py @@ -5,6 +5,7 @@ """ import logging +from collections import defaultdict from graflo.architecture.resource import Resource from graflo.architecture.vertex import VertexConfig @@ -14,6 +15,7 @@ detect_separator, split_by_separator, ) +from ...architecture import EdgeConfig logger = logging.getLogger(__name__) @@ -33,20 +35,36 @@ def __init__(self, fuzzy_threshold: float = 0.8): """ self.fuzzy_threshold = fuzzy_threshold - def create_vertex_resource(self, table_name: str, vertex_name: str) -> Resource: + def create_vertex_resource( + self, + table_name: str, + vertex_name: str, + vertex_attribute_mappings: defaultdict[str, dict[str, str]], + ) -> Resource: """Create a Resource for a vertex table. Args: table_name: Name of the PostgreSQL table vertex_name: Name of the vertex type (typically same as table_name) + vertex_attribute_mappings: Dictionary mapping vertex names to field mappings + (original_field -> sanitized_field) for transformations Returns: Resource: Resource configured to ingest vertex data """ - # Create apply list with VertexActor - # The actor wrapper will interpret {"vertex": vertex_name} as VertexActor apply = [{"vertex": vertex_name}] + field_mappings = vertex_attribute_mappings[vertex_name] + if field_mappings: + apply.append( + { + "map": field_mappings, + } + ) + logger.debug( + f"Added field mappings for vertex '{vertex_name}': {field_mappings}" + ) + resource = Resource( resource_name=table_name, apply=apply, @@ -63,6 +81,7 @@ def create_edge_resource( edge_table_info: EdgeTableInfo, vertex_config: VertexConfig, matcher: FuzzyMatcher, + vertex_attribute_mappings: defaultdict[str, dict[str, str]], ) -> Resource: """Create a Resource for an edge table. @@ -70,6 +89,8 @@ def create_edge_resource( edge_table_info: Edge table information from introspection vertex_config: Vertex configuration for source/target validation matcher: Optional fuzzy matcher for better performance (with caching enabled) + vertex_attribute_mappings: Dictionary mapping vertex names to field mappings + (original_field -> sanitized_field) for transformations Returns: Resource: Resource configured to ingest edge data @@ -126,19 +147,56 @@ def create_edge_resource( # avoiding attribute collisions between different vertex types apply = [] + # Get all column names from the edge table for mapping + edge_column_names = {col.name for col in edge_table_info.columns} + # First mapping: map source foreign key column to source vertex's primary key field if source_column: + source_map = {source_column: source_pk_field} + # Add attribute mappings for the source vertex + # These mappings transform original field names to sanitized field names + source_attr_mappings = vertex_attribute_mappings[source_table] + # Add mappings for columns that match original field names that were sanitized + for orig_field, sanitized_field in source_attr_mappings.items(): + # Only add mapping if: + # 1. The column exists in the edge table + # 2. It's not already mapped (e.g., as the source_column -> source_pk_field) + # 3. The sanitized field is different from the original (actual sanitization occurred) + if ( + orig_field in edge_column_names + and orig_field != source_column + and orig_field != sanitized_field + ): + source_map[orig_field] = sanitized_field + source_map_config = { "target_vertex": source_table, - "map": {source_column: source_pk_field}, + "map": source_map, } apply.append(source_map_config) # Second mapping: map target foreign key column to target vertex's primary key field if target_column: + target_map = {target_column: target_pk_field} + # Add attribute mappings for the target vertex + # These mappings transform original field names to sanitized field names + target_attr_mappings = vertex_attribute_mappings[target_table] + # Add mappings for columns that match original field names that were sanitized + for orig_field, sanitized_field in target_attr_mappings.items(): + # Only add mapping if: + # 1. The column exists in the edge table + # 2. It's not already mapped (e.g., as the target_column -> target_pk_field) + # 3. The sanitized field is different from the original (actual sanitization occurred) + if ( + orig_field in edge_column_names + and orig_field != target_column + and orig_field != sanitized_field + ): + target_map[orig_field] = sanitized_field + target_map_config = { "target_vertex": target_table, - "map": {target_column: target_pk_field}, + "map": target_map, } apply.append(target_map_config) @@ -222,13 +280,15 @@ def _infer_pk_field_from_column( ) return "id" - def map_tables_to_resources( + def create_resources_from_tables( self, introspection_result: SchemaIntrospectionResult, vertex_config: VertexConfig, + edge_config: EdgeConfig, + vertex_attribute_mappings: defaultdict[str, dict[str, str]], fuzzy_threshold: float | None = None, ) -> list[Resource]: - """Map all PostgreSQL tables to Resources. + """Create Resources from PostgreSQL tables. Creates Resources for both vertex and edge tables, enabling ingestion of the entire database schema. @@ -236,8 +296,10 @@ def map_tables_to_resources( Args: introspection_result: Result from PostgresConnection.introspect_schema() vertex_config: Inferred vertex configuration - sanitizer: carries mappiings + edge_config: Inferred edge configuration fuzzy_threshold: Similarity threshold for fuzzy matching (0.0 to 1.0) + vertex_attribute_mappings: Dictionary mapping vertex names to field mappings + (original_field -> sanitized_field) for transformations Returns: list[Resource]: List of Resources for all tables @@ -257,7 +319,9 @@ def map_tables_to_resources( for table_info in vertex_tables: table_name = table_info.name vertex_name = table_name # Use table name as vertex name - resource = self.create_vertex_resource(table_name, vertex_name) + resource = self.create_vertex_resource( + table_name, vertex_name, vertex_attribute_mappings + ) resources.append(resource) # Map edge tables to resources @@ -265,7 +329,7 @@ def map_tables_to_resources( for edge_table_info in edge_tables: try: resource = self.create_edge_resource( - edge_table_info, vertex_config, matcher + edge_table_info, vertex_config, matcher, vertex_attribute_mappings ) resources.append(resource) except ValueError as e: diff --git a/graflo/hq/inferencer.py b/graflo/hq/inferencer.py index 0432d67..76ca6e8 100644 --- a/graflo/hq/inferencer.py +++ b/graflo/hq/inferencer.py @@ -72,9 +72,11 @@ def create_resources( Returns: list[Resource]: List of Resources for PostgreSQL tables """ - return self.mapper.map_tables_to_resources( + return self.mapper.create_resources_from_tables( introspection_result, schema.vertex_config, + schema.edge_config, + vertex_attribute_mappings=self.sanitizer.vertex_attribute_mappings, fuzzy_threshold=self.mapper.fuzzy_threshold, ) diff --git a/graflo/hq/sanitizer.py b/graflo/hq/sanitizer.py index 8f93aa5..881bbdd 100644 --- a/graflo/hq/sanitizer.py +++ b/graflo/hq/sanitizer.py @@ -9,6 +9,7 @@ import logging from collections import Counter from typing import TYPE_CHECKING +from collections import defaultdict from graflo.architecture.edge import Edge from graflo.architecture.schema import Schema @@ -44,7 +45,9 @@ def __init__(self, db_flavor: DBFlavor): """ self.db_flavor = db_flavor self.reserved_words = load_reserved_words(db_flavor) - self.attribute_mappings: dict[str, str] = {} + self.vertex_attribute_mappings: defaultdict[str, dict[str, str]] = defaultdict( + dict + ) self.vertex_mappings: dict[str, str] = {} def sanitize(self, schema: Schema) -> Schema: @@ -84,34 +87,24 @@ def sanitize(self, schema: Schema) -> Schema: for vertex in schema.vertex_config.vertices: for field in vertex.fields: original_name = field.name - if original_name not in self.attribute_mappings: - sanitized_name = sanitize_attribute_name( - original_name, self.reserved_words - ) - if sanitized_name != original_name: - self.attribute_mappings[original_name] = sanitized_name - logger.debug( - f"Sanitizing field name '{original_name}' -> '{sanitized_name}' " - f"in vertex '{vertex.name}'" - ) - else: - self.attribute_mappings[original_name] = original_name - else: - sanitized_name = self.attribute_mappings[original_name] - - # Update field name if it changed + sanitized_name = sanitize_attribute_name( + original_name, self.reserved_words + ) if sanitized_name != original_name: + self.vertex_attribute_mappings[vertex.name][original_name] = ( + sanitized_name + ) + logger.debug( + f"Sanitizing field name '{original_name}' -> '{sanitized_name}' " + f"in vertex '{vertex.name}'" + ) field.name = sanitized_name - # Update index field references if they were sanitized for index in vertex.indexes: - updated_fields = [] - for field_name in index.fields: - sanitized_field_name = self.attribute_mappings.get( - field_name, field_name - ) - updated_fields.append(sanitized_field_name) - index.fields = updated_fields + index.fields = [ + self.vertex_attribute_mappings[vertex.name].get(item, item) + for item in index.fields + ] vertex_names = {vertex.dbname for vertex in schema.vertex_config.vertices} From a331227b997957e3597d7850ed977a8cd8d97275 Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Tue, 27 Jan 2026 12:08:16 +0100 Subject: [PATCH 3/9] sanitized None in payload; removed individual upsert --- graflo/db/tigergraph/conn.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/graflo/db/tigergraph/conn.py b/graflo/db/tigergraph/conn.py index e4bffa6..1e76dfb 100644 --- a/graflo/db/tigergraph/conn.py +++ b/graflo/db/tigergraph/conn.py @@ -3014,7 +3014,7 @@ def _generate_upsert_payload( # 4. Format attributes for TigerGraph REST++ API # TigerGraph requires attribute values to be wrapped in {"value": ...} formatted_attributes = { - k: {"value": v} for k, v in clean_record.items() + k: {"value": v} for k, v in clean_record.items() if v is not None } # 5. Add the record attributes to the map using the composite ID as the key @@ -3160,8 +3160,6 @@ def upsert_docs_batch(self, docs, class_name, match_keys, **kwargs): logger.error( f"Error upserting vertices to {class_name}: {result.get('message')}" ) - # Fallback to individual operations - self._fallback_individual_upsert(docs, class_name, match_keys) else: num_vertices = len(payload["vertices"][class_name]) logger.debug( @@ -3171,24 +3169,6 @@ def upsert_docs_batch(self, docs, class_name, match_keys, **kwargs): except Exception as e: logger.error(f"Error upserting vertices to {class_name}: {e}") - # Fallback to individual operations - self._fallback_individual_upsert(docs, class_name, match_keys) - - def _fallback_individual_upsert(self, docs, class_name, match_keys): - """Fallback method for individual vertex upserts.""" - for doc in docs: - try: - vertex_id = self._extract_id(doc, match_keys) - if vertex_id: - clean_doc = self._clean_document(doc) - # Serialize datetime objects before passing to REST API - # REST API expects JSON-serializable data - serialized_doc = json.loads( - json.dumps(clean_doc, default=_json_serializer) - ) - self._upsert_vertex(class_name, vertex_id, serialized_doc) - except Exception as e: - logger.error(f"Error upserting individual vertex {vertex_id}: {e}") def _generate_edge_upsert_payloads( self, From 618137c63673109f0ab5b8ae49ea015c653a4b96 Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Tue, 27 Jan 2026 12:20:46 +0100 Subject: [PATCH 4/9] sanitized None in payload for edges --- graflo/db/tigergraph/conn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graflo/db/tigergraph/conn.py b/graflo/db/tigergraph/conn.py index 1e76dfb..93fa4c8 100644 --- a/graflo/db/tigergraph/conn.py +++ b/graflo/db/tigergraph/conn.py @@ -3014,7 +3014,7 @@ def _generate_upsert_payload( # 4. Format attributes for TigerGraph REST++ API # TigerGraph requires attribute values to be wrapped in {"value": ...} formatted_attributes = { - k: {"value": v} for k, v in clean_record.items() if v is not None + k: {"value": v} for k, v in clean_record.items() if v } # 5. Add the record attributes to the map using the composite ID as the key @@ -3226,7 +3226,7 @@ def _generate_edge_upsert_payloads( # Clean and format edge attributes clean_edge_props = self._clean_document(edge_props) formatted_attributes = { - k: {"value": v} for k, v in clean_edge_props.items() + k: {"value": v} for k, v in clean_edge_props.items() if v } # Group by (source_id, target_id, edge_type) From 7a2078272a537576003b23db94ec95916ee8a116 Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Tue, 27 Jan 2026 12:44:30 +0100 Subject: [PATCH 5/9] better gsql query batching --- graflo/db/tigergraph/conn.py | 190 ++++++++++++++++++++++------------- 1 file changed, 122 insertions(+), 68 deletions(-) diff --git a/graflo/db/tigergraph/conn.py b/graflo/db/tigergraph/conn.py index 93fa4c8..6ebf31f 100644 --- a/graflo/db/tigergraph/conn.py +++ b/graflo/db/tigergraph/conn.py @@ -1934,6 +1934,105 @@ def _get_edge_group_create_statement(self, edges: list[Edge]) -> str: # No attributes - FROM/TO section is the last thing return f"ADD DIRECTED EDGE {relation} (\n{all_from_to}\n )" + def _batch_schema_statements( + self, schema_change_stmts: list[str], graph_name: str, max_job_size: int + ) -> list[list[str]]: + """Batch schema change statements into groups that fit within max_job_size. + + Intelligently merges small statements together while ensuring no batch + exceeds the maximum job size limit. + + Args: + schema_change_stmts: List of schema change statements to batch + graph_name: Name of the graph (used for size estimation) + max_job_size: Maximum size in characters for a single job + + Returns: + List of batches, where each batch is a list of statements + """ + if not schema_change_stmts: + return [] + + # Calculate base overhead for a job + # Use worst-case job name length (multi-batch format) for conservative estimation + worst_case_job_name = ( + f"schema_change_{graph_name}_batch_999" # Use large number for worst case + ) + base_template = ( + f"USE GRAPH {graph_name}\n" + f"CREATE SCHEMA_CHANGE JOB {worst_case_job_name} FOR GRAPH {graph_name} {{\n" + f"}}\n" + f"RUN SCHEMA_CHANGE JOB {worst_case_job_name}" + ) + base_overhead = len(base_template) + + # Each statement adds 5 characters: first gets " " (4) + ";" (1), + # subsequent get ";\n " (5) between statements, final ";" (1) is included + # For N statements: 4 (first indent) + (N-1)*5 (separators) + 1 (final semicolon) = 5*N + + def estimate_batch_size(stmts: list[str]) -> int: + """Estimate the total size of a batch of statements.""" + if not stmts: + return base_overhead + total_stmt_size = sum(len(stmt) for stmt in stmts) + return base_overhead + total_stmt_size + 5 * len(stmts) + + # Calculate total estimated size for all statements + num_statements = len(schema_change_stmts) + total_stmt_size = sum(len(stmt) for stmt in schema_change_stmts) + estimated_size = base_overhead + total_stmt_size + 5 * num_statements + + # If everything fits in one batch, return single batch + if estimated_size <= max_job_size: + logger.info( + f"Applying schema change as single job (estimated size: {estimated_size} chars)" + ) + return [schema_change_stmts] + + # Need to split into multiple batches + # Strategy: Use a greedy bin-packing approach that merges small statements + # Start by creating batches, trying to pack as many statements as possible + # into each batch without exceeding max_job_size + + batches: list[list[str]] = [] + + # Sort statements by size (smallest first) to help pack efficiently + # We'll process them in order and try to add to existing batches + stmt_with_size = [(stmt, len(stmt)) for stmt in schema_change_stmts] + stmt_with_size.sort(key=lambda x: x[1]) # Sort by statement size + + for stmt, stmt_size in stmt_with_size: + # Calculate overhead for adding this statement: 5 chars (indent + semicolon) + stmt_overhead = 5 + + # Try to add to an existing batch + added = False + for batch in batches: + current_batch_size = estimate_batch_size(batch) + # Check if adding this statement would exceed the limit + if current_batch_size + stmt_size + stmt_overhead <= max_job_size: + batch.append(stmt) + added = True + break + + # If couldn't add to existing batch, create a new one + if not added: + # Check if statement itself is too large + single_stmt_size = estimate_batch_size([stmt]) + if single_stmt_size > max_job_size: + logger.warning( + f"Statement exceeds max_job_size ({single_stmt_size} > {max_job_size}). " + f"Will attempt to execute anyway, but may fail." + ) + batches.append([stmt]) + + logger.info( + f"Large schema detected (estimated size: {estimated_size} chars). " + f"Splitting into {len(batches)} batches." + ) + + return batches + @_wrap_tg_exception def _define_schema_local(self, schema: Schema) -> None: """Define TigerGraph schema locally for the current graph using a SCHEMA_CHANGE job. @@ -1951,14 +2050,15 @@ def _define_schema_local(self, schema: Schema) -> None: vertex_config = schema.vertex_config edge_config = schema.edge_config - schema_change_stmts = [] + vertex_stmts = [] + edge_stmts = [] # Vertices for vertex in vertex_config.vertices: # Validate vertex name _validate_tigergraph_schema_name(vertex.dbname, "vertex") stmt = self._get_vertex_add_statement(vertex, vertex_config) - schema_change_stmts.append(stmt) + vertex_stmts.append(stmt) # Edges - group by relation since TigerGraph requires edges of the same type # to be created in a single statement with multiple FROM/TO pairs @@ -1978,78 +2078,32 @@ def _define_schema_local(self, schema: Schema) -> None: # Create one statement per relation with all FROM/TO pairs for relation, edge_group in edges_by_relation.items(): stmt = self._get_edge_group_create_statement(edge_group) - schema_change_stmts.append(stmt) + edge_stmts.append(stmt) - if not schema_change_stmts: + if not vertex_stmts and not edge_stmts: logger.debug(f"No schema changes to apply for graph '{graph_name}'") return # Estimate the size of the GSQL command to determine if we need to split it # Large SCHEMA_CHANGE JOBs (>30k chars) can cause parser failures with misleading errors # like "Missing return statement" (which is actually a parser size limit issue) - # We'll split into batches based on configurable max_job_size (default: 1000) - MAX_JOB_SIZE = self.config.max_job_size - - # Calculate accurate size estimation - # Actual format: - # USE GRAPH {graph_name} - # CREATE SCHEMA_CHANGE JOB {job_name} FOR GRAPH {graph_name} { - # stmt1; - # stmt2; - # ... - # } - # RUN SCHEMA_CHANGE JOB {job_name} - # - # For N statements: - # - Base overhead: USE GRAPH line + CREATE line + closing brace + RUN line + newlines - # - Statement overhead: first gets " " + ";" (5 chars), others get ";\n " (5 chars each) - # - Total: base + sum(len(stmt)) + 5*N - - # Use worst-case job name length (multi-batch format) for conservative estimation - worst_case_job_name = ( - f"schema_change_{graph_name}_batch_999" # Use large number for worst case - ) - base_template = ( - f"USE GRAPH {graph_name}\n" - f"CREATE SCHEMA_CHANGE JOB {worst_case_job_name} FOR GRAPH {graph_name} {{\n" - f"}}\n" - f"RUN SCHEMA_CHANGE JOB {worst_case_job_name}" - ) - base_overhead = len(base_template) - - # Each statement adds 5 characters: first gets " " (4) + ";" (1), - # subsequent get ";\n " (5) between statements, final ";" (1) is included - # For N statements: 4 (first indent) + (N-1)*5 (separators) + 1 (final semicolon) = 5*N - num_statements = len(schema_change_stmts) - total_stmt_size = sum(len(stmt) for stmt in schema_change_stmts) - estimated_size = base_overhead + total_stmt_size + 5 * num_statements - - if estimated_size <= MAX_JOB_SIZE: - # Small enough for a single job - batches = [schema_change_stmts] - logger.info( - f"Applying schema change as single job (estimated size: {estimated_size} chars)" + # We'll split into batches based on configurable max_job_size + # Batch vertices and edges separately, then concatenate + vertex_batches = ( + self._batch_schema_statements( + vertex_stmts, graph_name, self.config.max_job_size ) - else: - # Split into multiple batches - # Calculate how many statements per batch - # For a batch of M statements: base_overhead + sum(len(stmt)) + 5*M <= MAX_JOB_SIZE - # So: sum(len(stmt)) + 5*M <= MAX_JOB_SIZE - base_overhead - # If avg_stmt_size = sum(len(stmt)) / M, then: M * (avg_stmt_size + 5) <= MAX_JOB_SIZE - base_overhead - avg_stmt_size = ( - total_stmt_size / num_statements if num_statements > 0 else 0 - ) - available_space = MAX_JOB_SIZE - base_overhead - stmts_per_batch = max(1, int(available_space / (avg_stmt_size + 5))) - - batches = [] - for i in range(0, len(schema_change_stmts), stmts_per_batch): - batches.append(schema_change_stmts[i : i + stmts_per_batch]) - - logger.info( - f"Large schema detected (estimated size: {estimated_size} chars). " - f"Splitting into {len(batches)} batches of ~{stmts_per_batch} statements each." + if vertex_stmts + else [] + ) + edge_batches = ( + self._batch_schema_statements( + edge_stmts, graph_name, self.config.max_job_size ) + if edge_stmts + else [] + ) + batches = vertex_batches + edge_batches # Execute batches sequentially for batch_idx, batch_stmts in enumerate(batches): @@ -2087,10 +2141,10 @@ def _define_schema_local(self, schema: Schema) -> None: actual_size = len(full_gsql) # Safety check: warn if actual size exceeds limit (indicates estimation error) - if actual_size > MAX_JOB_SIZE: + if actual_size > self.config.max_job_size: logger.warning( - f"Batch {batch_idx + 1} actual size ({actual_size} chars) exceeds limit ({MAX_JOB_SIZE} chars). " - f"This may cause parser errors. Consider reducing MAX_JOB_SIZE or improving estimation." + f"Batch {batch_idx + 1} actual size ({actual_size} chars) exceeds limit ({self.config.max_job_size} chars). " + f"This may cause parser errors. Consider reducing max_job_size or improving estimation." ) logger.info( From 14f9858ee28a2555c30a49be840a858c8834a1f2 Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Tue, 27 Jan 2026 13:08:15 +0100 Subject: [PATCH 6/9] improved GraphEngine logic --- examples/5-ingest-postgres/ingest.py | 16 +++++-- graflo/cli/ingest.py | 58 +++++++++++++++++-------- graflo/hq/caster.py | 20 ++------- graflo/hq/graph_engine.py | 63 +++++++++++++++++++++------- test/conftest.py | 37 ++++++++++++++-- 5 files changed, 137 insertions(+), 57 deletions(-) diff --git a/examples/5-ingest-postgres/ingest.py b/examples/5-ingest-postgres/ingest.py index 86d1ff7..f269e6b 100644 --- a/examples/5-ingest-postgres/ingest.py +++ b/examples/5-ingest-postgres/ingest.py @@ -88,10 +88,8 @@ logger.warning("Assuming PostgreSQL database is already initialized") # Step 3: Create GraphEngine to orchestrate schema inference, pattern creation, and ingestion -# GraphEngine coordinates all operations: schema inference, pattern mapping, and data ingestion -engine = GraphEngine( - target_db_flavor=db_flavor, ingestion_params=IngestionParams(clean_start=True) -) +# GraphEngine coordinates all operations: schema inference, pattern mapping, schema definition, and data ingestion +engine = GraphEngine(target_db_flavor=db_flavor) # Step 3.1: Infer Schema from PostgreSQL database structure # This automatically detects vertex-like and edge-like tables based on: @@ -118,12 +116,22 @@ # Connection is automatically managed inside create_patterns() patterns = engine.create_patterns(postgres_conf, schema_name="public") +# Step 4.5: Define schema in target database +# This creates/initializes the database schema (if necessary) +# Some databases don't require explicit schema definition, but this ensures proper initialization +engine.define_schema( + schema=schema, + output_config=conn_conf, + clean_start=True, # Clean existing data before defining schema +) + # Step 5: Ingest data using GraphEngine # Note: ingestion will create its own PostgreSQL connections per table internally engine.ingest( schema=schema, output_config=conn_conf, patterns=patterns, + ingestion_params=IngestionParams(clean_start=False), # Schema already defined above ) print("\n" + "=" * 80) diff --git a/graflo/cli/ingest.py b/graflo/cli/ingest.py index 3f8cad6..5dc8c0f 100644 --- a/graflo/cli/ingest.py +++ b/graflo/cli/ingest.py @@ -27,9 +27,11 @@ import click from suthing import FileHandle -from graflo import Caster, DataSourceRegistry, Patterns, Schema -from graflo.db.connection.onto import DBConfig +from graflo import DataSourceRegistry, Patterns, Schema +from graflo.db.connection.onto import DBConfig, DBType from graflo.data_source import DataSourceFactory +from graflo.hq import GraphEngine +from graflo.onto import DBFlavor logger = logging.getLogger(__name__) @@ -136,17 +138,42 @@ def ingest( schema.fetch_resource() + # Determine DB flavor from connection config + db_type = conn_conf.connection_type + # Map DBType to DBFlavor (they have the same values for graph databases) + db_flavor = ( + DBFlavor(db_type.value) + if db_type + in ( + DBType.ARANGO, + DBType.NEO4J, + DBType.TIGERGRAPH, + DBType.FALKORDB, + DBType.MEMGRAPH, + ) + else DBFlavor.ARANGO + ) + + # Create GraphEngine for the full workflow + engine = GraphEngine(target_db_flavor=db_flavor) + # Create ingestion params with CLI arguments from graflo.hq.caster import IngestionParams ingestion_params = IngestionParams( n_cores=n_cores, + batch_size=batch_size, + init_only=init_only, + limit_files=limit_files, ) - caster = Caster( - schema, - ingestion_params=ingestion_params, - ) + # Define schema first (if clean_start is requested) + if fresh_start: + engine.define_schema( + schema=schema, + output_config=conn_conf, + clean_start=True, + ) # Validate that either source_path or data_source_config_path is provided if data_source_config_path is None and source_path is None: @@ -174,25 +201,20 @@ def ingest( ) registry.register(data_source, resource_name=resource_name) - # Update ingestion params with runtime options - ingestion_params.clean_start = fresh_start - ingestion_params.batch_size = batch_size - ingestion_params.init_only = init_only + # For data source registry, we need to use Caster directly + # since GraphEngine.ingest() uses patterns, not registry + from graflo.hq.caster import Caster + caster = Caster(schema=schema, ingestion_params=ingestion_params) caster.ingest_data_sources( data_source_registry=registry, conn_conf=conn_conf, ingestion_params=ingestion_params, ) else: - # Fall back to file-based ingestion - # Update ingestion params with runtime options - ingestion_params.clean_start = fresh_start - ingestion_params.batch_size = batch_size - ingestion_params.init_only = init_only - ingestion_params.limit_files = limit_files - - caster.ingest( + # Fall back to file-based ingestion using GraphEngine + engine.ingest( + schema=schema, output_config=conn_conf, patterns=patterns, ingestion_params=ingestion_params, diff --git a/graflo/hq/caster.py b/graflo/hq/caster.py index 21b0242..9eced67 100644 --- a/graflo/hq/caster.py +++ b/graflo/hq/caster.py @@ -439,6 +439,9 @@ def ingest_data_sources( ): """Ingest data from data sources in a registry. + Note: Schema definition should be handled separately via GraphEngine.define_schema() + before calling this method. + Args: data_source_registry: Registry containing data sources mapped to resources conn_conf: Database connection configuration @@ -452,23 +455,6 @@ def ingest_data_sources( self.ingestion_params = ingestion_params init_only = ingestion_params.init_only - # If effective_schema is not set, use schema.general.name as fallback - if conn_conf.can_be_target() and conn_conf.effective_schema is None: - schema_name = self.schema.general.name - # Map to the appropriate field based on DB type - if conn_conf.connection_type == DBType.TIGERGRAPH: - # TigerGraph uses 'schema_name' field - conn_conf.schema_name = schema_name - else: - # ArangoDB, Neo4j use 'database' field (which maps to effective_schema) - conn_conf.database = schema_name - - # init_db() now handles database/schema creation automatically - # It checks if the database exists and creates it if needed - # Uses schema.general.name if database is not set in config - with ConnectionManager(connection_config=conn_conf) as db_client: - db_client.init_db(self.schema, self.ingestion_params.clean_start) - if init_only: logger.info("ingest execution bound to init") sys.exit(0) diff --git a/graflo/hq/graph_engine.py b/graflo/hq/graph_engine.py index e82319f..c5165d2 100644 --- a/graflo/hq/graph_engine.py +++ b/graflo/hq/graph_engine.py @@ -8,8 +8,8 @@ import logging from graflo import Schema -from graflo.db import PostgresConnection -from graflo.db.connection.onto import DBConfig, PostgresConfig +from graflo.db import ConnectionManager, PostgresConnection +from graflo.db.connection.onto import DBConfig, DBType, PostgresConfig from graflo.hq.caster import Caster, IngestionParams from graflo.hq.inferencer import InferenceManager from graflo.hq.resource_mapper import ResourceMapper @@ -22,28 +22,30 @@ class GraphEngine: """Orchestrator for graph database operations. - GraphEngine coordinates schema inference, pattern creation, and data ingestion, - providing a unified interface for working with graph databases. + GraphEngine coordinates schema inference, pattern creation, schema definition, + and data ingestion, providing a unified interface for working with graph databases. + + The typical workflow is: + 1. infer_schema() - Infer schema from source database (if possible) + 2. create_patterns() - Create patterns mapping resources to data sources (if possible) + 3. define_schema() - Define schema in target database (if possible and necessary) + 4. ingest() - Ingest data into the target database Attributes: - inferencer: InferenceManager instance for schema inference - caster: Caster instance for data ingestion + target_db_flavor: Target database flavor for schema sanitization resource_mapper: ResourceMapper instance for pattern creation """ def __init__( self, target_db_flavor: DBFlavor = DBFlavor.ARANGO, - ingestion_params: IngestionParams | None = None, ): """Initialize the GraphEngine. Args: target_db_flavor: Target database flavor for schema sanitization - ingestion_params: IngestionParams instance for controlling ingestion behavior """ self.target_db_flavor = target_db_flavor - self.ingestion_params = ingestion_params or IngestionParams() self.resource_mapper = ResourceMapper() def infer_schema( @@ -89,6 +91,40 @@ def create_patterns( conn=postgres_conn, schema_name=schema_name ) + def define_schema( + self, + schema: Schema, + output_config: DBConfig, + clean_start: bool = False, + ) -> None: + """Define schema in the target database. + + This method handles database/schema creation and initialization. + Some databases don't require explicit schema definition (e.g., Neo4j), + but this method ensures the database is properly initialized. + + Args: + schema: Schema configuration for the graph + output_config: Target database connection configuration + clean_start: Whether to clean the database before defining schema + """ + # If effective_schema is not set, use schema.general.name as fallback + if output_config.can_be_target() and output_config.effective_schema is None: + schema_name = schema.general.name + # Map to the appropriate field based on DB type + if output_config.connection_type == DBType.TIGERGRAPH: + # TigerGraph uses 'schema_name' field + output_config.schema_name = schema_name + else: + # ArangoDB, Neo4j use 'database' field (which maps to effective_schema) + output_config.database = schema_name + + # Initialize database with schema definition + # init_db() handles database/schema creation automatically + # It checks if the database exists and creates it if needed + with ConnectionManager(connection_config=output_config) as db_client: + db_client.init_db(schema, clean_start) + def ingest( self, schema: Schema, @@ -104,13 +140,12 @@ def ingest( patterns: Patterns instance mapping resources to data sources. If None, defaults to empty Patterns() ingestion_params: IngestionParams instance with ingestion configuration. - If None, uses the instance's default ingestion_params + If None, uses default IngestionParams() """ - caster = Caster( - schema=schema, ingestion_params=ingestion_params or self.ingestion_params - ) + ingestion_params = ingestion_params or IngestionParams() + caster = Caster(schema=schema, ingestion_params=ingestion_params) caster.ingest( output_config=output_config, patterns=patterns or Patterns(), - ingestion_params=ingestion_params or self.ingestion_params, + ingestion_params=ingestion_params, ) diff --git a/test/conftest.py b/test/conftest.py index 147d77b..fc0893d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -10,7 +10,6 @@ from graflo.architecture.schema import Schema from graflo.architecture.util import cast_graph_name_to_triple -from graflo.hq.caster import Caster from graflo.util.misc import sorted_dicts from graflo.util.onto import Patterns, FilePattern @@ -71,11 +70,41 @@ def ingest_atomic(conn_conf, current_path, test_db_name, schema_o, mode, n_cores ) patterns.add_file_pattern(resource_name, file_pattern) + # Determine DB flavor from connection config + from graflo.db.connection.onto import DBType + from graflo.hq import GraphEngine from graflo.hq.caster import IngestionParams + from graflo.onto import DBFlavor + + db_type = conn_conf.connection_type + # Map DBType to DBFlavor (they have the same values for graph databases) + db_flavor = ( + DBFlavor(db_type.value) + if db_type + in ( + DBType.ARANGO, + DBType.NEO4J, + DBType.TIGERGRAPH, + DBType.FALKORDB, + DBType.MEMGRAPH, + ) + else DBFlavor.ARANGO + ) + + # Use GraphEngine for the full workflow + engine = GraphEngine(target_db_flavor=db_flavor) + + # Define schema first (with clean_start=True) + engine.define_schema( + schema=schema_o, + output_config=conn_conf, + clean_start=True, + ) - caster = Caster(schema_o) - ingestion_params = IngestionParams(n_cores=n_cores, clean_start=True) - caster.ingest( + # Then ingest data + ingestion_params = IngestionParams(n_cores=n_cores, clean_start=False) + engine.ingest( + schema=schema_o, output_config=conn_conf, patterns=patterns, ingestion_params=ingestion_params, From 0934b309130c6592dd4a151cf98873912dbf746f Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Tue, 27 Jan 2026 15:09:21 +0100 Subject: [PATCH 7/9] removed DBFlavor, a duplicate of DBType --- docs/examples/example-5.md | 3 +- examples/1-ingest-csv/ingest.py | 18 ++++-- examples/2-ingest-self-references/ingest.py | 18 ++++-- examples/3-ingest-csv-edge-weights/ingest.py | 20 ++++--- examples/4-ingest-neo4j/ingest.py | 14 +++-- examples/5-ingest-postgres/ingest.py | 31 ++++------- graflo/__init__.py | 3 +- graflo/architecture/actor_util.py | 5 +- graflo/architecture/edge.py | 13 +++-- graflo/architecture/onto.py | 14 ++--- graflo/architecture/vertex.py | 11 ++-- graflo/cli/ingest.py | 32 +++++------ graflo/db/__init__.py | 3 +- graflo/db/arango/conn.py | 6 +- graflo/db/arango/query.py | 4 +- graflo/db/connection/__init__.py | 3 +- graflo/db/connection/config_mapping.py | 20 ++++++- graflo/db/connection/onto.py | 36 ++---------- graflo/db/falkordb/conn.py | 10 ++-- graflo/db/manager.py | 3 +- graflo/db/memgraph/conn.py | 10 ++-- graflo/db/neo4j/conn.py | 8 ++- graflo/db/postgres/schema_inference.py | 6 +- graflo/db/tigergraph/conn.py | 5 +- graflo/db/util.py | 11 ++-- graflo/hq/caster.py | 27 +-------- graflo/hq/graph_engine.py | 58 +++++++++++++++++++- graflo/hq/inferencer.py | 4 +- graflo/hq/sanitizer.py | 6 +- graflo/onto.py | 44 +++++++-------- test/architecture/test_vertex.py | 6 +- test/conftest.py | 28 ++++------ test/db/postgres/test_schema_inference.py | 6 +- test/db/tigergraphs/test_reserved_words.py | 6 +- 34 files changed, 261 insertions(+), 231 deletions(-) diff --git a/docs/examples/example-5.md b/docs/examples/example-5.md index ed7b4e6..a845d4e 100644 --- a/docs/examples/example-5.md +++ b/docs/examples/example-5.md @@ -264,9 +264,8 @@ Automatically generate a graflo Schema from your PostgreSQL database. This is th ```python from graflo.hq import GraphEngine -from graflo.onto import DBFlavor +from graflo.onto import DBFlavor, DBType from graflo.db.connection.onto import ArangoConfig, Neo4jConfig, TigergraphConfig, FalkordbConfig, PostgresConfig -from graflo.db import DBType # Connect to target graph database to determine flavor # Choose one of: ArangoConfig, Neo4jConfig, TigergraphConfig, or FalkordbConfig diff --git a/examples/1-ingest-csv/ingest.py b/examples/1-ingest-csv/ingest.py index 9d592c3..726e406 100644 --- a/examples/1-ingest-csv/ingest.py +++ b/examples/1-ingest-csv/ingest.py @@ -1,8 +1,9 @@ import pathlib from suthing import FileHandle -from graflo import Caster, Patterns, Schema +from graflo import Patterns, Schema from graflo.util.onto import FilePattern from graflo.db.connection.onto import ArangoConfig +from graflo.hq import GraphEngine from graflo.hq.caster import IngestionParams schema = Schema.from_dict(FileHandle.load("schema.yaml")) @@ -22,6 +23,9 @@ # database="_system", # ) +# Determine DB type from connection config +db_type = conn_conf.connection_type + # Create Patterns with file patterns patterns = Patterns() patterns.add_file_pattern( @@ -45,10 +49,12 @@ # } # ) -caster = Caster(schema) - - +# Create GraphEngine and define schema + ingest in one operation +engine = GraphEngine(target_db_flavor=db_type) ingestion_params = IngestionParams(clean_start=True) -caster.ingest( - output_config=conn_conf, patterns=patterns, ingestion_params=ingestion_params +engine.define_and_ingest( + schema=schema, + output_config=conn_conf, + patterns=patterns, + ingestion_params=ingestion_params, ) diff --git a/examples/2-ingest-self-references/ingest.py b/examples/2-ingest-self-references/ingest.py index aeb6b11..e1384fd 100644 --- a/examples/2-ingest-self-references/ingest.py +++ b/examples/2-ingest-self-references/ingest.py @@ -1,8 +1,9 @@ import pathlib from suthing import FileHandle -from graflo import Caster, Patterns, Schema +from graflo import Patterns, Schema from graflo.util.onto import FilePattern from graflo.db.connection.onto import ArangoConfig +from graflo.hq import GraphEngine from graflo.hq.caster import IngestionParams schema = Schema.from_dict(FileHandle.load("schema.yaml")) @@ -22,6 +23,9 @@ # database="_system", # ) +# Determine DB type from connection config +db_type = conn_conf.connection_type + # Create Patterns with file patterns patterns = Patterns() patterns.add_file_pattern( @@ -36,10 +40,12 @@ # } # ) -caster = Caster(schema) - - +# Create GraphEngine and define schema + ingest in one operation +engine = GraphEngine(target_db_flavor=db_type) ingestion_params = IngestionParams(clean_start=True) -caster.ingest( - output_config=conn_conf, patterns=patterns, ingestion_params=ingestion_params +engine.define_and_ingest( + schema=schema, + output_config=conn_conf, + patterns=patterns, + ingestion_params=ingestion_params, ) diff --git a/examples/3-ingest-csv-edge-weights/ingest.py b/examples/3-ingest-csv-edge-weights/ingest.py index a0cf9bc..758a744 100644 --- a/examples/3-ingest-csv-edge-weights/ingest.py +++ b/examples/3-ingest-csv-edge-weights/ingest.py @@ -1,7 +1,8 @@ from suthing import FileHandle -from graflo import Caster, Patterns, Schema -from graflo.hq.caster import IngestionParams +from graflo import Patterns, Schema from graflo.db.connection.onto import Neo4jConfig +from graflo.hq import GraphEngine +from graflo.hq.caster import IngestionParams import logging @@ -32,6 +33,9 @@ # bolt_port=7688 # ) +# Determine DB type from connection config +db_type = conn_conf.connection_type + # Load patterns from YAML file (same pattern as Schema) patterns = Patterns.from_dict(FileHandle.load("patterns.yaml")) @@ -44,10 +48,12 @@ # FilePattern(regex="^relations.*\.csv$", sub_path=pathlib.Path("."), resource_name="relations") # ) -caster = Caster(schema) - - +# Create GraphEngine and define schema + ingest in one operation +engine = GraphEngine(target_db_flavor=db_type) ingestion_params = IngestionParams(clean_start=True) -caster.ingest( - output_config=conn_conf, patterns=patterns, ingestion_params=ingestion_params +engine.define_and_ingest( + schema=schema, + output_config=conn_conf, + patterns=patterns, + ingestion_params=ingestion_params, ) diff --git a/examples/4-ingest-neo4j/ingest.py b/examples/4-ingest-neo4j/ingest.py index 17ca42a..73c6401 100644 --- a/examples/4-ingest-neo4j/ingest.py +++ b/examples/4-ingest-neo4j/ingest.py @@ -1,8 +1,9 @@ import pathlib from suthing import FileHandle -from graflo import Caster, Patterns, Schema +from graflo import Patterns, Schema from graflo.util.onto import FilePattern from graflo.db.connection.onto import Neo4jConfig +from graflo.hq import GraphEngine from graflo.hq.caster import IngestionParams schema = Schema.from_dict(FileHandle.load("schema.yaml")) @@ -22,6 +23,9 @@ # bolt_port=7688 # ) +# Determine DB type from connection config +db_type = conn_conf.connection_type + # Create Patterns with file patterns patterns = Patterns() patterns.add_file_pattern( @@ -49,14 +53,14 @@ # } # ) -caster = Caster(schema) - - +# Create GraphEngine and define schema + ingest in one operation +engine = GraphEngine(target_db_flavor=db_type) ingestion_params = IngestionParams( clean_start=True, # max_items=5, ) -caster.ingest( +engine.define_and_ingest( + schema=schema, output_config=conn_conf, # Target database config patterns=patterns, # Source data patterns ingestion_params=ingestion_params, diff --git a/examples/5-ingest-postgres/ingest.py b/examples/5-ingest-postgres/ingest.py index f269e6b..90fb26c 100644 --- a/examples/5-ingest-postgres/ingest.py +++ b/examples/5-ingest-postgres/ingest.py @@ -15,8 +15,6 @@ from pathlib import Path from suthing import FileHandle -from graflo.onto import DBFlavor -from graflo.db import DBType from graflo.hq import GraphEngine, IngestionParams from graflo.db.postgres.util import load_schema_from_sql_file from graflo.db.connection.onto import PostgresConfig, TigergraphConfig @@ -62,14 +60,9 @@ conn_conf = TigergraphConfig.from_docker_env() # TigerGraph # conn_conf = FalkordbConfig.from_docker_env() # FalkorDB -# Determine db_flavor from target config +# Determine db_type from target config db_type = conn_conf.connection_type -# Map DBType to DBFlavor (they have the same values) -db_flavor = ( - DBFlavor(db_type.value) - if db_type in (DBType.ARANGO, DBType.NEO4J, DBType.TIGERGRAPH) - else DBFlavor.ARANGO -) + # Step 1.5: Initialize PostgreSQL database with mock schema if needed # This ensures the database has the required tables (users, products, purchases, follows) @@ -89,7 +82,7 @@ # Step 3: Create GraphEngine to orchestrate schema inference, pattern creation, and ingestion # GraphEngine coordinates all operations: schema inference, pattern mapping, schema definition, and data ingestion -engine = GraphEngine(target_db_flavor=db_flavor) +engine = GraphEngine(target_db_flavor=db_type) # Step 3.1: Infer Schema from PostgreSQL database structure # This automatically detects vertex-like and edge-like tables based on: @@ -116,22 +109,18 @@ # Connection is automatically managed inside create_patterns() patterns = engine.create_patterns(postgres_conf, schema_name="public") -# Step 4.5: Define schema in target database -# This creates/initializes the database schema (if necessary) +# Step 4.5 & 5: Define schema and ingest data in one operation +# This creates/initializes the database schema and then ingests data # Some databases don't require explicit schema definition, but this ensures proper initialization -engine.define_schema( - schema=schema, - output_config=conn_conf, - clean_start=True, # Clean existing data before defining schema -) - -# Step 5: Ingest data using GraphEngine # Note: ingestion will create its own PostgreSQL connections per table internally -engine.ingest( +engine.define_and_ingest( schema=schema, output_config=conn_conf, patterns=patterns, - ingestion_params=IngestionParams(clean_start=False), # Schema already defined above + ingestion_params=IngestionParams( + clean_start=False + ), # clean_start handled by define_and_ingest + clean_start=True, # Clean existing data before defining schema ) print("\n" + "=" * 80) diff --git a/graflo/__init__.py b/graflo/__init__.py index c1cc26b..0d906fd 100644 --- a/graflo/__init__.py +++ b/graflo/__init__.py @@ -38,7 +38,7 @@ ) from .db import ConnectionManager, ConnectionType from .filter.onto import ComparisonOperator, LogicalOperator -from .onto import AggregationType +from .onto import AggregationType, DBType from .util.onto import FilePattern, Patterns, ResourcePattern, TablePattern __all__ = [ @@ -53,6 +53,7 @@ "DataSourceFactory", "DataSourceRegistry", "DataSourceType", + "DBType", "FileDataSource", "FilePattern", "Index", diff --git a/graflo/architecture/actor_util.py b/graflo/architecture/actor_util.py index d002c4d..2c25e27 100644 --- a/graflo/architecture/actor_util.py +++ b/graflo/architecture/actor_util.py @@ -44,7 +44,8 @@ ) from graflo.architecture.util import project_dict from graflo.architecture.vertex import VertexConfig -from graflo.onto import DBFlavor +from graflo.onto import DBType + logger = logging.getLogger(__name__) @@ -340,7 +341,7 @@ def render_edge( b = project_dict(v, target_index) # For TigerGraph, extracted relations go to weight, not as relation key - is_tigergraph = vertex_config.db_flavor == DBFlavor.TIGERGRAPH + is_tigergraph = vertex_config.db_flavor == DBType.TIGERGRAPH extracted_relation = None # 1. Try to extract relation from data context diff --git a/graflo/architecture/edge.py b/graflo/architecture/edge.py index c62046c..df1b5c8 100644 --- a/graflo/architecture/edge.py +++ b/graflo/architecture/edge.py @@ -2,7 +2,7 @@ This module provides classes and utilities for managing edges in graph databases. It handles edge configuration, weight management, indexing, and relationship operations. -The module supports both ArangoDB and Neo4j through the DBFlavor enum. +The module supports both ArangoDB and Neo4j through the DBType enum. Key Components: - Edge: Represents an edge with its source, target, and configuration @@ -28,7 +28,8 @@ Weight, ) from graflo.architecture.vertex import Field, FieldType, VertexConfig, _FieldsType -from graflo.onto import DBFlavor +from graflo.onto import DBType + # Default relation name for TigerGraph edges when relation is not specified DEFAULT_TIGERGRAPH_RELATION = "relates" @@ -249,7 +250,7 @@ def finish_init(self, vertex_config: VertexConfig): self._target = vertex_config.vertex_dbname(self.target) # ArangoDB-specific: set graph_name and database_name only for ArangoDB - if vertex_config.db_flavor == DBFlavor.ARANGO: + if vertex_config.db_flavor == DBType.ARANGO: graph_name = [ vertex_config.vertex_dbname(self.source), vertex_config.vertex_dbname(self.target), @@ -260,7 +261,7 @@ def finish_init(self, vertex_config: VertexConfig): self.database_name = "_".join(graph_name + ["edges"]) # TigerGraph requires named edge types (relations), so assign default if missing - if vertex_config.db_flavor == DBFlavor.TIGERGRAPH and self.relation is None: + if vertex_config.db_flavor == DBType.TIGERGRAPH and self.relation is None: # Use default relation name for TigerGraph # TigerGraph requires all edges to have a named type (relation) self.relation = DEFAULT_TIGERGRAPH_RELATION @@ -270,7 +271,7 @@ def finish_init(self, vertex_config: VertexConfig): # TigerGraph: add relation field to weights if relation_field or relation_from_key is set # This ensures the relation value is included as a typed property in the edge schema - if vertex_config.db_flavor == DBFlavor.TIGERGRAPH: + if vertex_config.db_flavor == DBType.TIGERGRAPH: if self.relation_field is None and self.relation_from_key: # relation_from_key is True but relation_field not set, default to standard name self.relation_field = DEFAULT_TIGERGRAPH_RELATION_WEIGHTNAME @@ -336,7 +337,7 @@ def _init_index(self, index: Index, vc: VertexConfig) -> Index: fields = vc.index(index.name).fields index_fields += [f"{index.name}@{x}" for x in fields] - if not index.exclude_edge_endpoints and vc.db_flavor == DBFlavor.ARANGO: + if not index.exclude_edge_endpoints and vc.db_flavor == DBType.ARANGO: if all([item not in index_fields for item in ["_from", "_to"]]): index_fields = ["_from", "_to"] + index_fields diff --git a/graflo/architecture/onto.py b/graflo/architecture/onto.py index 8581808..6bae8c1 100644 --- a/graflo/architecture/onto.py +++ b/graflo/architecture/onto.py @@ -10,7 +10,7 @@ - Action context for graph transformations The module is designed to be database-agnostic, supporting both ArangoDB and Neo4j through -the DBFlavor enum. It provides a unified interface for working with graph data structures +the DBType enum. It provides a unified interface for working with graph data structures while allowing for database-specific optimizations and features. Key Components: @@ -36,7 +36,8 @@ from dataclass_wizard import JSONWizard, YAMLWizard -from graflo.onto import BaseDataclass, BaseEnum, DBFlavor +from graflo.onto import DBType +from graflo.onto import BaseDataclass, BaseEnum from graflo.util.transform import pick_unique_dict # type for vertex or edge name (index) @@ -154,7 +155,7 @@ def __iter__(self): """Iterate over the indexed fields.""" return iter(self.fields) - def db_form(self, db_type: DBFlavor) -> dict: + def db_form(self, db_type: DBType) -> dict: """Convert index configuration to database-specific format. Args: @@ -167,14 +168,11 @@ def db_form(self, db_type: DBFlavor) -> dict: ValueError: If db_type is not supported """ r = self.to_dict() - if db_type == DBFlavor.ARANGO: + if db_type == DBType.ARANGO: _ = r.pop("name") _ = r.pop("exclude_edge_endpoints") - elif db_type == DBFlavor.NEO4J: - pass else: - raise ValueError(f"Unknown db_type {db_type}") - + pass return r diff --git a/graflo/architecture/vertex.py b/graflo/architecture/vertex.py index b8eacee..2d21140 100644 --- a/graflo/architecture/vertex.py +++ b/graflo/architecture/vertex.py @@ -2,7 +2,7 @@ This module provides classes and utilities for managing vertices in graph databases. It handles vertex configuration, field management, indexing, and filtering operations. -The module supports both ArangoDB and Neo4j through the DBFlavor enum. +The module supports both ArangoDB and Neo4j through the DBType enum. Key Components: - Vertex: Represents a vertex with its fields and indexes @@ -23,7 +23,8 @@ from graflo.architecture.onto import Index from graflo.filter.onto import Expression -from graflo.onto import BaseDataclass, BaseEnum, DBFlavor +from graflo.onto import DBType +from graflo.onto import BaseDataclass, BaseEnum logger = logging.getLogger(__name__) @@ -314,7 +315,7 @@ def __post_init__(self): self.fields.append(Field(name=field_name, type=None)) seen_names.add(field_name) - def finish_init(self, db_flavor: DBFlavor): + def finish_init(self, db_flavor: DBType): """Complete initialization of vertex with database-specific field types. Args: @@ -322,7 +323,7 @@ def finish_init(self, db_flavor: DBFlavor): """ self.fields = [ Field(name=f.name, type=FieldType.STRING) - if f.type is None and db_flavor == DBFlavor.TIGERGRAPH + if f.type is None and db_flavor == DBType.TIGERGRAPH else f for f in self.fields ] @@ -345,7 +346,7 @@ class VertexConfig(BaseDataclass): vertices: list[Vertex] blank_vertices: list[str] = dataclasses.field(default_factory=list) force_types: dict[str, list] = dataclasses.field(default_factory=dict) - db_flavor: DBFlavor = DBFlavor.ARANGO + db_flavor: DBType = DBType.ARANGO def __post_init__(self): """Initialize the vertex configuration. diff --git a/graflo/cli/ingest.py b/graflo/cli/ingest.py index 5dc8c0f..f2f0c0c 100644 --- a/graflo/cli/ingest.py +++ b/graflo/cli/ingest.py @@ -27,11 +27,10 @@ import click from suthing import FileHandle -from graflo import DataSourceRegistry, Patterns, Schema -from graflo.db.connection.onto import DBConfig, DBType +from graflo import DataSourceRegistry, Patterns, Schema, DBType +from graflo.db.connection.onto import DBConfig from graflo.data_source import DataSourceFactory from graflo.hq import GraphEngine -from graflo.onto import DBFlavor logger = logging.getLogger(__name__) @@ -138,24 +137,21 @@ def ingest( schema.fetch_resource() - # Determine DB flavor from connection config + # Determine DB type from connection config db_type = conn_conf.connection_type - # Map DBType to DBFlavor (they have the same values for graph databases) - db_flavor = ( - DBFlavor(db_type.value) - if db_type - in ( - DBType.ARANGO, - DBType.NEO4J, - DBType.TIGERGRAPH, - DBType.FALKORDB, - DBType.MEMGRAPH, - ) - else DBFlavor.ARANGO - ) + # Ensure it's a graph database (target database) + if db_type not in ( + DBType.ARANGO, + DBType.NEO4J, + DBType.TIGERGRAPH, + DBType.FALKORDB, + DBType.MEMGRAPH, + DBType.NEBULA, + ): + db_type = DBType.ARANGO # Default to ARANGO for non-graph databases # Create GraphEngine for the full workflow - engine = GraphEngine(target_db_flavor=db_flavor) + engine = GraphEngine(target_db_flavor=db_type) # Create ingestion params with CLI arguments from graflo.hq.caster import IngestionParams diff --git a/graflo/db/__init__.py b/graflo/db/__init__.py index 558bc27..bf7787b 100644 --- a/graflo/db/__init__.py +++ b/graflo/db/__init__.py @@ -25,7 +25,7 @@ from .arango.conn import ArangoConnection from .conn import Connection, ConnectionType -from .connection import DBConfig, DBType +from .connection import DBConfig from .falkordb.conn import FalkordbConnection from .manager import ConnectionManager from .memgraph.conn import MemgraphConnection @@ -37,7 +37,6 @@ __all__ = [ "Connection", "ConnectionType", - "DBType", "DBConfig", "ConnectionManager", "ArangoConnection", diff --git a/graflo/db/arango/conn.py b/graflo/db/arango/conn.py index 2dcef3a..51d85f8 100644 --- a/graflo/db/arango/conn.py +++ b/graflo/db/arango/conn.py @@ -42,8 +42,10 @@ from graflo.db.conn import Connection from graflo.db.util import get_data_from_cursor, json_serializer from graflo.filter.onto import Clause -from graflo.onto import AggregationType, DBFlavor +from graflo.onto import AggregationType from graflo.util.transform import pick_unique_dict +from graflo.onto import DBType + from ..connection.onto import ArangoConfig @@ -385,7 +387,7 @@ def _add_index(self, general_collection: Any, index: Index) -> Any | None: Returns: IndexHandle: Handle to the created index, or None if index type is not supported """ - data = index.db_form(DBFlavor.ARANGO) + data = index.db_form(DBType.ARANGO) ih: Any | None = None if index.type == IndexType.PERSISTENT: ih = general_collection.add_index(data) diff --git a/graflo/db/arango/query.py b/graflo/db/arango/query.py index c7b704c..faa7dee 100644 --- a/graflo/db/arango/query.py +++ b/graflo/db/arango/query.py @@ -22,7 +22,7 @@ from arango import ArangoClient from graflo.filter.onto import Expression -from graflo.onto import DBFlavor +from graflo.onto import DBType logger = logging.getLogger(__name__) @@ -166,7 +166,7 @@ def fetch_fields_query( if filters is not None: ff = Expression.from_dict(filters) - extrac_filter_clause = f" && {ff(doc_name='_cdoc', kind=DBFlavor.ARANGO)}" + extrac_filter_clause = f" && {ff(doc_name='_cdoc', kind=DBType.ARANGO)}" else: extrac_filter_clause = "" diff --git a/graflo/db/connection/__init__.py b/graflo/db/connection/__init__.py index 40b4c94..0e986c2 100644 --- a/graflo/db/connection/__init__.py +++ b/graflo/db/connection/__init__.py @@ -1,6 +1,5 @@ -from .onto import DBConfig, DBType +from .onto import DBConfig __all__ = [ "DBConfig", - "DBType", ] diff --git a/graflo/db/connection/config_mapping.py b/graflo/db/connection/config_mapping.py index 1eab59d..fbef6ac 100644 --- a/graflo/db/connection/config_mapping.py +++ b/graflo/db/connection/config_mapping.py @@ -3,7 +3,6 @@ from .onto import ( ArangoConfig, DBConfig, - DBType, FalkordbConfig, MemgraphConfig, NebulaConfig, @@ -11,6 +10,7 @@ PostgresConfig, TigergraphConfig, ) +from ... import DBType # Define this mapping in a separate file to avoid circular imports DB_TYPE_MAPPING: Dict[DBType, Type[DBConfig]] = { @@ -22,3 +22,21 @@ DBType.NEBULA: NebulaConfig, DBType.POSTGRES: PostgresConfig, } + + +def get_config_class(db_type: DBType) -> Type[DBConfig]: + """Get the appropriate config class for a database type. + + This factory function breaks the circular dependency by moving the + lookup logic out of the DBType enum. + + Args: + db_type: The database type enum value + + Returns: + The corresponding DBConfig subclass + + Raises: + KeyError: If the db_type is not in the mapping + """ + return DB_TYPE_MAPPING[db_type] diff --git a/graflo/db/connection/onto.py b/graflo/db/connection/onto.py index 4d8053c..3e8256e 100644 --- a/graflo/db/connection/onto.py +++ b/graflo/db/connection/onto.py @@ -2,7 +2,6 @@ import logging import warnings from pathlib import Path -from strenum import StrEnum from typing import Any, Dict, Type, TypeVar from urllib.parse import urlparse @@ -10,42 +9,13 @@ from pydantic import AliasChoices from pydantic_settings import BaseSettings, SettingsConfigDict -from graflo.onto import MetaEnum +from graflo.onto import DBType logger = logging.getLogger(__name__) # Type variable for DBConfig subclasses T = TypeVar("T", bound="DBConfig") - -class DBType(StrEnum, metaclass=MetaEnum): - """Enum representing different types of databases. - - Includes both graph databases and source databases (SQL, NoSQL, etc.). - """ - - # Graph databases - ARANGO = "arango" - NEO4J = "neo4j" - TIGERGRAPH = "tigergraph" - FALKORDB = "falkordb" - MEMGRAPH = "memgraph" - NEBULA = "nebula" - - # Source databases (SQL, NoSQL) - POSTGRES = "postgres" - MYSQL = "mysql" - MONGODB = "mongodb" - SQLITE = "sqlite" - - @property - def config_class(self) -> Type["DBConfig"]: - """Get the appropriate config class for this database type.""" - from .config_mapping import DB_TYPE_MAPPING - - return DB_TYPE_MAPPING[self] - - # Databases that can be used as sources (INPUT) SOURCE_DATABASES: set[DBType] = { DBType.ARANGO, # Graph DBs can be sources @@ -411,7 +381,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "DBConfig": config_data["uri"] = f"{protocol}://{hostname}" # Get the appropriate config class and initialize it - config_class = conn_type.config_class + from .config_mapping import get_config_class + + config_class = get_config_class(conn_type) return config_class(**config_data) @classmethod diff --git a/graflo/db/falkordb/conn.py b/graflo/db/falkordb/conn.py index ca9f82e..0867fd7 100644 --- a/graflo/db/falkordb/conn.py +++ b/graflo/db/falkordb/conn.py @@ -38,7 +38,9 @@ from graflo.db.conn import Connection from graflo.db.util import serialize_value from graflo.filter.onto import Expression -from graflo.onto import AggregationType, DBFlavor, ExpressionFlavor +from graflo.onto import AggregationType, ExpressionFlavor +from graflo.onto import DBType + from ..connection.onto import FalkordbConfig @@ -60,7 +62,7 @@ class FalkordbConnection(Connection): _graph_name: Name of the currently selected graph """ - flavor = DBFlavor.FALKORDB + flavor = DBType.FALKORDB # Type annotations for instance attributes client: FalkorDB | None @@ -654,7 +656,7 @@ def fetch_docs( if filters is not None: ff = Expression.from_dict(filters) # Use NEO4J flavor since FalkorDB uses OpenCypher - filter_clause = f"WHERE {ff(doc_name='n', kind=DBFlavor.NEO4J)}" + filter_clause = f"WHERE {ff(doc_name='n', kind=DBType.NEO4J)}" else: filter_clause = "" @@ -877,7 +879,7 @@ def aggregate( # Build filter clause if filters is not None: ff = Expression.from_dict(filters) - filter_clause = f"WHERE {ff(doc_name='n', kind=DBFlavor.NEO4J)}" + filter_clause = f"WHERE {ff(doc_name='n', kind=DBType.NEO4J)}" else: filter_clause = "" diff --git a/graflo/db/manager.py b/graflo/db/manager.py index 785b0da..91bbe9d 100644 --- a/graflo/db/manager.py +++ b/graflo/db/manager.py @@ -24,7 +24,8 @@ """ from graflo.db.arango.conn import ArangoConnection -from graflo.db.connection.onto import DBConfig, DBType, TARGET_DATABASES +from graflo.db.connection.onto import DBConfig, TARGET_DATABASES +from graflo.onto import DBType from graflo.db.falkordb.conn import FalkordbConnection from graflo.db.memgraph.conn import MemgraphConnection from graflo.db.neo4j.conn import Neo4jConnection diff --git a/graflo/db/memgraph/conn.py b/graflo/db/memgraph/conn.py index 14dc582..34fac40 100644 --- a/graflo/db/memgraph/conn.py +++ b/graflo/db/memgraph/conn.py @@ -88,7 +88,9 @@ from graflo.architecture.vertex import VertexConfig from graflo.db.conn import Connection from graflo.filter.onto import Expression -from graflo.onto import AggregationType, DBFlavor, ExpressionFlavor +from graflo.onto import AggregationType, ExpressionFlavor +from graflo.onto import DBType + from ..connection.onto import MemgraphConfig @@ -152,8 +154,8 @@ class MemgraphConnection(Connection): Attributes ---------- - flavor : DBFlavor - Database type identifier (DBFlavor.MEMGRAPH) + flavor : DBType + Database type identifier (DBType.MEMGRAPH) config : MemgraphConfig Connection configuration (URI, credentials) conn : mgclient.Connection @@ -171,7 +173,7 @@ class MemgraphConnection(Connection): conn.close() """ - flavor = DBFlavor.MEMGRAPH + flavor = DBType.MEMGRAPH # Type annotations for instance attributes conn: mgclient.Connection | None diff --git a/graflo/db/neo4j/conn.py b/graflo/db/neo4j/conn.py index acd6223..8c9940b 100644 --- a/graflo/db/neo4j/conn.py +++ b/graflo/db/neo4j/conn.py @@ -34,7 +34,9 @@ from graflo.architecture.vertex import VertexConfig from graflo.db.conn import Connection from graflo.filter.onto import Expression -from graflo.onto import AggregationType, DBFlavor, ExpressionFlavor +from graflo.onto import AggregationType, ExpressionFlavor +from graflo.onto import DBType + from ..connection.onto import Neo4jConfig @@ -53,7 +55,7 @@ class Neo4jConnection(Connection): conn: Neo4j session instance """ - flavor = DBFlavor.NEO4J + flavor = DBType.NEO4J def __init__(self, config: Neo4jConfig): """Initialize Neo4j connection. @@ -474,7 +476,7 @@ def fetch_docs( """ if filters is not None: ff = Expression.from_dict(filters) - filter_clause = f"WHERE {ff(doc_name='n', kind=DBFlavor.NEO4J)}" + filter_clause = f"WHERE {ff(doc_name='n', kind=DBType.NEO4J)}" else: filter_clause = "" diff --git a/graflo/db/postgres/schema_inference.py b/graflo/db/postgres/schema_inference.py index 67354cb..2d39b82 100644 --- a/graflo/db/postgres/schema_inference.py +++ b/graflo/db/postgres/schema_inference.py @@ -14,9 +14,9 @@ from graflo.architecture.onto import Index, IndexType from graflo.architecture.schema import Schema, SchemaMetadata from graflo.architecture.vertex import Field, FieldType, Vertex, VertexConfig -from graflo.onto import DBFlavor +from graflo.onto import DBType -from ...architecture.onto_sql import EdgeTableInfo, SchemaIntrospectionResult +from graflo.architecture.onto_sql import EdgeTableInfo, SchemaIntrospectionResult from .conn import PostgresConnection from .types import PostgresTypeMapper @@ -35,7 +35,7 @@ class PostgresSchemaInferencer: def __init__( self, - db_flavor: DBFlavor = DBFlavor.ARANGO, + db_flavor: DBType = DBType.ARANGO, conn: PostgresConnection | None = None, ): """Initialize the schema inferencer. diff --git a/graflo/db/tigergraph/conn.py b/graflo/db/tigergraph/conn.py index 6ebf31f..6b86117 100644 --- a/graflo/db/tigergraph/conn.py +++ b/graflo/db/tigergraph/conn.py @@ -50,7 +50,8 @@ ) from graflo.db.util import json_serializer from graflo.filter.onto import Clause, Expression -from graflo.onto import AggregationType, DBFlavor, ExpressionFlavor +from graflo.onto import AggregationType, ExpressionFlavor +from graflo.onto import DBType from graflo.util.transform import pick_unique_dict from urllib.parse import quote @@ -229,7 +230,7 @@ class TigerGraphConnection(Connection): - Version is auto-detected, or can be manually specified in config """ - flavor = DBFlavor.TIGERGRAPH + flavor = DBType.TIGERGRAPH def __init__(self, config: TigergraphConfig): super().__init__() diff --git a/graflo/db/util.py b/graflo/db/util.py index 88108fc..ebf8bdc 100644 --- a/graflo/db/util.py +++ b/graflo/db/util.py @@ -11,15 +11,14 @@ - sanitize_attribute_name: Sanitize attribute names to avoid reserved words Example: - >>> # ArangoDB-specific AQL query (collection is ArangoDB terminology) + >>> from graflo.onto import DBType >>> # ArangoDB-specific AQL query (collection is ArangoDB terminology) >>> cursor = db.execute("FOR doc IN vertex_class RETURN doc") >>> batch = get_data_from_cursor(cursor, limit=100) >>> # Serialize datetime objects in a document >>> doc = {"id": 1, "created_at": datetime.now()} >>> serialized = serialize_document(doc) >>> # Sanitize reserved words - >>> from graflo.onto import DBFlavor - >>> reserved = load_reserved_words(DBFlavor.TIGERGRAPH) + >>> reserved = load_reserved_words(DBType.TIGERGRAPH) >>> sanitized = sanitize_attribute_name("SELECT", reserved) """ @@ -31,7 +30,7 @@ from arango.exceptions import CursorNextError -from graflo.onto import DBFlavor +from graflo.db.connection.onto import DBType logger = logging.getLogger(__name__) @@ -180,7 +179,7 @@ def json_serializer(obj): return serialized -def load_reserved_words(db_flavor: DBFlavor) -> set[str]: +def load_reserved_words(db_flavor: DBType) -> set[str]: """Load reserved words for a given database flavor. Args: @@ -190,7 +189,7 @@ def load_reserved_words(db_flavor: DBFlavor) -> set[str]: Set of reserved words (uppercase) for the database flavor. Returns empty set if no reserved words file exists or for unsupported flavors. """ - if db_flavor != DBFlavor.TIGERGRAPH: + if db_flavor != DBType.TIGERGRAPH: # Currently only TigerGraph has reserved words defined return set() diff --git a/graflo/hq/caster.py b/graflo/hq/caster.py index 9eced67..dd0ca93 100644 --- a/graflo/hq/caster.py +++ b/graflo/hq/caster.py @@ -36,8 +36,8 @@ DataSourceRegistry, ) from graflo.data_source.sql import SQLConfig, SQLDataSource -from graflo.db import DBType, ConnectionManager, DBConfig -from graflo.onto import DBFlavor +from graflo.db import ConnectionManager +from graflo.db.connection.onto import DBConfig from graflo.util.chunker import ChunkerType from graflo.util.onto import FilePattern, Patterns, ResourceType, TablePattern @@ -498,27 +498,6 @@ def ingest_data_sources( ) logger.info(f"Processing took {klepsidra.elapsed:.1f} sec") - @staticmethod - def _get_db_flavor_from_config(output_config: DBConfig) -> DBFlavor: - """Convert DBConfig connection type to DBFlavor. - - Args: - output_config: Database configuration - - Returns: - DBFlavor enum value corresponding to the database type - """ - db_type = output_config.connection_type - if db_type == DBType.ARANGO: - return DBFlavor.ARANGO - elif db_type == DBType.NEO4J: - return DBFlavor.NEO4J - elif db_type == DBType.TIGERGRAPH: - return DBFlavor.TIGERGRAPH - else: - # Default to ARANGO for unknown types - return DBFlavor.ARANGO - def _register_file_sources( self, registry: DataSourceRegistry, @@ -706,7 +685,7 @@ def ingest( ingestion_params = ingestion_params or IngestionParams() # Initialize vertex config with correct field types based on database type - db_flavor = self._get_db_flavor_from_config(output_config) + db_flavor = output_config.connection_type self.schema.vertex_config.db_flavor = db_flavor self.schema.vertex_config.finish_init() # Initialize edge config after vertex config is fully initialized diff --git a/graflo/hq/graph_engine.py b/graflo/hq/graph_engine.py index c5165d2..61f8cea 100644 --- a/graflo/hq/graph_engine.py +++ b/graflo/hq/graph_engine.py @@ -8,12 +8,12 @@ import logging from graflo import Schema +from graflo.onto import DBType from graflo.db import ConnectionManager, PostgresConnection -from graflo.db.connection.onto import DBConfig, DBType, PostgresConfig +from graflo.db.connection.onto import DBConfig, PostgresConfig from graflo.hq.caster import Caster, IngestionParams from graflo.hq.inferencer import InferenceManager from graflo.hq.resource_mapper import ResourceMapper -from graflo.onto import DBFlavor from graflo.util.onto import Patterns logger = logging.getLogger(__name__) @@ -38,7 +38,7 @@ class GraphEngine: def __init__( self, - target_db_flavor: DBFlavor = DBFlavor.ARANGO, + target_db_flavor: DBType = DBType.ARANGO, ): """Initialize the GraphEngine. @@ -125,6 +125,58 @@ def define_schema( with ConnectionManager(connection_config=output_config) as db_client: db_client.init_db(schema, clean_start) + def define_and_ingest( + self, + schema: Schema, + output_config: DBConfig, + patterns: "Patterns | None" = None, + ingestion_params: IngestionParams | None = None, + clean_start: bool | None = None, + ) -> None: + """Define schema and ingest data into the graph database in one operation. + + This is a convenience method that chains define_schema() and ingest(). + It's the recommended way to set up and populate a graph database. + + Args: + schema: Schema configuration for the graph + output_config: Target database connection configuration + patterns: Patterns instance mapping resources to data sources. + If None, defaults to empty Patterns() + ingestion_params: IngestionParams instance with ingestion configuration. + If None, uses default IngestionParams() + clean_start: Whether to clean the database before defining schema. + If None, uses ingestion_params.clean_start if provided, otherwise False. + Note: If clean_start is True, ingestion_params.clean_start will be + set to False to avoid double-cleaning. + """ + ingestion_params = ingestion_params or IngestionParams() + + # Determine clean_start value: explicit parameter > ingestion_params > False + if clean_start is None: + clean_start = ingestion_params.clean_start + + # Define schema first + self.define_schema( + schema=schema, + output_config=output_config, + clean_start=clean_start, + ) + + # If we cleaned during schema definition, don't clean again during ingestion + if clean_start: + ingestion_params = IngestionParams( + **{**ingestion_params.model_dump(), "clean_start": False} + ) + + # Then ingest data + self.ingest( + schema=schema, + output_config=output_config, + patterns=patterns, + ingestion_params=ingestion_params, + ) + def ingest( self, schema: Schema, diff --git a/graflo/hq/inferencer.py b/graflo/hq/inferencer.py index 76ca6e8..7f5be22 100644 --- a/graflo/hq/inferencer.py +++ b/graflo/hq/inferencer.py @@ -1,9 +1,9 @@ from graflo import Schema +from graflo.onto import DBType from graflo.architecture import Resource from graflo.db import PostgresConnection from graflo.db.postgres import PostgresSchemaInferencer, PostgresResourceMapper from graflo.hq.sanitizer import SchemaSanitizer -from graflo.onto import DBFlavor import logging logger = logging.getLogger(__name__) @@ -15,7 +15,7 @@ class InferenceManager: def __init__( self, conn: PostgresConnection, - target_db_flavor: DBFlavor = DBFlavor.ARANGO, + target_db_flavor: DBType = DBType.ARANGO, fuzzy_threshold: float = 0.8, ): """Initialize the PostgreSQL inference manager. diff --git a/graflo/hq/sanitizer.py b/graflo/hq/sanitizer.py index 881bbdd..5f64305 100644 --- a/graflo/hq/sanitizer.py +++ b/graflo/hq/sanitizer.py @@ -14,7 +14,7 @@ from graflo.architecture.edge import Edge from graflo.architecture.schema import Schema from graflo.architecture.vertex import Field -from graflo.onto import DBFlavor +from graflo.onto import DBType from graflo.db.util import load_reserved_words, sanitize_attribute_name @@ -37,7 +37,7 @@ class SchemaSanitizer: - Applying field index mappings to resources """ - def __init__(self, db_flavor: DBFlavor): + def __init__(self, db_flavor: DBType): """Initialize the schema sanitizer. Args: @@ -149,7 +149,7 @@ def sanitize(self, schema: Schema) -> Schema: str, dict[str, str] ] = {} # vertex_name -> {old_field: new_field} - if schema.vertex_config.db_flavor == DBFlavor.TIGERGRAPH: + if schema.vertex_config.db_flavor == DBType.TIGERGRAPH: # Group edges by relation edges_by_relation: dict[str | None, list[Edge]] = {} for edge in schema.edge_config.edges: diff --git a/graflo/onto.py b/graflo/onto.py index 45e6436..c1e7931 100644 --- a/graflo/onto.py +++ b/graflo/onto.py @@ -7,7 +7,6 @@ Key Components: - BaseEnum: Base class for string-based enumerations with flexible membership testing - BaseDataclass: Base class for dataclasses with JSON/YAML serialization support - - DBFlavor: Enum for supported database types (ArangoDB, Neo4j) - ExpressionFlavor: Enum for expression language types - AggregationType: Enum for supported aggregation operations @@ -97,28 +96,6 @@ def base_enum_representer(dumper, data): _register_yaml_representer() -class DBFlavor(BaseEnum): - """Supported database types. - - This enum defines the supported graph database types in the system. - - Attributes: - ARANGO: ArangoDB database - NEO4J: Neo4j database - TIGERGRAPH: TigerGraph database - FALKORDB: FalkorDB database (Redis-based graph database using Cypher) - MEMGRAPH: Memgraph database (in-memory graph database using Cypher) - NEBULA: NebulaGraph database - """ - - ARANGO = "arango" - NEO4J = "neo4j" - TIGERGRAPH = "tigergraph" - FALKORDB = "falkordb" - MEMGRAPH = "memgraph" - NEBULA = "nebula" - - class ExpressionFlavor(BaseEnum): """Supported expression language types. @@ -321,3 +298,24 @@ def get_fields_members(cls): list[str]: List of public field names """ return [k for k in cls.__annotations__ if not k.startswith("_")] + + +class DBType(StrEnum, metaclass=MetaEnum): + """Enum representing different types of databases. + + Includes both graph databases and source databases (SQL, NoSQL, etc.). + """ + + # Graph databases + ARANGO = "arango" + NEO4J = "neo4j" + TIGERGRAPH = "tigergraph" + FALKORDB = "falkordb" + MEMGRAPH = "memgraph" + NEBULA = "nebula" + + # Source databases (SQL, NoSQL) + POSTGRES = "postgres" + MYSQL = "mysql" + MONGODB = "mongodb" + SQLITE = "sqlite" diff --git a/test/architecture/test_vertex.py b/test/architecture/test_vertex.py index e548550..38e5a14 100644 --- a/test/architecture/test_vertex.py +++ b/test/architecture/test_vertex.py @@ -5,7 +5,7 @@ import pytest from graflo.architecture.vertex import Field, FieldType, Vertex, VertexConfig -from graflo.onto import DBFlavor +from graflo.onto import DBType logger = logging.getLogger(__name__) @@ -306,7 +306,7 @@ def test_get_fields_with_defaults_tigergraph(): ], ) - vertex.finish_init(DBFlavor.TIGERGRAPH) + vertex.finish_init(DBType.TIGERGRAPH) # For TigerGraph, None types should default to STRING fields = vertex.get_fields() assert len(fields) == 4 @@ -375,7 +375,7 @@ def test_vertex_config_fields_with_db_flavor(): assert fields[1].type is None # Preserved # Set db_flavor and call finish_init on config - config.db_flavor = DBFlavor.TIGERGRAPH + config.db_flavor = DBType.TIGERGRAPH config.finish_init() # With TigerGraph, should get fields with defaults applied fields = config.fields("user") diff --git a/test/conftest.py b/test/conftest.py index fc0893d..ece67ad 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -71,28 +71,24 @@ def ingest_atomic(conn_conf, current_path, test_db_name, schema_o, mode, n_cores patterns.add_file_pattern(resource_name, file_pattern) # Determine DB flavor from connection config - from graflo.db.connection.onto import DBType + from graflo.onto import DBType from graflo.hq import GraphEngine from graflo.hq.caster import IngestionParams - from graflo.onto import DBFlavor db_type = conn_conf.connection_type - # Map DBType to DBFlavor (they have the same values for graph databases) - db_flavor = ( - DBFlavor(db_type.value) - if db_type - in ( - DBType.ARANGO, - DBType.NEO4J, - DBType.TIGERGRAPH, - DBType.FALKORDB, - DBType.MEMGRAPH, - ) - else DBFlavor.ARANGO - ) + # Ensure it's a graph database (default to ARANGO if not) + if db_type not in ( + DBType.ARANGO, + DBType.NEO4J, + DBType.TIGERGRAPH, + DBType.FALKORDB, + DBType.MEMGRAPH, + DBType.NEBULA, + ): + db_type = DBType.ARANGO # Use GraphEngine for the full workflow - engine = GraphEngine(target_db_flavor=db_flavor) + engine = GraphEngine(target_db_flavor=db_type) # Define schema first (with clean_start=True) engine.define_schema( diff --git a/test/db/postgres/test_schema_inference.py b/test/db/postgres/test_schema_inference.py index 94b0e43..3b13428 100644 --- a/test/db/postgres/test_schema_inference.py +++ b/test/db/postgres/test_schema_inference.py @@ -10,14 +10,14 @@ from unittest.mock import patch from graflo.hq import GraphEngine -from graflo.onto import DBFlavor +from graflo.onto import DBType def test_infer_schema_from_postgres(conn_conf, load_mock_schema): """Test that infer_schema_from_postgres correctly infers schema from PostgreSQL.""" _ = load_mock_schema # Ensure schema is loaded - engine = GraphEngine(target_db_flavor=DBFlavor.ARANGO) + engine = GraphEngine(target_db_flavor=DBType.ARANGO) schema = engine.infer_schema(conn_conf, schema_name="public") # Verify schema structure @@ -185,7 +185,7 @@ def test_infer_schema_with_pg_catalog_fallback(conn_conf, load_mock_schema): PostgresConnection, "_check_information_schema_reliable", return_value=False ): # Test that infer_schema_from_postgres works with pg_catalog fallback - engine = GraphEngine(target_db_flavor=DBFlavor.ARANGO) + engine = GraphEngine(target_db_flavor=DBType.ARANGO) schema = engine.infer_schema(conn_conf, schema_name="public") # Verify schema structure diff --git a/test/db/tigergraphs/test_reserved_words.py b/test/db/tigergraphs/test_reserved_words.py index 88d2693..2fe4153 100644 --- a/test/db/tigergraphs/test_reserved_words.py +++ b/test/db/tigergraphs/test_reserved_words.py @@ -20,7 +20,7 @@ import pytest -from graflo.onto import DBFlavor +from graflo.onto import DBType from test.conftest import fetch_schema_obj from graflo.hq.sanitizer import SchemaSanitizer @@ -44,7 +44,7 @@ def test_vertex_name_sanitization_for_tigergraph(schema_with_reserved_words): """Test that vertex names with reserved words are sanitized for TigerGraph.""" schema = schema_with_reserved_words - sanitizer = SchemaSanitizer(DBFlavor.TIGERGRAPH) + sanitizer = SchemaSanitizer(DBType.TIGERGRAPH) sanitized_schema = sanitizer.sanitize(schema) @@ -61,7 +61,7 @@ def test_edges_sanitization_for_tigergraph(schema_with_incompatible_edges): """Test that vertex names with reserved words are sanitized for TigerGraph.""" schema = schema_with_incompatible_edges - sanitizer = SchemaSanitizer(DBFlavor.TIGERGRAPH) + sanitizer = SchemaSanitizer(DBType.TIGERGRAPH) sanitized_schema = sanitizer.sanitize(schema) From 8a387998a92692942e7c35ea97168c7457994379 Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Tue, 27 Jan 2026 16:13:00 +0100 Subject: [PATCH 8/9] fixed tests; updated docs --- README.md | 44 ++- docs/examples/example-5.md | 55 +-- docs/getting_started/quickstart.md | 43 ++- docs/index.md | 8 + .../5-ingest-postgres/generated-schema.yaml | 3 +- graflo/hq/caster.py | 318 +++++++++++------- ...a.institution.yaml => oa-institution.yaml} | 0 test/conftest.py | 11 - test/data/oa-institution/__init__.py | 0 .../oa.institutions.json} | 0 test/data_source/test_api_data_source.py | 13 +- test/db/arangos/test_ingest_relation.py | 19 +- ...ents.yaml => oa-institution_contents.yaml} | 0 ...n_sizes.yaml => oa-institution_sizes.yaml} | 0 test/test_caster.py | 17 +- 15 files changed, 319 insertions(+), 212 deletions(-) rename test/config/schema/{oa.institution.yaml => oa-institution.yaml} (100%) create mode 100644 test/data/oa-institution/__init__.py rename test/data/{json/oa.institution.json => oa-institution/oa.institutions.json} (100%) rename test/ref/db/{oa_relation_contents.yaml => oa-institution_contents.yaml} (100%) rename test/ref/db/{oa_relation_sizes.yaml => oa-institution_sizes.yaml} (100%) diff --git a/README.md b/README.md index c37503e..2f9e033 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ Resources are your data sources that can be: - Infer edge configurations from foreign key relationships - Create Resource mappings from PostgreSQL tables automatically - Direct database access - ingest data without exporting to files first +- **Async ingestion**: Efficient async/await-based ingestion pipeline for better performance - **Parallel processing**: Use as many cores as you have - **Database support**: Ingest into ArangoDB, Neo4j, **TigerGraph**, **FalkorDB**, and **Memgraph** using the same API (database agnostic). Source data from PostgreSQL and other SQL databases. - **Server-side filtering**: Efficient querying with server-side filtering support (TigerGraph REST++ API) @@ -123,46 +124,65 @@ patterns.add_file_pattern( schema.fetch_resource() from graflo.hq.caster import IngestionParams +from graflo.hq import GraphEngine -caster = Caster(schema) - +# Option 1: Use GraphEngine for schema definition and ingestion (recommended) +engine = GraphEngine() ingestion_params = IngestionParams( clean_start=False, # Set to True to wipe existing database # max_items=1000, # Optional: limit number of items to process # batch_size=10000, # Optional: customize batch size ) -caster.ingest( +engine.define_and_ingest( + schema=schema, output_config=conn_conf, # Target database config patterns=patterns, # Source data patterns ingestion_params=ingestion_params, + clean_start=False, # Set to True to wipe existing database ) + +# Option 2: Use Caster directly (schema must be defined separately) +# from graflo.hq import GraphEngine +# engine = GraphEngine() +# engine.define_schema(schema=schema, output_config=conn_conf, clean_start=False) +# +# caster = Caster(schema) +# caster.ingest( +# output_config=conn_conf, +# patterns=patterns, +# ingestion_params=ingestion_params, +# ) ``` ### PostgreSQL Schema Inference ```python -from graflo.db.postgres import PostgresConnection from graflo.hq import GraphEngine -from graflo.db.connection.onto import PostgresConfig +from graflo.db.connection.onto import PostgresConfig, ArangoConfig from graflo import Caster -from graflo.onto import DBFlavor +from graflo.onto import DBType # Connect to PostgreSQL postgres_config = PostgresConfig.from_docker_env() # or PostgresConfig.from_env() -postgres_conn = PostgresConnection(postgres_config) # Create GraphEngine and infer schema from PostgreSQL 3NF database -engine = GraphEngine(target_db_flavor=DBFlavor.ARANGO) +# Connection is automatically managed inside infer_schema() +engine = GraphEngine(target_db_flavor=DBType.ARANGO) schema = engine.infer_schema( - postgres_conn, + postgres_config, schema_name="public", # PostgreSQL schema name ) -# Close PostgreSQL connection -postgres_conn.close() +# Define schema in target database (optional, can also use define_and_ingest) +target_config = ArangoConfig.from_docker_env() +engine.define_schema( + schema=schema, + output_config=target_config, + clean_start=False, +) -# Use the inferred schema with Caster +# Use the inferred schema with Caster for ingestion caster = Caster(schema) # ... continue with ingestion ``` diff --git a/docs/examples/example-5.md b/docs/examples/example-5.md index a845d4e..65e6483 100644 --- a/docs/examples/example-5.md +++ b/docs/examples/example-5.md @@ -264,25 +264,20 @@ Automatically generate a graflo Schema from your PostgreSQL database. This is th ```python from graflo.hq import GraphEngine -from graflo.onto import DBFlavor, DBType +from graflo.onto import DBType from graflo.db.connection.onto import ArangoConfig, Neo4jConfig, TigergraphConfig, FalkordbConfig, PostgresConfig -# Connect to target graph database to determine flavor +# Connect to target graph database to determine database type # Choose one of: ArangoConfig, Neo4jConfig, TigergraphConfig, or FalkordbConfig target_config = ArangoConfig.from_docker_env() # or Neo4jConfig, TigergraphConfig, FalkordbConfig -# Determine db_flavor from target config +# Get database type from target config db_type = target_config.connection_type -db_flavor = ( - DBFlavor(db_type.value) - if db_type in (DBType.ARANGO, DBType.NEO4J, DBType.TIGERGRAPH, DBType.FALKORDB) - else DBFlavor.ARANGO -) # Create GraphEngine and infer schema automatically # Connection is automatically managed inside infer_schema() postgres_conf = PostgresConfig.from_docker_env() -engine = GraphEngine(target_db_flavor=db_flavor) +engine = GraphEngine(target_db_flavor=db_type) schema = engine.infer_schema( postgres_conf, schema_name="public", # PostgreSQL schema name @@ -372,26 +367,22 @@ Finally, ingest the data from PostgreSQL into your target graph database. This i 5. **Graph Database Storage**: Data is written to the target graph database (ArangoDB/Neo4j/TigerGraph) using database-specific APIs for optimal performance. The system handles duplicates and updates based on indexes. ```python -from graflo import Caster - -# Create Caster with inferred schema -caster = Caster(schema) - -# Ingest data from PostgreSQL into graph database +from graflo.hq import GraphEngine from graflo.hq.caster import IngestionParams +# Use GraphEngine for schema definition and ingestion +engine = GraphEngine() ingestion_params = IngestionParams( clean_start=True, # Clear existing data first ) -caster.ingest( +engine.define_and_ingest( + schema=schema, output_config=target_config, # Target graph database config patterns=patterns, # PostgreSQL table patterns ingestion_params=ingestion_params, + clean_start=True, # Clear existing data first ) - -# Cleanup -postgres_conn.close() ``` ## Complete Example @@ -403,11 +394,10 @@ import logging from pathlib import Path import yaml -from graflo import Caster -from graflo.onto import DBFlavor -from graflo.db import DBType +from graflo.onto import DBType from graflo.hq import GraphEngine from graflo.db.connection.onto import ArangoConfig, PostgresConfig +from graflo.hq.caster import IngestionParams logger = logging.getLogger(__name__) @@ -426,14 +416,9 @@ target_config = ArangoConfig.from_docker_env() # or Neo4jConfig, TigergraphConf # Step 4: Infer Schema from PostgreSQL database structure # Connection is automatically managed inside infer_schema() db_type = target_config.connection_type -db_flavor = ( - DBFlavor(db_type.value) - if db_type in (DBType.ARANGO, DBType.NEO4J, DBType.TIGERGRAPH) - else DBFlavor.ARANGO -) # Create GraphEngine and infer schema -engine = GraphEngine(target_db_flavor=db_flavor) +engine = GraphEngine(target_db_flavor=db_type) schema = engine.infer_schema( postgres_conf, schema_name="public", @@ -449,24 +434,20 @@ logger.info(f"Inferred schema saved to {schema_output_file}") engine = GraphEngine() patterns = engine.create_patterns(postgres_conf, schema_name="public") -# Step 7: Create Caster and ingest data -from graflo.hq.caster import IngestionParams - -caster = Caster(schema) - +# Step 7: Define schema and ingest data ingestion_params = IngestionParams( clean_start=True, # Clear existing data first ) -caster.ingest( +# Use GraphEngine to define schema and ingest data +engine.define_and_ingest( + schema=schema, output_config=target_config, patterns=patterns, ingestion_params=ingestion_params, + clean_start=True, # Clear existing data first ) -# Cleanup -postgres_conn.close() - print("\n" + "=" * 80) print("Ingestion complete!") print("=" * 80) diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index 3de534b..edd475c 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -10,7 +10,7 @@ This guide will help you get started with graflo by showing you how to transform - `DataSource` defines where data comes from (files, APIs, SQL databases, in-memory objects). - Class `Patterns` manages the mapping of resources to their physical data sources (files or PostgreSQL tables). It efficiently handles PostgreSQL connections by grouping tables that share the same connection configuration. - `DataSourceRegistry` maps DataSources to Resources (many DataSources can map to the same Resource). -1- Database backend configurations use Pydantic `BaseSettings` with environment variable support. Use `ArangoConfig`, `Neo4jConfig`, `TigergraphConfig`, `FalkordbConfig`, `MemgraphConfig`, or `PostgresConfig` directly, or load from docker `.env` files using `from_docker_env()`. All configs inherit from `DBConfig` and support unified `database`/`schema_name` structure with `effective_database` and `effective_schema` properties for database-agnostic access. If `effective_schema` is not set, `Caster` automatically uses `Schema.general.name` as fallback. +1- Database backend configurations use Pydantic `BaseSettings` with environment variable support. Use `ArangoConfig`, `Neo4jConfig`, `TigergraphConfig`, `FalkordbConfig`, `MemgraphConfig`, or `PostgresConfig` directly, or load from docker `.env` files using `from_docker_env()`. All configs inherit from `DBConfig` and support unified `database`/`schema_name` structure with `effective_database` and `effective_schema` properties for database-agnostic access. If `effective_schema` is not set, `GraphEngine.define_schema()` automatically uses `Schema.general.name` as fallback. ## Basic Example @@ -77,16 +77,32 @@ patterns = Patterns( ) from graflo.hq.caster import IngestionParams +from graflo.hq import GraphEngine +# Option 1: Use GraphEngine for schema definition and ingestion (recommended) +engine = GraphEngine() ingestion_params = IngestionParams( clean_start=False, # Set to True to wipe existing database ) -caster.ingest( +engine.define_and_ingest( + schema=schema, output_config=conn_conf, # Target database config patterns=patterns, # Source data patterns ingestion_params=ingestion_params, + clean_start=False, # Set to True to wipe existing database ) + +# Option 2: Use Caster directly (schema must be defined separately) +# engine = GraphEngine() +# engine.define_schema(schema=schema, output_config=conn_conf, clean_start=False) +# +# caster = Caster(schema) +# caster.ingest( +# output_config=conn_conf, +# patterns=patterns, +# ingestion_params=ingestion_params, +# ) ``` Here `schema` defines the graph and the mapping the sources to vertices and edges (refer to [Schema](../concepts/index.md#schema) for details on schema and its components). @@ -138,23 +154,23 @@ patterns = Patterns( # Ingest from graflo.db.connection.onto import ArangoConfig +from graflo.hq import GraphEngine arango_config = ArangoConfig.from_docker_env() # Target graph database -caster = Caster(schema) - -from graflo.hq.caster import IngestionParams +# Use GraphEngine for schema definition and ingestion +engine = GraphEngine() ingestion_params = IngestionParams( clean_start=False, # Set to True to wipe existing database ) -caster.ingest( +engine.define_and_ingest( + schema=schema, output_config=arango_config, # Target graph database patterns=patterns, # Source PostgreSQL tables ingestion_params=ingestion_params, + clean_start=False, # Set to True to wipe existing database ) - -pg_conn.close() ``` ## Using API Data Sources @@ -189,9 +205,18 @@ registry.register(api_source, resource_name="users") # Ingest from graflo.hq.caster import IngestionParams +from graflo.hq import GraphEngine -caster = Caster(schema) +# Define schema first (required before ingestion) +engine = GraphEngine() +engine.define_schema( + schema=schema, + output_config=conn_conf, + clean_start=False, +) +# Then ingest using Caster +caster = Caster(schema) ingestion_params = IngestionParams() # Use default parameters caster.ingest_data_sources( diff --git a/docs/index.md b/docs/index.md index d2d8efc..00dec79 100644 --- a/docs/index.md +++ b/docs/index.md @@ -49,6 +49,13 @@ Resources define how data is transformed into a graph (semantic mapping). They w - **Table-like processing**: CSV files, SQL tables, API responses - **JSON-like processing**: JSON files, nested data structures, hierarchical API responses +### GraphEngine +The `GraphEngine` orchestrates graph database operations, providing a unified interface for: +- Schema inference from PostgreSQL databases +- Schema definition in target graph databases (moved from Caster) +- Pattern creation from data sources +- Data ingestion with async support + ## Key Features - **🚀 PostgreSQL Schema Inference**: **Automatically generate schemas from normalized PostgreSQL databases (3NF)** - No manual schema definition needed! @@ -71,6 +78,7 @@ Resources define how data is transformed into a graph (semantic mapping). They w - Vertex fields support types (INT, FLOAT, STRING, DATETIME, BOOL) for better validation - Edge weight fields can specify types for improved type safety - Backward compatible: fields without types default to None (suitable for databases like ArangoDB) +- **Async Ingestion**: Efficient async/await-based ingestion pipeline for better performance - **Parallel Processing**: Efficient processing with multi-threading - **Database Integration**: Seamless integration with Neo4j, ArangoDB, TigerGraph, FalkorDB, Memgraph, and PostgreSQL (as source) - **Advanced Filtering**: Powerful filtering capabilities for data transformation with server-side filtering support diff --git a/examples/5-ingest-postgres/generated-schema.yaml b/examples/5-ingest-postgres/generated-schema.yaml index cf96ab0..64b3a49 100644 --- a/examples/5-ingest-postgres/generated-schema.yaml +++ b/examples/5-ingest-postgres/generated-schema.yaml @@ -45,7 +45,8 @@ resources: resource_name: purchases transforms: {} vertex_config: - db_flavor: tigergraph + db_flavor: !!python/object/apply:graflo.onto.DBType + - tigergraph vertices: - dbname: products fields: diff --git a/graflo/hq/caster.py b/graflo/hq/caster.py index dd0ca93..09b9d7d 100644 --- a/graflo/hq/caster.py +++ b/graflo/hq/caster.py @@ -14,13 +14,10 @@ >>> caster.ingest(path="data/", conn_conf=db_config) """ +import asyncio import logging -import multiprocessing as mp -import queue import re import sys -from concurrent.futures import ThreadPoolExecutor -from functools import partial from pathlib import Path from typing import Any, cast @@ -28,6 +25,7 @@ from pydantic import BaseModel from suthing import Timer +from graflo.architecture.edge import Edge from graflo.architecture.onto import EncodingType, GraphContainer from graflo.architecture.schema import Schema from graflo.data_source import ( @@ -55,6 +53,9 @@ class IngestionParams(BaseModel): dry: Whether to perform a dry run (no database changes) init_only: Whether to only initialize the database without ingestion limit_files: Optional limit on number of files to process + max_concurrent_db_ops: Maximum number of concurrent database operations (for vertices/edges). + If None, uses n_cores. Set to 1 to prevent deadlocks in databases that don't handle + concurrent transactions well (e.g., Neo4j). Database-independent setting. """ clean_start: bool = False @@ -64,6 +65,7 @@ class IngestionParams(BaseModel): dry: bool = False init_only: bool = False limit_files: int | None = None + max_concurrent_db_ops: int | None = None class Caster: @@ -144,7 +146,7 @@ def discover_files( return files - def cast_normal_resource( + async def cast_normal_resource( self, data, resource_name: str | None = None ) -> GraphContainer: """Cast data into a graph container using a resource. @@ -158,18 +160,19 @@ def cast_normal_resource( """ rr = self.schema.fetch_resource(resource_name) - with ThreadPoolExecutor(max_workers=self.ingestion_params.n_cores) as executor: - docs = list( - executor.map( - lambda doc: rr(doc), - data, - ) - ) + # Process documents in parallel using asyncio + semaphore = asyncio.Semaphore(self.ingestion_params.n_cores) + + async def process_doc(doc): + async with semaphore: + return await asyncio.to_thread(rr, doc) + + docs = await asyncio.gather(*[process_doc(doc) for doc in data]) graph = GraphContainer.from_docs_list(docs) return graph - def process_batch( + async def process_batch( self, batch, resource_name: str | None, @@ -182,12 +185,12 @@ def process_batch( resource_name: Optional name of the resource to use conn_conf: Optional database connection configuration """ - gc = self.cast_normal_resource(batch, resource_name=resource_name) + gc = await self.cast_normal_resource(batch, resource_name=resource_name) if conn_conf is not None: - self.push_db(gc=gc, conn_conf=conn_conf, resource_name=resource_name) + await self.push_db(gc=gc, conn_conf=conn_conf, resource_name=resource_name) - def process_data_source( + async def process_data_source( self, data_source: AbstractDataSource, resource_name: str | None = None, @@ -211,11 +214,11 @@ def process_data_source( for batch in data_source.iter_batches( batch_size=self.ingestion_params.batch_size, limit=limit ): - self.process_batch( + await self.process_batch( batch, resource_name=actual_resource_name, conn_conf=conn_conf ) - def process_resource( + async def process_resource( self, resource_instance: ( Path | str | list[dict] | list[list] | pd.DataFrame | dict[str, Any] @@ -281,13 +284,13 @@ def process_resource( data_source.resource_name = resource_name # Process using the data source - self.process_data_source( + await self.process_data_source( data_source=data_source, resource_name=resource_name, conn_conf=conn_conf, ) - def push_db( + async def push_db( self, gc: GraphContainer, conn_conf: DBConfig, @@ -302,105 +305,181 @@ def push_db( """ vc = self.schema.vertex_config resource = self.schema.fetch_resource(resource_name) - with ConnectionManager(connection_config=conn_conf) as db_client: - for vcol, data in gc.vertices.items(): - # blank nodes: push and get back their keys {"_key": ...} - if vcol in vc.blank_vertices: - query0 = db_client.insert_return_batch(data, vc.vertex_dbname(vcol)) - cursor = db_client.execute(query0) - gc.vertices[vcol] = [item for item in cursor] - else: - db_client.upsert_docs_batch( - data, - vc.vertex_dbname(vcol), - vc.index(vcol), - update_keys="doc", - filter_uniques=True, - dry=self.ingestion_params.dry, - ) - # update edge misc with blank node edges - for vcol in vc.blank_vertices: - for edge_id, edge in self.schema.edge_config.edges_items(): - vfrom, vto, relation = edge_id - if vcol == vfrom or vcol == vto: - if edge_id not in gc.edges: - gc.edges[edge_id] = [] - gc.edges[edge_id].extend( - [ - (x, y, {}) - for x, y in zip(gc.vertices[vfrom], gc.vertices[vto]) - ] - ) - - with ConnectionManager(connection_config=conn_conf) as db_client: - # currently works only on item level - for edge in resource.extra_weights: - if edge.weights is None: - continue - for weight in edge.weights.vertices: - if weight.name in vc.vertex_set: - index_fields = vc.index(weight.name) - - if not self.ingestion_params.dry and weight.name in gc.vertices: - weights_per_item = db_client.fetch_present_documents( - class_name=vc.vertex_dbname(weight.name), - batch=gc.vertices[weight.name], - match_keys=index_fields.fields, - keep_keys=weight.fields, + # Push vertices in parallel (with configurable concurrency control to prevent deadlocks) + # Some databases can deadlock when multiple transactions modify the same nodes + # Use a semaphore to limit concurrent operations based on max_concurrent_db_ops + max_concurrent = ( + self.ingestion_params.max_concurrent_db_ops + if self.ingestion_params.max_concurrent_db_ops is not None + else self.ingestion_params.n_cores + ) + vertex_semaphore = asyncio.Semaphore(max_concurrent) + + async def push_vertex(vcol: str, data: list[dict]): + async with vertex_semaphore: + + def _push_vertex_sync(): + with ConnectionManager(connection_config=conn_conf) as db_client: + # blank nodes: push and get back their keys {"_key": ...} + if vcol in vc.blank_vertices: + query0 = db_client.insert_return_batch( + data, vc.vertex_dbname(vcol) ) + cursor = db_client.execute(query0) + return vcol, [item for item in cursor] + else: + db_client.upsert_docs_batch( + data, + vc.vertex_dbname(vcol), + vc.index(vcol), + update_keys="doc", + filter_uniques=True, + dry=self.ingestion_params.dry, + ) + return vcol, None + + return await asyncio.to_thread(_push_vertex_sync) - for j, item in enumerate(gc.linear): - weights = weights_per_item[j] + # Process all vertices in parallel (with semaphore limiting concurrency for Neo4j) + vertex_results = await asyncio.gather( + *[push_vertex(vcol, data) for vcol, data in gc.vertices.items()] + ) - for ee in item[edge.edge_id]: - weight_collection_attached = { - weight.cfield(k): v - for k, v in weights[0].items() - } - ee.update(weight_collection_attached) - else: - logger.error(f"{weight.name} not a valid vertex") + # Update blank vertices with returned keys + for vcol, result in vertex_results: + if result is not None: + gc.vertices[vcol] = result - with ConnectionManager(connection_config=conn_conf) as db_client: + # update edge misc with blank node edges + for vcol in vc.blank_vertices: for edge_id, edge in self.schema.edge_config.edges_items(): - for ee in gc.loop_over_relations(edge_id): - _, _, relation = ee - if not self.ingestion_params.dry: - data = gc.edges[ee] - db_client.insert_edges_batch( - docs_edges=data, - source_class=vc.vertex_dbname(edge.source), - target_class=vc.vertex_dbname(edge.target), - relation_name=relation, - match_keys_source=vc.index(edge.source).fields, - match_keys_target=vc.index(edge.target).fields, - filter_uniques=False, - dry=self.ingestion_params.dry, - collection_name=edge.database_name, - ) - - def process_with_queue(self, tasks: mp.Queue, conn_conf: DBConfig | None = None): + vfrom, vto, relation = edge_id + if vcol == vfrom or vcol == vto: + if edge_id not in gc.edges: + gc.edges[edge_id] = [] + gc.edges[edge_id].extend( + [ + (x, y, {}) + for x, y in zip(gc.vertices[vfrom], gc.vertices[vto]) + ] + ) + + # Process extra weights + async def process_extra_weights(): + def _process_extra_weights_sync(): + with ConnectionManager(connection_config=conn_conf) as db_client: + # currently works only on item level + for edge in resource.extra_weights: + if edge.weights is None: + continue + for weight in edge.weights.vertices: + if weight.name in vc.vertex_set: + index_fields = vc.index(weight.name) + + if ( + not self.ingestion_params.dry + and weight.name in gc.vertices + ): + weights_per_item = ( + db_client.fetch_present_documents( + class_name=vc.vertex_dbname(weight.name), + batch=gc.vertices[weight.name], + match_keys=index_fields.fields, + keep_keys=weight.fields, + ) + ) + + for j, item in enumerate(gc.linear): + weights = weights_per_item[j] + + for ee in item[edge.edge_id]: + weight_collection_attached = { + weight.cfield(k): v + for k, v in weights[0].items() + } + ee.update(weight_collection_attached) + else: + logger.error(f"{weight.name} not a valid vertex") + + await asyncio.to_thread(_process_extra_weights_sync) + + await process_extra_weights() + + # Push edges in parallel (with configurable concurrency control to prevent deadlocks) + # Some databases can deadlock when multiple transactions modify the same nodes/relationships + # Use a semaphore to limit concurrent operations based on max_concurrent_db_ops + edge_semaphore = asyncio.Semaphore(max_concurrent) + + async def push_edge(edge_id: tuple, edge: Edge): + async with edge_semaphore: + + def _push_edge_sync(): + with ConnectionManager(connection_config=conn_conf) as db_client: + for ee in gc.loop_over_relations(edge_id): + _, _, relation = ee + if not self.ingestion_params.dry: + data = gc.edges[ee] + db_client.insert_edges_batch( + docs_edges=data, + source_class=vc.vertex_dbname(edge.source), + target_class=vc.vertex_dbname(edge.target), + relation_name=relation, + match_keys_source=vc.index(edge.source).fields, + match_keys_target=vc.index(edge.target).fields, + filter_uniques=False, + dry=self.ingestion_params.dry, + collection_name=edge.database_name, + ) + + await asyncio.to_thread(_push_edge_sync) + + # Process all edges in parallel (with semaphore limiting concurrency for Neo4j) + await asyncio.gather( + *[ + push_edge(edge_id, edge) + for edge_id, edge in self.schema.edge_config.edges_items() + ] + ) + + async def process_with_queue( + self, tasks: asyncio.Queue, conn_conf: DBConfig | None = None + ): """Process tasks from a queue. Args: - tasks: Queue of tasks to process + tasks: Async queue of tasks to process conn_conf: Optional database connection configuration """ + # Sentinel value to signal completion + SENTINEL = None + while True: try: - task = tasks.get_nowait() + # Get task from queue (will wait if queue is empty) + task = await tasks.get() + + # Check for sentinel value + if task is SENTINEL: + tasks.task_done() + break + # Support both (Path, str) tuples and DataSource instances if isinstance(task, tuple) and len(task) == 2: filepath, resource_name = task - self.process_resource( + await self.process_resource( resource_instance=filepath, resource_name=resource_name, conn_conf=conn_conf, ) elif isinstance(task, AbstractDataSource): - self.process_data_source(data_source=task, conn_conf=conn_conf) - except queue.Empty: + await self.process_data_source( + data_source=task, conn_conf=conn_conf + ) + tasks.task_done() + except Exception as e: + logger.error(f"Error processing task: {e}", exc_info=True) + tasks.task_done() break @staticmethod @@ -431,7 +510,7 @@ def normalize_resource( rows_dressed = [{k: v for k, v in zip(columns, item)} for item in _data] return rows_dressed - def ingest_data_sources( + async def ingest_data_sources( self, data_source_registry: DataSourceRegistry, conn_conf: DBConfig, @@ -471,29 +550,26 @@ def ingest_data_sources( with Timer() as klepsidra: if self.ingestion_params.n_cores > 1: - queue_tasks: mp.Queue = mp.Queue() + # Use asyncio for parallel processing + queue_tasks: asyncio.Queue = asyncio.Queue() for item in tasks: - queue_tasks.put(item) + await queue_tasks.put(item) - func = partial( - self.process_with_queue, - conn_conf=conn_conf, - ) - assert mp.get_start_method() == "fork", ( - "Requires 'forking' operating system" - ) + # Add sentinel values to signal workers to stop + for _ in range(self.ingestion_params.n_cores): + await queue_tasks.put(None) - processes = [] + # Create worker tasks + worker_tasks = [ + self.process_with_queue(queue_tasks, conn_conf=conn_conf) + for _ in range(self.ingestion_params.n_cores) + ] - for w in range(self.ingestion_params.n_cores): - p = mp.Process(target=func, args=(queue_tasks,)) - processes.append(p) - p.start() - for p in processes: - p.join() + # Run all workers in parallel + await asyncio.gather(*worker_tasks) else: for data_source in tasks: - self.process_data_source( + await self.process_data_source( data_source=data_source, conn_conf=conn_conf ) logger.info(f"Processing took {klepsidra.elapsed:.1f} sec") @@ -695,8 +771,10 @@ def ingest( registry = self._build_registry_from_patterns(patterns, ingestion_params) # Ingest data sources - self.ingest_data_sources( - data_source_registry=registry, - conn_conf=output_config, - ingestion_params=ingestion_params, + asyncio.run( + self.ingest_data_sources( + data_source_registry=registry, + conn_conf=output_config, + ingestion_params=ingestion_params, + ) ) diff --git a/test/config/schema/oa.institution.yaml b/test/config/schema/oa-institution.yaml similarity index 100% rename from test/config/schema/oa.institution.yaml rename to test/config/schema/oa-institution.yaml diff --git a/test/conftest.py b/test/conftest.py index ece67ad..7b4e7ec 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -71,21 +71,10 @@ def ingest_atomic(conn_conf, current_path, test_db_name, schema_o, mode, n_cores patterns.add_file_pattern(resource_name, file_pattern) # Determine DB flavor from connection config - from graflo.onto import DBType from graflo.hq import GraphEngine from graflo.hq.caster import IngestionParams db_type = conn_conf.connection_type - # Ensure it's a graph database (default to ARANGO if not) - if db_type not in ( - DBType.ARANGO, - DBType.NEO4J, - DBType.TIGERGRAPH, - DBType.FALKORDB, - DBType.MEMGRAPH, - DBType.NEBULA, - ): - db_type = DBType.ARANGO # Use GraphEngine for the full workflow engine = GraphEngine(target_db_flavor=db_type) diff --git a/test/data/oa-institution/__init__.py b/test/data/oa-institution/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/data/json/oa.institution.json b/test/data/oa-institution/oa.institutions.json similarity index 100% rename from test/data/json/oa.institution.json rename to test/data/oa-institution/oa.institutions.json diff --git a/test/data_source/test_api_data_source.py b/test/data_source/test_api_data_source.py index c183d07..c519085 100644 --- a/test/data_source/test_api_data_source.py +++ b/test/data_source/test_api_data_source.py @@ -1,5 +1,6 @@ """Tests for API data source implementation.""" +import asyncio import logging from os.path import dirname, realpath @@ -54,7 +55,9 @@ def test_api_data_source_basic( # Create caster and process caster = Caster(schema, n_cores=1) - caster.process_data_source(data_source=api_source, resource_name=resource_name) + asyncio.run( + caster.process_data_source(data_source=api_source, resource_name=resource_name) + ) # Verify we got data # Note: This is a basic test - full verification would require database connection @@ -89,9 +92,11 @@ def test_api_data_source_via_process_resource( }, } - caster.process_resource( - resource_instance=resource_config, - resource_name=resource_name, + asyncio.run( + caster.process_resource( + resource_instance=resource_config, + resource_name=resource_name, + ) ) # Test passes if no exceptions are raised diff --git a/test/db/arangos/test_ingest_relation.py b/test/db/arangos/test_ingest_relation.py index 6e45306..55db5dc 100644 --- a/test/db/arangos/test_ingest_relation.py +++ b/test/db/arangos/test_ingest_relation.py @@ -1,9 +1,7 @@ -from pathlib import Path +from test.conftest import ingest_atomic from test.db.arangos.conftest import verify_from_db -from suthing import FileHandle - -from graflo import Caster, ConnectionManager +from graflo import ConnectionManager def test_ingest( @@ -14,22 +12,17 @@ def test_ingest( test_db_name, reset, ): + m = "oa-institution" _ = create_db - schema_o = schema_obj("oa.institution") - j_resource = FileHandle.load(Path(current_path) / "data/json/oa.institution.json") - - conn_conf.database = test_db_name + schema_o = schema_obj(m) - with ConnectionManager(connection_config=conn_conf) as db_client: - db_client.init_db(schema_o, clean_start=True) - caster = Caster(schema_o) - caster.process_resource(j_resource, "institutions", conn_conf=conn_conf) + ingest_atomic(conn_conf, current_path, test_db_name, schema_o, mode=m) verify_from_db( conn_conf, current_path, test_db_name, - mode="oa_relation", + mode=m, reset=reset, ) diff --git a/test/ref/db/oa_relation_contents.yaml b/test/ref/db/oa-institution_contents.yaml similarity index 100% rename from test/ref/db/oa_relation_contents.yaml rename to test/ref/db/oa-institution_contents.yaml diff --git a/test/ref/db/oa_relation_sizes.yaml b/test/ref/db/oa-institution_sizes.yaml similarity index 100% rename from test/ref/db/oa_relation_sizes.yaml rename to test/ref/db/oa-institution_sizes.yaml diff --git a/test/test_caster.py b/test/test_caster.py index bb16034..df4ee44 100644 --- a/test/test_caster.py +++ b/test/test_caster.py @@ -1,3 +1,4 @@ +import asyncio import logging import os import pathlib @@ -52,8 +53,10 @@ def cast(modes, schema_obj, current_path, level, reset, n_cores=1): f"{mode}.{ext}.gz", ) - caster.process_resource( - resource_instance=pathlib.Path(fname), resource_name=resource_name + asyncio.run( + caster.process_resource( + resource_instance=pathlib.Path(fname), resource_name=resource_name + ) ) else: data_obj = FileHandle.load( @@ -62,12 +65,16 @@ def cast(modes, schema_obj, current_path, level, reset, n_cores=1): ) if level == 1: - caster.process_resource( - resource_instance=data_obj, resource_name=resource_name + asyncio.run( + caster.process_resource( + resource_instance=data_obj, resource_name=resource_name + ) ) elif level == 2: data = caster.normalize_resource(data_obj) - graph = caster.cast_normal_resource(data, resource_name=resource_name) + graph = asyncio.run( + caster.cast_normal_resource(data, resource_name=resource_name) + ) graph.pick_unique() From 4a523d0c71d73ce575316b6d5505e1eae7b0331f Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Tue, 27 Jan 2026 16:14:32 +0100 Subject: [PATCH 9/9] version bump --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 75f291a..b4c6d16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ description = "A framework for transforming tabular (CSV, SQL) and hierarchical name = "graflo" readme = "README.md" requires-python = ">=3.11" -version = "1.4.3" +version = "1.4.4" [project.optional-dependencies] plot = [ diff --git a/uv.lock b/uv.lock index c1c739e..ecb38e7 100644 --- a/uv.lock +++ b/uv.lock @@ -348,7 +348,7 @@ wheels = [ [[package]] name = "graflo" -version = "1.4.3" +version = "1.4.4" source = { editable = "." } dependencies = [ { name = "click" },