From 53c3f2d5b4486d7c70a44983a907196119a88002 Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Mon, 2 Feb 2026 15:31:16 +0100 Subject: [PATCH 1/7] WIP: dataclasses -> pydantic --- .cursorignore | 2 + .gitignore | 3 +- graflo/architecture/actor.py | 276 ++++++------- graflo/architecture/actor_config.py | 366 +++++++++++++++++ graflo/architecture/base.py | 117 ++++++ graflo/architecture/edge.py | 160 +++----- graflo/architecture/onto.py | 129 +++--- graflo/architecture/resource.py | 161 +++++--- graflo/architecture/schema.py | 114 ++++-- graflo/architecture/transform.py | 200 ++++----- graflo/architecture/vertex.py | 413 +++++++++---------- graflo/data_source/factory.py | 2 - graflo/db/arango/conn.py | 8 + graflo/db/arango/query.py | 10 +- graflo/db/arango/util.py | 13 +- graflo/db/conn.py | 23 +- graflo/db/falkordb/conn.py | 9 +- graflo/db/memgraph/conn.py | 8 +- graflo/db/neo4j/conn.py | 6 +- graflo/db/postgres/resource_mapping.py | 4 +- graflo/db/postgres/schema_inference.py | 13 +- graflo/db/tigergraph/conn.py | 16 +- graflo/filter/onto.py | 543 +++++++++---------------- graflo/hq/inferencer.py | 2 +- graflo/hq/sanitizer.py | 13 +- graflo/onto.py | 35 +- test/architecture/conftest.py | 10 +- test/architecture/test_actor.py | 35 ++ test/architecture/test_edge.py | 2 +- test/architecture/test_vertex.py | 38 +- test/test_filters.py | 10 +- test/test_filters_python.py | 6 +- 32 files changed, 1561 insertions(+), 1186 deletions(-) create mode 100644 graflo/architecture/actor_config.py create mode 100644 graflo/architecture/base.py diff --git a/.cursorignore b/.cursorignore index ba81bf62..e65559c3 100644 --- a/.cursorignore +++ b/.cursorignore @@ -28,6 +28,8 @@ !sh/ !sh/**/*.sh +!planning +!planning/*MD !.github/** !.env.example diff --git a/.gitignore b/.gitignore index bc8b3ff6..efc70c67 100644 --- a/.gitignore +++ b/.gitignore @@ -76,4 +76,5 @@ target/ #*/**/*png */**/*pdf -site/ \ No newline at end of file +site/ +planning/ \ No newline at end of file diff --git a/graflo/architecture/actor.py b/graflo/architecture/actor.py index 35b3ddee..7e2a401f 100644 --- a/graflo/architecture/actor.py +++ b/graflo/architecture/actor.py @@ -27,7 +27,18 @@ from functools import reduce from pathlib import Path from types import MappingProxyType -from typing import Any, Callable, Generic, Type, TypeVar +from typing import Any, Callable, Type + +from graflo.architecture.actor_config import ( + ActorConfig, + DescendActorConfig, + EdgeActorConfig, + TransformActorConfig, + VertexActorConfig, + parse_root_config, + normalize_actor_step, + validate_actor_step, +) from graflo.architecture.actor_util import ( add_blank_collections, @@ -203,23 +214,19 @@ class VertexActor(Actor): vertex_config: Configuration for the vertex """ - def __init__( - self, - vertex: str, - keep_fields: tuple[str, ...] | None = None, - **kwargs, - ): - """Initialize the vertex actor. - - Args: - vertex: Name of the vertex - keep_fields: Optional tuple of fields to keep - **kwargs: Additional initialization parameters - """ - self.name = vertex - self.keep_fields: tuple[str, ...] | None = keep_fields + def __init__(self, config: VertexActorConfig): + """Initialize the vertex actor from config.""" + self.name = config.vertex + self.keep_fields: tuple[str, ...] | None = ( + tuple(config.keep_fields) if config.keep_fields else None + ) self.vertex_config: VertexConfig + @classmethod + def from_config(cls, config: VertexActorConfig) -> VertexActor: + """Create a VertexActor from a VertexActorConfig.""" + return cls(config) + def fetch_important_items(self) -> dict[str, Any]: """Get important items for string representation. @@ -399,18 +406,18 @@ class EdgeActor(Actor): vertex_config: Vertex configuration """ - def __init__( - self, - **kwargs: Any, - ): - """Initialize the edge actor. - - Args: - **kwargs: Edge configuration parameters - """ + def __init__(self, config: EdgeActorConfig): + """Initialize the edge actor from config.""" + kwargs = config.model_dump(by_alias=False, exclude_none=True) + kwargs.pop("type", None) self.edge = Edge.from_dict(kwargs) self.vertex_config: VertexConfig + @classmethod + def from_config(cls, config: EdgeActorConfig) -> EdgeActor: + """Create an EdgeActor from an EdgeActorConfig.""" + return cls(config) + def fetch_important_items(self) -> dict[str, Any]: """Get important items for string representation. @@ -482,26 +489,30 @@ class TransformActor(Actor): both simple and complex transformation scenarios. Attributes: - _kwargs: Original initialization parameters - vertex: Optional target vertex + _kwargs: Config dump for init_transforms (module, foo, input, output) + vertex: Optional target vertex (to_vertex) transforms: Dictionary of available transforms name: Transform name params: Transform parameters t: Transform instance """ - def __init__(self, **kwargs: Any): - """Initialize the transform actor. - - Args: - **kwargs: Transform configuration parameters - """ - self._kwargs = kwargs - self.vertex: str | None = kwargs.pop("target_vertex", None) - self.transforms: dict[str, ProtoTransform] - self.name: str | None = kwargs.get("name", None) - self.params: dict[str, Any] = kwargs.get("params", {}) - self.t: Transform = Transform(**kwargs) + def __init__(self, config: TransformActorConfig): + """Initialize the transform actor from config.""" + self._kwargs = config.model_dump(by_alias=True) + self.vertex = config.to_vertex + self.transforms = {} + self.name = config.name + self.params = config.params + self.t: Transform = Transform( + map=config.map or {}, + name=config.name, + params=config.params, + module=config.module, + foo=config.foo, + input=tuple(config.input) if config.input else (), + output=tuple(config.output) if config.output else (), + ) def fetch_important_items(self) -> dict[str, Any]: """Get important items for string representation. @@ -513,6 +524,11 @@ def fetch_important_items(self) -> dict[str, Any]: items.update({"t.input": self.t.input, "t.output": self.t.output}) return items + @classmethod + def from_config(cls, config: TransformActorConfig) -> TransformActor: + """Create a TransformActor from a TransformActorConfig.""" + return cls(config) + def init_transforms(self, **kwargs: Any) -> None: """Initialize available transforms. @@ -561,7 +577,7 @@ def finish_init(self, **kwargs: Any) -> None: ): self.t.input = pt.input self.t.output = pt.output - self.t.__post_init__() + self.t._refresh_derived() def _extract_doc(self, nargs: tuple[Any, ...], **kwargs: Any) -> dict[str, Any]: """Extract document from arguments. @@ -654,7 +670,11 @@ class DescendActor(Actor): """ def __init__( - self, key: str | None, descendants_kwargs: list, any_key: bool = False, **kwargs + self, + key: str | None, + any_key: bool = False, + *, + _descendants: list[ActorWrapper] | None = None, ): """Initialize the descend actor. @@ -662,17 +682,13 @@ def __init__( key: Optional key for accessing nested data. If provided, only this key will be processed. Mutually exclusive with `any_key`. any_key: If True, processes all keys in a dictionary instead of a specific key. - When enabled, iterates over all key-value pairs in the document dictionary. - Mutually exclusive with `key`. - descendants_kwargs: List of child actor configurations - **kwargs: Additional initialization parameters + _descendants: Pre-built list of child ActorWrappers (from config). """ self.key = key self.any_key = any_key - self._descendants: list[ActorWrapper] = [] - for descendant_kwargs in descendants_kwargs: - self._descendants += [ActorWrapper(**descendant_kwargs, **kwargs)] - # Sort descendants once after initialization + self._descendants: list[ActorWrapper] = ( + list(_descendants) if _descendants else [] + ) self._descendants.sort(key=lambda x: _NodeTypePriority[type(x.actor)]) def fetch_important_items(self): @@ -713,6 +729,12 @@ def descendants(self) -> list[ActorWrapper]: """ return self._descendants + @classmethod + def from_config(cls, config: DescendActorConfig) -> DescendActor: + """Create a DescendActor from a DescendActorConfig.""" + wrappers = [ActorWrapper.from_config(c) for c in config.pipeline] + return cls(key=config.into, any_key=config.any_key, _descendants=wrappers) + def init_transforms(self, **kwargs: Any) -> None: """Initialize transforms for all descendants. @@ -788,7 +810,9 @@ def finish_init(self, **kwargs: Any) -> None: skip_vertex = len(transform_targets) >= 2 if intersection and v.name not in present_vertices: if not skip_vertex or (v.name in transform_targets and skip_vertex): - new_descendant = ActorWrapper(vertex=v.name) + new_descendant = ActorWrapper.from_config( + VertexActorConfig(vertex=v.name) + ) new_descendant.finish_init(**kwargs) self.add_descendant(new_descendant) @@ -916,10 +940,8 @@ def fetch_actors(self, level, edges): } ) -A = TypeVar("A", bound=Actor) - -class ActorWrapper(Generic[A]): +class ActorWrapper: """Wrapper class for managing actor instances. This class provides a unified interface for creating and managing different types @@ -932,35 +954,22 @@ class ActorWrapper(Generic[A]): """ def __init__(self, *args: Any, **kwargs: Any) -> None: - """Initialize the actor wrapper. + """Initialize the actor wrapper from config only. - Args: - *args: Positional arguments for actor initialization - **kwargs: Keyword arguments for actor initialization + Accepts the same shapes as parse_root_config: + - Single step dict or **kwargs: e.g. ActorWrapper(vertex="user") + - Pipeline: ActorWrapper(pipeline=[...]) or ActorWrapper(*list_of_steps) Raises: - ValueError: If unable to initialize an actor + ValueError: If input does not validate as ActorConfig """ - self.actor: Actor - self.vertex_config: VertexConfig - self.edge_config: EdgeConfig - self.edge_greedy: bool = True - self.target_vertices: set[str] = set() - - # Try initialization methods in order - # Make a single copy of kwargs to avoid mutation issues - # (only _try_init_descend modifies kwargs, but we use copy for all for consistency) - kwargs_copy = kwargs.copy() - if self._try_init_descend(*args, **kwargs_copy): - pass - elif self._try_init_transform(**kwargs_copy): - pass - elif self._try_init_vertex(**kwargs_copy): - pass - elif self._try_init_edge(**kwargs_copy): - pass - else: - raise ValueError(f"Not able to init ActorWrapper with {kwargs}") + config = parse_root_config(*args, **kwargs) + w = ActorWrapper.from_config(config) + self.actor = w.actor + self.vertex_config = w.vertex_config + self.edge_config = w.edge_config + self.edge_greedy = w.edge_greedy + self.target_vertices = w.target_vertices def init_transforms(self, **kwargs: Any) -> None: """Initialize transforms for the wrapped actor. @@ -1018,91 +1027,34 @@ def count(self): """ return self.actor.count() - def _try_init_descend(self, *args: Any, **kwargs: Any) -> bool: - """Try to initialize a descend actor. - - Args: - *args: Positional arguments - **kwargs: Keyword arguments (may be modified) - - Returns: - bool: True if successful, False otherwise - """ - # Check if we have the required arguments before modifying kwargs - has_apply = "apply" in kwargs - has_args = len(args) > 0 - if not (has_apply or has_args): - return False - - # Now safe to pop from kwargs - descend_key = kwargs.pop(ActorConstants.DESCEND_KEY, None) - descendants = kwargs.pop("apply", None) - - if descendants is not None: - descendants = ( - descendants if isinstance(descendants, list) else [descendants] - ) - elif len(args) > 0: - descendants = list(args) + @classmethod + def from_config(cls, config: ActorConfig) -> ActorWrapper: + """Create an ActorWrapper from a validated ActorConfig (Pydantic model).""" + if isinstance(config, VertexActorConfig): + actor = VertexActor.from_config(config) + elif isinstance(config, TransformActorConfig): + actor = TransformActor.from_config(config) + elif isinstance(config, EdgeActorConfig): + actor = EdgeActor.from_config(config) + elif isinstance(config, DescendActorConfig): + actor = DescendActor.from_config(config) else: - return False - - try: - self.actor = DescendActor( - descend_key, descendants_kwargs=descendants, **kwargs + raise ValueError( + f"Expected VertexActorConfig, TransformActorConfig, EdgeActorConfig, or DescendActorConfig, got {type(config)}" ) - return True - except (TypeError, ValueError, AttributeError) as e: - logger.debug(f"Failed to initialize DescendActor: {e}") - return False - - def _try_init_transform(self, **kwargs: Any) -> bool: - """Try to initialize a transform actor. - - Args: - **kwargs: Keyword arguments - - Returns: - bool: True if successful, False otherwise - """ - try: - self.actor = TransformActor(**kwargs) - return True - except (TypeError, ValueError, AttributeError) as e: - logger.debug(f"Failed to initialize TransformActor: {e}") - return False - - def _try_init_vertex(self, **kwargs: Any) -> bool: - """Try to initialize a vertex actor. - - Args: - **kwargs: Keyword arguments - - Returns: - bool: True if successful, False otherwise - """ - try: - self.actor = VertexActor(**kwargs) - return True - except (TypeError, ValueError, AttributeError) as e: - logger.debug(f"Failed to initialize VertexActor: {e}") - return False + wrapper = cls.__new__(cls) + wrapper.actor = actor + wrapper.vertex_config = VertexConfig(vertices=[]) + wrapper.edge_config = EdgeConfig() + wrapper.edge_greedy = True + wrapper.target_vertices = set() + return wrapper - def _try_init_edge(self, **kwargs: Any) -> bool: - """Try to initialize an edge actor. - - Args: - **kwargs: Keyword arguments - - Returns: - bool: True if successful, False otherwise - """ - try: - self.actor = EdgeActor(**kwargs) - return True - except (TypeError, ValueError, AttributeError) as e: - logger.debug(f"Failed to initialize EdgeActor: {e}") - return False + @classmethod + def _from_step(cls, step: dict[str, Any]) -> ActorWrapper: + """Build an ActorWrapper from a single pipeline step dict (normalize + validate + from_config).""" + config = validate_actor_step(normalize_actor_step(dict(step))) + return cls.from_config(config) def __call__( self, @@ -1283,9 +1235,9 @@ def collect_actors(self) -> list[Actor]: def find_descendants( self, - predicate: Callable[[ActorWrapper[Any]], bool] | None = None, + predicate: Callable[[ActorWrapper], bool] | None = None, *, - actor_type: type[A] | None = None, + actor_type: type[Actor] | None = None, **attr_in: Any, ) -> list[ActorWrapper]: """Find all descendant ActorWrappers matching the given criteria. diff --git a/graflo/architecture/actor_config.py b/graflo/architecture/actor_config.py new file mode 100644 index 00000000..93f63a1a --- /dev/null +++ b/graflo/architecture/actor_config.py @@ -0,0 +1,366 @@ +"""Pydantic models for actor configuration. + +These models define the structure of YAML configuration files for the +actor-based graph transformation system. They provide validation, +type safety, and explicit format support (pipeline, transform/map/to_vertex, +create_edge/edge with from/to). + +These replace the implicit type inference in ActorWrapper.__init__() +with explicit Pydantic discriminated unions. +""" + +from __future__ import annotations + +import logging +from typing import Annotated, Any, Literal, cast + +from pydantic import Field, TypeAdapter, model_validator + +from graflo.architecture.base import ConfigBaseModel + +logger = logging.getLogger(__name__) + + +def _steps_list(val: Any) -> list[Any]: + """Ensure value is a list of steps (single dict becomes [dict]).""" + return val if isinstance(val, list) else [val] + + +def normalize_actor_step(data: dict[str, Any]) -> dict[str, Any]: + """Normalize a raw step dict so it has 'type' and flat structure for validation. + + Supports explicit format: + - {"vertex": "user"} -> {"type": "vertex", "vertex": "user"} + - {"transform": {"map": {...}, "to_vertex": "x"}} -> {"type": "transform", "map": {...}, "to_vertex": "x"} + - {"edge": {"from": "a", "to": "b"}} or {"create_edge": {...}} -> {"type": "edge", "from": "a", "to": "b"} + - {"descend": {"into": "k", "pipeline": [...]}} or {"apply": [...]} / {"pipeline": [...]} + """ + if not isinstance(data, dict): + return data + data = dict(data) + if "type" in data: + return data + + if "vertex" in data: + data["type"] = "vertex" + return data + + if "transform" in data: + inner = data.pop("transform") + if isinstance(inner, dict): + data.update(inner) + data["type"] = "transform" + return data + + if "edge" in data: + inner = data.pop("edge") + if isinstance(inner, dict): + data.update(inner) + data["type"] = "edge" + return data + if ("source" in data or "from" in data) and ("target" in data or "to" in data): + data = dict(data) + data["type"] = "edge" + return data + if "create_edge" in data: + inner = data.pop("create_edge") + if isinstance(inner, dict): + data.update(inner) + data["type"] = "edge" + return data + + if "descend" in data: + inner = data.pop("descend") + if isinstance(inner, dict): + if "pipeline" in inner: + inner["pipeline"] = [ + normalize_actor_step(s) for s in _steps_list(inner["pipeline"]) + ] + elif "apply" in inner: + inner["pipeline"] = [ + normalize_actor_step(s) for s in _steps_list(inner["apply"]) + ] + del inner["apply"] + data.update(inner) + data["type"] = "descend" + if "pipeline" not in data and "apply" in data: + data["pipeline"] = [ + normalize_actor_step(s) for s in _steps_list(data["apply"]) + ] + del data["apply"] + return data + + if "apply" in data: + data["type"] = "descend" + data["pipeline"] = [normalize_actor_step(s) for s in _steps_list(data["apply"])] + del data["apply"] + return data + if "pipeline" in data: + data["type"] = "descend" + data["pipeline"] = [ + normalize_actor_step(s) for s in _steps_list(data["pipeline"]) + ] + return data + + # Minimal transform step: only "name" or only "map" + if "type" not in data and ("name" in data or "map" in data): + data = dict(data) + data["type"] = "transform" + return data + + return data + + +class VertexActorConfig(ConfigBaseModel): + """Configuration for a VertexActor.""" + + type: Literal["vertex"] = Field( + default="vertex", description="Actor type discriminator" + ) + vertex: str = Field(..., description="Name of the vertex type to create") + keep_fields: list[str] | None = Field( + default=None, description="Optional list of fields to keep" + ) + + @model_validator(mode="before") + @classmethod + def set_type(cls, data: Any) -> Any: + if isinstance(data, dict) and "vertex" in data and "type" not in data: + data = dict(data) + data["type"] = "vertex" + return data + + +class TransformActorConfig(ConfigBaseModel): + """Configuration for a TransformActor. + + Explicit format: transform with map and to_vertex (target vertex for output). + """ + + type: Literal["transform"] = Field( + default="transform", description="Actor type discriminator" + ) + map: dict[str, str] | None = Field( + default=None, description="Field mapping: output_key -> input_key" + ) + to_vertex: str | None = Field( + default=None, + alias="target_vertex", + description="Target vertex to send transformed output to", + ) + name: str | None = Field(default=None, description="Named transform function") + params: dict[str, Any] = Field( + default_factory=dict, description="Transform function parameters" + ) + module: str | None = Field( + default=None, description="Module containing transform function" + ) + foo: str | None = Field( + default=None, description="Transform function name in module" + ) + input: list[str] | None = Field( + default=None, description="Input field names for functional transform" + ) + output: list[str] | None = Field( + default=None, description="Output field names for functional transform" + ) + + @model_validator(mode="before") + @classmethod + def set_type_and_flatten(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + data = dict(data) + if "transform" in data and "type" not in data: + inner = data.pop("transform") + if isinstance(inner, dict): + data.update(inner) + data["type"] = "transform" + return data + + +class EdgeActorConfig(ConfigBaseModel): + """Configuration for an EdgeActor. Supports 'from'/'to' and 'source'/'target'.""" + + type: Literal["edge"] = Field( + default="edge", description="Actor type discriminator" + ) + source: str = Field(..., alias="from", description="Source vertex type name") + target: str = Field(..., alias="to", description="Target vertex type name") + match_source: str | None = Field( + default=None, description="Field for matching source vertices" + ) + match_target: str | None = Field( + default=None, description="Field for matching target vertices" + ) + weights: dict[str, list[str]] | None = Field( + default=None, description="Weight configuration" + ) + indexes: list[dict[str, Any]] | None = Field( + default=None, description="Index configuration" + ) + relation: str | None = Field( + default=None, description="Relation name (e.g. for Neo4j)" + ) + relation_field: str | None = Field( + default=None, description="Field to extract relation from" + ) + relation_from_key: bool = Field( + default=False, description="Extract relation from location key" + ) + exclude_target: str | None = Field( + default=None, description="Exclude target from edge rendering" + ) + exclude_source: str | None = Field( + default=None, description="Exclude source from edge rendering" + ) + match: str | None = Field(default=None, description="Match discriminant") + + @model_validator(mode="before") + @classmethod + def set_type_and_flatten(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + data = dict(data) + for key in ("edge", "create_edge"): + if key in data and "type" not in data: + inner = data.pop(key) + if isinstance(inner, dict): + data.update(inner) + data["type"] = "edge" + break + if "source" in data or "from" in data: + if "target" in data or "to" in data and "type" not in data: + data["type"] = "edge" + return data + + +class DescendActorConfig(ConfigBaseModel): + """Configuration for a DescendActor. Uses 'pipeline' (alias 'apply') and optional 'into' (alias 'key').""" + + type: Literal["descend"] = Field( + default="descend", description="Actor type discriminator" + ) + into: str | None = Field( + default=None, alias="key", description="Key to descend into" + ) + any_key: bool = Field(default=False, description="Process all keys") + pipeline: list["ActorConfig"] = Field( + default_factory=list, + alias="apply", + description="Pipeline of actors to apply to nested data", + ) + + @model_validator(mode="before") + @classmethod + def set_type_and_normalize(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + data = dict(data) + if ( + "into" in data + or "key" in data + or "any_key" in data + or "descend" in data + or "apply" in data + or "pipeline" in data + ) and "type" not in data: + data["type"] = "descend" + if "apply" in data and "pipeline" not in data: + data["pipeline"] = data["apply"] + if "descend" in data: + inner = data.pop("descend") + if isinstance(inner, dict): + data.update(inner) + if "pipeline" in data: + data["pipeline"] = [ + normalize_actor_step(s) for s in _steps_list(data["pipeline"]) + ] + return data + + +# Discriminated union for parsing a single pipeline step (used in ActorWrapper and Resource) +ActorConfig = Annotated[ + VertexActorConfig | TransformActorConfig | EdgeActorConfig | DescendActorConfig, + Field(discriminator="type"), +] + +DescendActorConfig.model_rebuild() + +# TypeAdapter for validating a single pipeline step (union type has no model_validate) +_actor_config_adapter: TypeAdapter[ + VertexActorConfig | TransformActorConfig | EdgeActorConfig | DescendActorConfig +] = TypeAdapter( + VertexActorConfig | TransformActorConfig | EdgeActorConfig | DescendActorConfig +) + + +# Keys to strip from step dicts (runtime or resource-level, not part of ActorConfig) +_STEP_STRIP_KEYS = frozenset( + { + "vertex_config", + "edge_config", + "edge_greedy", + "transforms", + "resource_name", + } +) + + +def validate_actor_step( + data: dict[str, Any], +) -> VertexActorConfig | TransformActorConfig | EdgeActorConfig | DescendActorConfig: + """Validate a normalized step dict as ActorConfig (discriminated union).""" + return _actor_config_adapter.validate_python(data) + + +def parse_root_config( + *args: Any, + **kwargs: Any, +) -> VertexActorConfig | TransformActorConfig | EdgeActorConfig | DescendActorConfig: + """Parse root input into a single ActorConfig (single step or descend pipeline). + + Accepts the same shapes as ActorWrapper: + - Single step dict: e.g. {"vertex": "user"} or **kwargs + - Pipeline: list of steps, or kwargs with "apply"/"pipeline" + + Returns: + Validated ActorConfig. For pipeline input, returns a DescendActorConfig + with into=None and pipeline=[...]. + """ + pipeline: list[dict[str, Any]] | None = None + single: dict[str, Any] | None = None + + if kwargs and ("apply" in kwargs or "pipeline" in kwargs): + raw = kwargs.get("pipeline") or kwargs.get("apply") + if raw is not None: + pipeline = cast( + list[dict[str, Any]], + list(raw) if isinstance(raw, list) else [raw], + ) + elif args: + if len(args) == 1 and isinstance(args[0], list): + pipeline = list(args[0]) + elif len(args) == 1 and isinstance(args[0], dict): + single = dict(args[0]) + elif args and all(isinstance(a, dict) for a in args): + pipeline = [dict(a) for a in args] + + if pipeline is not None: + configs = [ + _actor_config_adapter.validate_python(normalize_actor_step(s)) + for s in pipeline + ] + return DescendActorConfig.model_validate( + { + "type": "descend", + "into": None, + "any_key": False, + "pipeline": configs, + } + ) + if single is not None: + step_dict = {k: v for k, v in single.items() if k not in _STEP_STRIP_KEYS} + return _actor_config_adapter.validate_python(normalize_actor_step(step_dict)) + step_kwargs = {k: v for k, v in kwargs.items() if k not in _STEP_STRIP_KEYS} + return _actor_config_adapter.validate_python(normalize_actor_step(step_kwargs)) diff --git a/graflo/architecture/base.py b/graflo/architecture/base.py new file mode 100644 index 00000000..41cd88c8 --- /dev/null +++ b/graflo/architecture/base.py @@ -0,0 +1,117 @@ +"""Base model for Graflo configuration classes with YAML support.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Self + +import yaml +from pydantic import BaseModel, ConfigDict + + +class ConfigBaseModel(BaseModel): + """Base model for all Graflo configuration classes. + + Provides YAML serialization/deserialization and standard configuration + for all Pydantic models in the system. + + This replaces the JSONWizard/YAMLWizard functionality from dataclass-wizard + with Pydantic's superior validation and type safety. + """ + + model_config = ConfigDict( + populate_by_name=True, + extra="forbid", + use_enum_values=True, + validate_assignment=True, + ) + + @classmethod + def from_yaml(cls, path: str) -> Self: + """Load a single instance from a YAML file.""" + with open(path) as f: + data = yaml.safe_load(f) + return cls.model_validate(data) + + @classmethod + def from_yaml_list(cls, path: str) -> list[Self]: + """Load a list of instances from a YAML file.""" + with open(path) as f: + data = yaml.safe_load(f) + if not isinstance(data, list): + raise ValueError(f"Expected list in YAML file, got {type(data)}") + return [cls.model_validate(item) for item in data] + + @classmethod + def from_dict(cls, data: dict[str, Any] | list[Any]) -> Self: + """Load from a dictionary (or list for root model).""" + return cls.model_validate(data) + + def to_yaml(self, path: str, **kwargs: Any) -> None: + """Save instance to a YAML file.""" + with open(path, "w") as f: + yaml.safe_dump( + self.model_dump(by_alias=True, exclude_none=True), + f, + default_flow_style=False, + sort_keys=False, + **kwargs, + ) + + def to_yaml_str(self, **kwargs: Any) -> str: + """Convert instance to a YAML string.""" + return yaml.safe_dump( + self.model_dump(by_alias=True, exclude_none=True), + default_flow_style=False, + sort_keys=False, + **kwargs, + ) + + def to_dict(self, **kwargs: Any) -> dict[str, Any]: + """Convert instance to a dictionary. + + Supports skip_defaults=True (mapped to exclude_defaults) for backward + compatibility with dataclass-wizard style APIs. + """ + if kwargs.get("skip_defaults"): + kwargs = dict(kwargs) + kwargs.pop("skip_defaults", None) + kwargs["exclude_defaults"] = True + return self.model_dump(by_alias=True, exclude_none=True, **kwargs) + + def update(self, other: Self) -> None: + """Update this instance with values from another instance of the same type. + + Performs in-place merge: lists are concatenated, dicts/sets are merged, + nested ConfigBaseModel instances are updated recursively. None values + in other do not overwrite existing values. + + Args: + other: Another instance of the same type to copy from + + Raises: + TypeError: If other is not an instance of the same type + """ + if type(other) is not type(self): + raise TypeError( + f"Expected {type(self).__name__} instance, got {type(other).__name__}" + ) + for name in self.model_fields: + current = getattr(self, name) + other_val = getattr(other, name) + if other_val is None: + continue + if isinstance(other_val, list): + setattr(self, name, current + deepcopy(other_val)) + elif isinstance(other_val, set): + setattr(self, name, current | deepcopy(other_val)) + elif isinstance(other_val, dict): + setattr(self, name, {**current, **deepcopy(other_val)}) + elif isinstance(other_val, ConfigBaseModel): + if current is not None: + current.update(other_val) + else: + setattr(self, name, deepcopy(other_val)) + else: + if current is None: + setattr(self, name, other_val) diff --git a/graflo/architecture/edge.py b/graflo/architecture/edge.py index df1b5c8f..92941170 100644 --- a/graflo/architecture/edge.py +++ b/graflo/architecture/edge.py @@ -17,17 +17,23 @@ from __future__ import annotations -import dataclasses -from typing import Union +from typing import Any +from pydantic import ( + Field as PydanticField, + PrivateAttr, + field_validator, + model_validator, +) + +from graflo.architecture.base import ConfigBaseModel from graflo.architecture.onto import ( - BaseDataclass, EdgeId, EdgeType, Index, Weight, ) -from graflo.architecture.vertex import Field, FieldType, VertexConfig, _FieldsType +from graflo.architecture.vertex import Field, FieldType, VertexConfig from graflo.onto import DBType @@ -38,8 +44,21 @@ DEFAULT_TIGERGRAPH_RELATION_WEIGHTNAME = "relation" -@dataclasses.dataclass -class WeightConfig(BaseDataclass): +def _normalize_direct_item(item: str | Field | dict[str, Any]) -> Field: + """Convert a single direct field item (str, Field, or dict) to Field.""" + if isinstance(item, Field): + return item + if isinstance(item, str): + return Field(name=item, type=None) + if isinstance(item, dict): + name = item.get("name") + if name is None: + raise ValueError(f"Field dict must have 'name' key: {item}") + return Field(name=name, type=item.get("type")) + raise TypeError(f"Field must be str, Field, or dict, got {type(item)}") + + +class WeightConfig(ConfigBaseModel): """Configuration for edge weights and relationships. This class manages the configuration of weights and relationships for edges, @@ -48,7 +67,7 @@ class WeightConfig(BaseDataclass): Attributes: vertices: List of weight configurations direct: List of direct field mappings. Can be specified as strings, Field objects, or dicts. - Will be normalized to Field objects internally in __post_init__. + Will be normalized to Field objects by the validator. After initialization, this is always list[Field] (type checker sees this). Examples: @@ -68,45 +87,15 @@ class WeightConfig(BaseDataclass): ... ]) """ - vertices: list[Weight] = dataclasses.field(default_factory=list) - # Internal representation: After __post_init__, this is always list[Field] - # Input types: Accepts list[str], list[Field], or list[dict] at construction - # The _FieldsType allows flexible input but normalizes to list[Field] internally - direct: _FieldsType = dataclasses.field(default_factory=list) - - def _normalize_fields( - self, fields: list[str] | list[Field] | list[dict] - ) -> list[Field]: - """Normalize fields to Field objects. - - Converts strings, Field objects, or dicts to Field objects. - Field objects behave like strings for backward compatibility. - - Args: - fields: List of strings, Field objects, or dicts + vertices: list[Weight] = PydanticField(default_factory=list) + direct: list[Field] = PydanticField(default_factory=list) - Returns: - list[Field]: Normalized list of Field objects (preserving order) - """ - normalized = [] - for field in fields: - if isinstance(field, Field): - normalized.append(field) - elif isinstance(field, str): - # Backward compatibility: string becomes Field with None type - # (most databases like ArangoDB don't require types) - normalized.append(Field(name=field, type=None)) - elif isinstance(field, dict): - # From dict (e.g., from YAML/JSON) - # Extract name and optional type - name = field.get("name") - if name is None: - raise ValueError(f"Field dict must have 'name' key: {field}") - field_type = field.get("type") - normalized.append(Field(name=name, type=field_type)) - else: - raise TypeError(f"Field must be str, Field, or dict, got {type(field)}") - return normalized + @field_validator("direct", mode="before") + @classmethod + def normalize_direct(cls, v: Any) -> Any: + if not isinstance(v, list): + return v + return [_normalize_direct_item(item) for item in v] @property def direct_names(self) -> list[str]: @@ -117,49 +106,8 @@ def direct_names(self) -> list[str]: """ return [field.name for field in self.direct] - def __post_init__(self): - """Initialize the weight configuration after dataclass initialization. - - Normalizes direct fields to Field objects. Field objects behave like strings, - maintaining backward compatibility. - - After this method, self.direct is always list[Field], regardless of input type. - """ - # Normalize direct fields to Field objects (preserve order) - # This converts str, Field, or dict inputs to list[Field] - self.direct = self._normalize_fields(self.direct) - - @classmethod - def from_dict(cls, data: dict): - """Create WeightConfig from dictionary, handling field normalization. - - Overrides parent to properly handle direct fields that may be strings, dicts, or Field objects. - JSONWizard may incorrectly deserialize dicts in direct, so we need to handle them manually. - - Args: - data: Dictionary containing weight config data - - Returns: - WeightConfig: New WeightConfig instance with direct normalized to list[Field] - """ - # Extract and preserve direct fields before JSONWizard processes them - direct_data = data.get("direct", []) - # Create a copy without direct to let JSONWizard handle the rest - data_copy = {k: v for k, v in data.items() if k != "direct"} - - # Call parent from_dict (JSONWizard) - instance = super().from_dict(data_copy) - # Now manually set direct (could be strings, dicts, or already Field objects) - # __post_init__ will normalize them to list[Field] - instance.direct = direct_data - # Trigger normalization - this ensures direct is always list[Field] after init - instance.direct = instance._normalize_fields(instance.direct) - return instance - - -@dataclasses.dataclass -class Edge(BaseDataclass): +class Edge(ConfigBaseModel): """Represents an edge in the graph database. An edge connects two vertices and can have various configurations for @@ -184,18 +132,12 @@ class Edge(BaseDataclass): source: str target: str - indexes: list[Index] = dataclasses.field(default_factory=list) - weights: Union[WeightConfig, None] = ( - None # Using Union for dataclass_wizard compatibility - ) + indexes: list[Index] = PydanticField(default_factory=list, alias="index") + weights: WeightConfig | None = None # relation represents Class in neo4j, for arango it becomes a weight relation: str | None = None - _relation_dbname: str | None = dataclasses.field( - default=None, - repr=False, - metadata={"dump": False}, - ) + _relation_dbname: str | None = PrivateAttr(default=None) relation_field: str | None = None relation_from_key: bool = False @@ -219,11 +161,8 @@ class Edge(BaseDataclass): None # ArangoDB-specific: edge collection name (set in finish_init) ) - def __post_init__(self): - """Initialize the edge after dataclass initialization.""" - - self._source: str | None = None - self._target: str | None = None + _source: str | None = PrivateAttr(default=None) + _target: str | None = PrivateAttr(default=None) @property def relation_dbname(self) -> str | None: @@ -362,9 +301,10 @@ def edge_id(self) -> EdgeId: """ return self.source, self.target, self.purpose + # update() inherited from ConfigBaseModel; docstring: same as base, in-place merge. + -@dataclasses.dataclass -class EdgeConfig(BaseDataclass): +class EdgeConfig(ConfigBaseModel): """Configuration for managing collections of edges. This class manages a collection of edges, providing methods for accessing @@ -374,14 +314,14 @@ class EdgeConfig(BaseDataclass): edges: List of edge configurations """ - edges: list[Edge] = dataclasses.field(default_factory=list) + edges: list[Edge] = PydanticField(default_factory=list) + _edges_map: dict[EdgeId, Edge] = PrivateAttr() - def __post_init__(self): - """Initialize the edge configuration. - - Creates internal mapping of edge IDs to edge configurations. - """ - self._edges_map: dict[EdgeId, Edge] = {e.edge_id: e for e in self.edges} + @model_validator(mode="after") + def _build_edges_map(self) -> EdgeConfig: + """Build internal mapping of edge IDs to edge configurations.""" + object.__setattr__(self, "_edges_map", {e.edge_id: e for e in self.edges}) + return self def finish_init(self, vc: VertexConfig): """Complete initialization of all edges with vertex configuration. diff --git a/graflo/architecture/onto.py b/graflo/architecture/onto.py index 6bae8c17..1a320954 100644 --- a/graflo/architecture/onto.py +++ b/graflo/architecture/onto.py @@ -28,16 +28,15 @@ from __future__ import annotations -import dataclasses import logging -from abc import ABCMeta from collections import defaultdict from typing import Any, TypeAlias -from dataclass_wizard import JSONWizard, YAMLWizard +from pydantic import ConfigDict, Field, model_validator +from graflo.architecture.base import ConfigBaseModel from graflo.onto import DBType -from graflo.onto import BaseDataclass, BaseEnum +from graflo.onto import BaseEnum from graflo.util.transform import pick_unique_dict # type for vertex or edge name (index) @@ -91,9 +90,8 @@ class EdgeType(BaseEnum): DIRECT = "direct" -@dataclasses.dataclass -class ABCFields(BaseDataclass, metaclass=ABCMeta): - """Abstract base class for entities that have fields. +class ABCFields(ConfigBaseModel): + """Base model for entities that have fields. Attributes: name: Optional name of the entity @@ -101,7 +99,7 @@ class ABCFields(BaseDataclass, metaclass=ABCMeta): """ name: str | None = None - fields: list[str] = dataclasses.field(default_factory=list) + fields: list[str] = Field(default_factory=list) keep_vertex_name: bool = True def cfield(self, x: str) -> str: @@ -116,7 +114,6 @@ def cfield(self, x: str) -> str: return f"{self.name}@{x}" if self.keep_vertex_name else x -@dataclasses.dataclass class Weight(ABCFields): """Defines weight configuration for edges. @@ -125,12 +122,11 @@ class Weight(ABCFields): filter: Dictionary of filter conditions for weights """ - map: dict = dataclasses.field(default_factory=dict) - filter: dict = dataclasses.field(default_factory=dict) + map: dict = Field(default_factory=dict) + filter: dict = Field(default_factory=dict) -@dataclasses.dataclass -class Index(BaseDataclass): +class Index(ConfigBaseModel): """Configuration for database indexes. Attributes: @@ -144,7 +140,7 @@ class Index(BaseDataclass): """ name: str | None = None - fields: list[str] = dataclasses.field(default_factory=list) + fields: list[str] = Field(default_factory=list) unique: bool = True type: IndexType = IndexType.PERSISTENT deduplicate: bool = True @@ -167,12 +163,10 @@ def db_form(self, db_type: DBType) -> dict: Raises: ValueError: If db_type is not supported """ - r = self.to_dict() + r = dict(self.to_dict()) if db_type == DBType.ARANGO: - _ = r.pop("name") - _ = r.pop("exclude_edge_endpoints") - else: - pass + r.pop("name", None) + r.pop("exclude_edge_endpoints", None) return r @@ -190,8 +184,7 @@ def __iter__(self): yield key, self._dictlike.edges[key] -@dataclasses.dataclass -class GraphContainer(BaseDataclass): +class GraphContainer(ConfigBaseModel): """Container for graph data including vertices and edges. Attributes: @@ -200,12 +193,11 @@ class GraphContainer(BaseDataclass): linear: List of default dictionaries containing linear data """ - vertices: dict[str, list] - edges: dict[tuple[str, str, str | None], list] - linear: list[defaultdict[str | tuple[str, str, str | None], list[Any]]] - - def __post_init__(self): - pass + vertices: dict[str, list] = Field(default_factory=dict) + edges: dict[tuple[str, str, str | None], list] = Field(default_factory=dict) + linear: list[defaultdict[str | tuple[str, str, str | None], list[Any]]] = Field( + default_factory=list + ) def items(self): """Get an ItemsView of the container's contents.""" @@ -293,8 +285,7 @@ def dd_factory() -> defaultdict[GraphEntity, list]: return defaultdict(list) -@dataclasses.dataclass(kw_only=True) -class VertexRep(BaseDataclass): +class VertexRep(ConfigBaseModel): """Context for graph transformation actions. Attributes: @@ -302,21 +293,40 @@ class VertexRep(BaseDataclass): ctx: context (for edge definition upstream """ + model_config = ConfigDict(kw_only=True) # type: ignore[assignment] + vertex: dict ctx: dict -@dataclasses.dataclass(frozen=True, eq=True) -class LocationIndex(JSONWizard, YAMLWizard): - path: tuple[str | int | None, ...] = dataclasses.field(default_factory=tuple) +class LocationIndex(ConfigBaseModel): + """Immutable location index for nested graph traversal.""" + + model_config = ConfigDict(frozen=True) + + path: tuple[str | int | None, ...] = Field(default_factory=tuple) + + @model_validator(mode="before") + @classmethod + def accept_tuple(cls, data: Any) -> Any: + """Accept a single tuple as positional path (e.g. LocationIndex((0,))).""" + if isinstance(data, tuple): + return {"path": data} + return data + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Allow LocationIndex((0,)) or LocationIndex(path=(0,)).""" + if args and len(args) == 1 and isinstance(args[0], tuple) and not kwargs: + kwargs = {"path": args[0]} + super().__init__(**kwargs) def extend(self, extension: tuple[str | int | None, ...]) -> LocationIndex: - return LocationIndex((*self.path, *extension)) + return LocationIndex(path=(*self.path, *extension)) - def depth(self): + def depth(self) -> int: return len(self.path) - def congruence_measure(self, other: LocationIndex): + def congruence_measure(self, other: LocationIndex) -> int: neq_position = 0 for step_a, step_b in zip(self.path, other.path): if step_a != step_b: @@ -331,44 +341,47 @@ def filter(self, lindex_list: list[LocationIndex]) -> list[LocationIndex]: if t.depth() >= self.depth() and t.path[: self.depth()] == self.path ] - def __lt__(self, other: LocationIndex): + def __lt__(self, other: LocationIndex) -> bool: return len(self.path) < len(other.path) - def __contains__(self, item): + def __contains__(self, item: object) -> bool: return item in self.path - def __len__(self): + def __len__(self) -> int: return len(self.path) def __iter__(self): return iter(self.path) - def __getitem__(self, item): + def __getitem__(self, item: int | slice): return self.path[item] -@dataclasses.dataclass(kw_only=True) -class ActionContext(BaseDataclass): +def _default_dict_list() -> defaultdict[GraphEntity, list]: + return defaultdict(list) + + +def _default_dict_transforms() -> defaultdict[LocationIndex, list[dict]]: + return defaultdict(list) + + +class ActionContext(ConfigBaseModel): """Context for graph transformation actions. Attributes: - acc_vertex: Local accumulation of vertices - acc_global: Global accumulation of graph entities - buffer_vertex: Buffer for vertex data - buffer_transforms: Buffer for transforms data + acc_vertex: Local accumulation of vertices (defaultdict[str, defaultdict[LocationIndex, list]]) + acc_global: Global accumulation of graph entities (defaultdict[GraphEntity, list]) + buffer_vertex: Buffer for vertex data (defaultdict[GraphEntity, list]) + buffer_transforms: Buffer for transforms data (defaultdict[LocationIndex, list[dict]]) target_vertices: Set of target vertex names indicating user intention """ - acc_vertex: defaultdict[str, defaultdict[LocationIndex, list]] = dataclasses.field( - default_factory=outer_factory - ) - acc_global: defaultdict[GraphEntity, list] = dataclasses.field( - default_factory=dd_factory - ) - buffer_vertex: defaultdict[GraphEntity, list] = dataclasses.field( - default_factory=lambda: defaultdict(list) - ) - buffer_transforms: defaultdict[LocationIndex, list[dict]] = dataclasses.field( - default_factory=lambda: defaultdict(list) - ) - target_vertices: set[str] = dataclasses.field(default_factory=set) + model_config = ConfigDict(kw_only=True) # type: ignore[assignment] + + # Pydantic cannot schema nested defaultdict with custom key types (e.g. LocationIndex), + # so we use Any; runtime type is as documented in Attributes + acc_vertex: Any = Field(default_factory=outer_factory) + acc_global: Any = Field(default_factory=dd_factory) + buffer_vertex: Any = Field(default_factory=_default_dict_list) + buffer_transforms: Any = Field(default_factory=_default_dict_transforms) + target_vertices: set[str] = Field(default_factory=set) diff --git a/graflo/architecture/resource.py b/graflo/architecture/resource.py index cbb805df..25244248 100644 --- a/graflo/architecture/resource.py +++ b/graflo/architecture/resource.py @@ -19,22 +19,22 @@ Example: >>> resource = Resource( ... resource_name="users", - ... apply=[VertexActor("user"), EdgeActor("follows")], + ... apply=[{"vertex": "user"}, {"edge": {"from": "user", "to": "user"}}], ... encoding=EncodingType.UTF_8 ... ) >>> result = resource(doc) """ -import dataclasses +from __future__ import annotations + import logging from collections import defaultdict -from typing import Callable +from typing import Any, Callable -from dataclass_wizard import JSONWizard +from pydantic import AliasChoices, Field as PydanticField, PrivateAttr, model_validator -from graflo.architecture.actor import ( - ActorWrapper, -) +from graflo.architecture.actor import ActorWrapper +from graflo.architecture.base import ConfigBaseModel from graflo.architecture.edge import Edge, EdgeConfig from graflo.architecture.onto import ( ActionContext, @@ -42,70 +42,101 @@ GraphEntity, ) from graflo.architecture.transform import ProtoTransform -from graflo.architecture.vertex import ( - VertexConfig, -) -from graflo.onto import BaseDataclass +from graflo.architecture.vertex import VertexConfig logger = logging.getLogger(__name__) -@dataclasses.dataclass(kw_only=True) -class Resource(BaseDataclass, JSONWizard): +class Resource(ConfigBaseModel): """Resource configuration and processing. - This class represents a data resource that can be processed and transformed - into graph structures. It manages the processing pipeline through actors - and handles data encoding, transformation, and mapping. - - Attributes: - resource_name: Name of the resource - apply: List of actors to apply in sequence - encoding: Data encoding type (default: UTF_8) - merge_collections: List of collections to merge - extra_weights: List of additional edge weights - types: Dictionary of field type mappings - root: Root actor wrapper for processing - vertex_config: Configuration for vertices - edge_config: Configuration for edges + Represents a data resource that can be processed and transformed into graph + structures. Manages the processing pipeline through actors and handles data + encoding, transformation, and mapping. Suitable for LLM-generated schema + constituents. """ - resource_name: str - apply: list - encoding: EncodingType = EncodingType.UTF_8 - merge_collections: list[str] = dataclasses.field(default_factory=list) - extra_weights: list[Edge] = dataclasses.field(default_factory=list) - types: dict[str, str] = dataclasses.field(default_factory=dict) - edge_greedy: bool = True - - def __post_init__(self): - """Initialize the resource after dataclass initialization. - - Sets up the actor wrapper and type mappings. Evaluates type expressions - for field type casting. - - Raises: - Exception: If type evaluation fails for any field - """ - self.root = ActorWrapper(*self.apply) - self._types: dict[str, Callable] = dict() - self.vertex_config: VertexConfig - self.edge_config: EdgeConfig + model_config = {"extra": "forbid"} + + resource_name: str = PydanticField( + ..., + description="Name of the resource (e.g. table or file identifier).", + ) + apply: list[dict[str, Any]] = PydanticField( + ..., + description="Pipeline of actor steps to apply in sequence (vertex, edge, transform, descend). " + 'Each step is a dict, e.g. {"vertex": "user"} or {"edge": {"from": "a", "to": "b"}}.', + validation_alias=AliasChoices("apply", "pipeline"), + ) + encoding: EncodingType = PydanticField( + default=EncodingType.UTF_8, + description="Character encoding for input/output (e.g. utf-8, ISO-8859-1).", + ) + merge_collections: list[str] = PydanticField( + default_factory=list, + description="List of collection names to merge when writing to the graph.", + ) + extra_weights: list[Edge] = PydanticField( + default_factory=list, + description="Additional edge weight configurations for this resource.", + ) + types: dict[str, str] = PydanticField( + default_factory=dict, + description='Field name to Python type expression for casting (e.g. {"amount": "float"}).', + ) + edge_greedy: bool = PydanticField( + default=True, + description="If True, emit edges as soon as source/target vertices exist; if False, wait for explicit targets.", + ) + + _root: ActorWrapper = PrivateAttr() + _types: dict[str, Callable[..., Any]] = PrivateAttr(default_factory=dict) + _vertex_config: VertexConfig = PrivateAttr() + _edge_config: EdgeConfig = PrivateAttr() + + @model_validator(mode="after") + def _build_root_and_types(self) -> Resource: + """Build root ActorWrapper from apply and evaluate type expressions.""" + object.__setattr__(self, "_root", ActorWrapper(*self.apply)) + object.__setattr__(self, "_types", {}) for k, v in self.types.items(): try: self._types[k] = eval(v) except Exception as ex: logger.error( - f"For resource {self.name} for field {k} failed to cast type {v} : {ex}" + "For resource %s for field %s failed to cast type %s : %s", + self.name, + k, + v, + ex, ) + # Placeholders until finish_init is called by Schema + object.__setattr__( + self, + "_vertex_config", + VertexConfig(vertices=[]), + ) + object.__setattr__(self, "_edge_config", EdgeConfig()) + return self @property - def name(self): - """Get the resource name. + def vertex_config(self) -> VertexConfig: + """Vertex configuration (set by Schema.finish_init).""" + return self._vertex_config - Returns: - str: Name of the resource - """ + @property + def edge_config(self) -> EdgeConfig: + """Edge configuration (set by Schema.finish_init).""" + return self._edge_config + + @property + def root(self) -> ActorWrapper: + """Root actor wrapper for the processing pipeline.""" + return self._root + + @property + def name(self) -> str: + """Resource name (alias for resource_name).""" return self.resource_name def finish_init( @@ -113,21 +144,21 @@ def finish_init( vertex_config: VertexConfig, edge_config: EdgeConfig, transforms: dict[str, ProtoTransform], - ): + ) -> None: """Complete resource initialization. Initializes the resource with vertex and edge configurations, - and sets up the processing pipeline. + and sets up the processing pipeline. Called by Schema after load. Args: vertex_config: Configuration for vertices edge_config: Configuration for edges transforms: Dictionary of available transforms """ - self.vertex_config = vertex_config - self.edge_config = edge_config + object.__setattr__(self, "_vertex_config", vertex_config) + object.__setattr__(self, "_edge_config", edge_config) - logger.debug(f"total resource actor count : {self.root.count()}") + logger.debug("total resource actor count : %s", self.root.count()) self.root.finish_init( vertex_config=vertex_config, transforms=transforms, @@ -135,7 +166,9 @@ def finish_init( edge_greedy=self.edge_greedy, ) - logger.debug(f"total resource actor count (after 2 finit): {self.root.count()}") + logger.debug( + "total resource actor count (after 2 finit): %s", self.root.count() + ) for e in self.extra_weights: e.finish_init(vertex_config) @@ -154,10 +187,6 @@ def __call__(self, doc: dict) -> defaultdict[GraphEntity, list]: acc = self.root.normalize_ctx(ctx) return acc - def count(self): - """Get the total number of actors in the resource. - - Returns: - int: Number of actors - """ + def count(self) -> int: + """Total number of actors in the resource pipeline.""" return self.root.count() diff --git a/graflo/architecture/schema.py b/graflo/architecture/schema.py index c3ff58b3..729b200c 100644 --- a/graflo/architecture/schema.py +++ b/graflo/architecture/schema.py @@ -27,66 +27,104 @@ >>> resource = schema.fetch_resource("users") """ -import dataclasses +from __future__ import annotations + import logging from collections import Counter +from typing import Any + +from pydantic import ( + Field as PydanticField, + PrivateAttr, + field_validator, + model_validator, +) from graflo.architecture.actor import EdgeActor, TransformActor, VertexActor +from graflo.architecture.base import ConfigBaseModel from graflo.architecture.edge import EdgeConfig from graflo.architecture.resource import Resource from graflo.architecture.transform import ProtoTransform from graflo.architecture.vertex import VertexConfig -from graflo.onto import BaseDataclass logger = logging.getLogger(__name__) -@dataclasses.dataclass -class SchemaMetadata(BaseDataclass): +class SchemaMetadata(ConfigBaseModel): """Schema metadata and versioning information. - This class holds metadata about the schema, including its name and version. - It's used for schema identification and versioning. - - Attributes: - name: Name of the schema - version: Optional version string of the schema + Holds metadata about the schema, including its name and version. + Used for schema identification and versioning. Suitable for LLM-generated + schema constituents. """ - name: str - version: str | None = None + name: str = PydanticField( + ..., + description="Name of the schema (e.g. graph or database identifier).", + ) + version: str | None = PydanticField( + default=None, + description="Optional version string of the schema (e.g. semantic version).", + ) -@dataclasses.dataclass -class Schema(BaseDataclass): +class Schema(ConfigBaseModel): """Graph database schema configuration. - This class represents the complete schema configuration for a graph database. - It manages resources, vertex configurations, edge configurations, and transforms. - - Attributes: - general: Schema metadata and versioning information - vertex_config: Configuration for vertex collections - edge_config: Configuration for edge collections - resources: List of resource definitions - transforms: Dictionary of available transforms - _resources: Internal mapping of resource names to resources + Represents the complete schema configuration for a graph database. + Manages resources, vertex configurations, edge configurations, and transforms. + Suitable for LLM-generated schema constituents. """ - general: SchemaMetadata - vertex_config: VertexConfig - edge_config: EdgeConfig - resources: list[Resource] - transforms: dict[str, ProtoTransform] = dataclasses.field(default_factory=dict) - - def __post_init__(self): - """Initialize the schema after dataclass initialization. - - Sets up transforms, initializes edge configuration, and validates - resource names for uniqueness. + general: SchemaMetadata = PydanticField( + ..., + description="Schema metadata and versioning (name, version).", + ) + vertex_config: VertexConfig = PydanticField( + ..., + description="Configuration for vertex collections (vertices, fields, indexes).", + ) + edge_config: EdgeConfig = PydanticField( + ..., + description="Configuration for edge collections (edges, weights).", + ) + resources: list[Resource] = PydanticField( + default_factory=list, + description="List of resource definitions (data pipelines mapping to vertices/edges).", + ) + transforms: dict[str, ProtoTransform] = PydanticField( + default_factory=dict, + description="Dictionary of named transforms available to resources (name -> ProtoTransform).", + ) + + _resources: dict[str, Resource] = PrivateAttr() + + @field_validator("resources", mode="before") + @classmethod + def _coerce_resources_list(cls, v: Any) -> Any: + """Accept empty dict as empty list for backward compatibility.""" + if isinstance(v, dict) and len(v) == 0: + return [] + return v + + @model_validator(mode="after") + def _init_schema(self) -> Schema: + """Set transform names, finish edge/resource init, and build resource name map.""" + self.finish_init() + return self + + def finish_init(self) -> None: + """Complete schema initialization after construction or resource updates. + + Sets transform names, initializes edge configuration with vertex config, + calls finish_init on each resource, validates unique resource names, + and builds the internal _resources name-to-Resource mapping. + + Call this after assigning to resources (e.g. when inferring resources + from a database) so that _resources and resource pipelines are correct. Raises: - ValueError: If duplicate resource names are found + ValueError: If duplicate resource names are found. """ for name, t in self.transforms.items(): t.name = name @@ -105,9 +143,7 @@ def __post_init__(self): for k, v in c.items(): if v > 1: raise ValueError(f"resource name {k} used {v} times") - self._resources: dict[str, Resource] = {} - for r in self.resources: - self._resources[r.name] = r + object.__setattr__(self, "_resources", {r.name: r for r in self.resources}) def fetch_resource(self, name: str | None = None) -> Resource: """Fetch a resource by name or get the first available resource. diff --git a/graflo/architecture/transform.py b/graflo/architecture/transform.py index 015f21f8..a32c0e19 100644 --- a/graflo/architecture/transform.py +++ b/graflo/architecture/transform.py @@ -28,25 +28,41 @@ from __future__ import annotations -import dataclasses import importlib import logging from copy import deepcopy -from typing import Any +from typing import Any, Self -from graflo.onto import BaseDataclass +from pydantic import Field, PrivateAttr, model_validator + +from graflo.architecture.base import ConfigBaseModel logger = logging.getLogger(__name__) +def _tuple_it(x: str | list[str] | tuple[str, ...]) -> tuple[str, ...]: + """Convert input to tuple format. + + Args: + x: Input to convert (string, list, or tuple) + + Returns: + tuple: Converted tuple + """ + if isinstance(x, str): + x = [x] + if isinstance(x, list): + x = tuple(x) + return x + + class TransformException(BaseException): """Base exception for transform-related errors.""" pass -@dataclasses.dataclass -class ProtoTransform(BaseDataclass): +class ProtoTransform(ConfigBaseModel): """Base class for transform definitions. This class provides the foundation for data transformations, supporting both @@ -64,63 +80,47 @@ class ProtoTransform(BaseDataclass): name: str | None = None module: str | None = None - params: dict[str, Any] = dataclasses.field(default_factory=dict) + params: dict[str, Any] = Field(default_factory=dict) foo: str | None = None - input: str | list[str] | tuple[str, ...] = dataclasses.field(default_factory=tuple) - output: str | list[str] | tuple[str, ...] = dataclasses.field(default_factory=tuple) - - def __post_init__(self): - """Initialize the transform after dataclass initialization. - - Sets up the transform function and input/output field specifications. - """ - self._foo = None - self._init_foo() - - self.input = self._tuple_it(self.input) - - if not self.output: - self.output = self.input - self.output = self._tuple_it(self.output) - - @staticmethod - def _tuple_it(x): - """Convert input to tuple format. - - Args: - x: Input to convert (string, list, or tuple) - - Returns: - tuple: Converted tuple - """ - if isinstance(x, str): - x = [x] - if isinstance(x, list): - x = tuple(x) - return x - - def _init_foo(self): - """Initialize the transform function from module. - - Imports the specified module and gets the transform function. - - Raises: - TypeError: If module import fails - ValueError: If function lookup fails - """ + input: tuple[str, ...] = Field(default_factory=tuple) + output: tuple[str, ...] = Field(default_factory=tuple) + + _foo: Any = PrivateAttr(default=None) + + @model_validator(mode="before") + @classmethod + def _normalize_input_output(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + data = dict(data) + for key in ("input", "output"): + if key in data and data[key] is not None: + data[key] = _tuple_it(data[key]) + return data + + @model_validator(mode="after") + def _init_foo_and_output(self) -> Self: if self.module is not None and self.foo is not None: try: _module = importlib.import_module(self.module) except Exception as e: raise TypeError(f"Provided module {self.module} is not valid: {e}") try: - self._foo = getattr(_module, self.foo) + object.__setattr__(self, "_foo", getattr(_module, self.foo)) except Exception as e: raise ValueError( f"Could not instantiate transform function. Exception: {e}" ) + if not self.output and self.input: + object.__setattr__(self, "output", self.input) + return self + + @classmethod + def get_fields_members(cls) -> list[str]: + """Get list of field members (public model fields).""" + return list(cls.model_fields.keys()) - def __lt__(self, other): + def __lt__(self, other: object) -> bool: """Compare transforms for ordering. Args: @@ -129,12 +129,13 @@ def __lt__(self, other): Returns: bool: True if this transform should be ordered before other """ + if not isinstance(other, ProtoTransform): + return NotImplemented if self._foo is None and other._foo is not None: return True return False -@dataclasses.dataclass(kw_only=True) class Transform(ProtoTransform): """Concrete transform implementation. @@ -148,65 +149,62 @@ class Transform(ProtoTransform): functional_transform: Whether this is a functional transform """ - fields: str | list[str] | tuple[str, ...] = dataclasses.field(default_factory=tuple) - map: dict[str, str] = dataclasses.field(default_factory=dict) - switch: dict[str, Any] = dataclasses.field(default_factory=dict) - - def __post_init__(self): - """Initialize the transform after dataclass initialization. - - Sets up field specifications and validates transform configuration. - - Raises: - ValueError: If transform configuration is invalid - """ - super().__post_init__() - self.functional_transform = self._foo is not None - - # Normalize containers - self.fields = self._tuple_it(self.fields) - self.input = self._tuple_it(self.input) - self.output = self._tuple_it(self.output) - - # Derive relationships between map, input, output, and fields. + fields: tuple[str, ...] = Field(default_factory=tuple) + map: dict[str, str] = Field(default_factory=dict) + switch: dict[str, Any] = Field(default_factory=dict) + + functional_transform: bool = False + + @model_validator(mode="before") + @classmethod + def _normalize_fields(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + data = dict(data) + if "fields" in data and data["fields"] is not None: + data["fields"] = _tuple_it(data["fields"]) + return data + + @model_validator(mode="after") + def _init_derived(self) -> Self: + object.__setattr__(self, "functional_transform", self._foo is not None) self._init_input_from_fields() self._init_io_from_map() self._init_from_switch() self._default_output_from_input() self._init_map_from_io() - self._validate_configuration() + return self def _init_input_from_fields(self) -> None: """Populate input from fields when provided.""" if self.fields and not self.input: - self.input = self.fields + object.__setattr__(self, "input", self.fields) - def _init_io_from_map(self, force_init=False) -> None: + def _init_io_from_map(self, force_init: bool = False) -> None: """Populate input/output tuples from an explicit map.""" if not self.map: return if force_init or (not self.input and not self.output): input_fields, output_fields = zip(*self.map.items()) - self.input = tuple(input_fields) - self.output = tuple(output_fields) + object.__setattr__(self, "input", tuple(input_fields)) + object.__setattr__(self, "output", tuple(output_fields)) elif not self.input: - self.input = tuple(self.map.keys()) + object.__setattr__(self, "input", tuple(self.map.keys())) elif not self.output: - self.output = tuple(self.map.values()) + object.__setattr__(self, "output", tuple(self.map.values())) def _init_from_switch(self) -> None: """Fallback initialization using switch definitions.""" if self.switch and not self.input and not self.output: - self.input = tuple(self.switch) - # We rely on the first switch entry to infer the output shape. + object.__setattr__(self, "input", tuple(self.switch)) first_key = self.input[0] - self.output = self._tuple_it(self.switch[first_key]) + object.__setattr__(self, "output", _tuple_it(self.switch[first_key])) def _default_output_from_input(self) -> None: """Ensure output mirrors input when not explicitly provided.""" - if not self.output: - self.output = self.input + if not self.output and self.input: + object.__setattr__(self, "output", self.input) def _init_map_from_io(self) -> None: """Derive map from input/output when possible.""" @@ -214,7 +212,9 @@ def _init_map_from_io(self) -> None: return if len(self.input) != len(self.output): return - self.map = {src: dst for src, dst in zip(self.input, self.output)} + object.__setattr__( + self, "map", {src: dst for src, dst in zip(self.input, self.output)} + ) def _validate_configuration(self) -> None: """Validate that the transform has enough information to operate.""" @@ -224,7 +224,11 @@ def _validate_configuration(self) -> None: "constructor." ) - def __call__(self, *nargs, **kwargs): + def _refresh_derived(self) -> None: + """Re-run derived state (e.g. map from input/output) after mutating attributes.""" + self._init_map_from_io() + + def __call__(self, *nargs: Any, **kwargs: Any) -> dict[str, Any] | Any: """Execute the transform. Args: @@ -239,7 +243,7 @@ def __call__(self, *nargs, **kwargs): if isinstance(input_doc, dict): output_values = [input_doc[k] for k in self.input] else: - output_values = nargs + output_values = list(nargs) else: if nargs and isinstance(input_doc := nargs[0], dict): new_args = [input_doc[k] for k in self.input] @@ -258,7 +262,7 @@ def is_mapping(self) -> bool: """True when the transform is pure mapping (no function).""" return self._foo is None - def _dress_as_dict(self, transform_result) -> dict[str, Any]: + def _dress_as_dict(self, transform_result: Any) -> dict[str, Any]: """Convert transform result to dictionary format. Args: @@ -285,14 +289,18 @@ def is_dummy(self) -> bool: """ return (self.name is not None) and (not self.map and self._foo is None) - def update(self, t: Transform) -> Transform: - """Update this transform with another transform's configuration. + def merge_from(self, t: Transform) -> Transform: + """Merge another transform's configuration into a copy of it. + + Returns a new Transform with values from self overriding t where set. + Does not override ConfigBaseModel.update (in-place); use this for + copy-and-merge semantics. Args: - t: Transform to update from + t: Transform to merge from Returns: - Transform: Updated transform + Transform: New transform with merged configuration """ t_copy = deepcopy(t) if self.input: @@ -300,8 +308,8 @@ def update(self, t: Transform) -> Transform: if self.output: t_copy.output = self.output if self.params: - t_copy.params.update(self.params) - t_copy.__post_init__() + t_copy.params = {**t_copy.params, **self.params} + t_copy._refresh_derived() return t_copy def get_barebone( @@ -316,7 +324,7 @@ def get_barebone( tuple[Transform | None, Transform | None]: Updated self transform and transform to store in library """ - self_param = self.to_dict(skip_defaults=True) + self_param = self.to_dict(exclude_defaults=True) if self.foo is not None: # self will be the lib transform return None, self @@ -324,7 +332,7 @@ def get_barebone( # init self from other self_param.pop("foo", None) self_param.pop("module", None) - other_param = other.to_dict(skip_defaults=True) + other_param = other.to_dict(exclude_defaults=True) other_param.update(self_param) return Transform(**other_param), None else: diff --git a/graflo/architecture/vertex.py b/graflo/architecture/vertex.py index f913be1e..248e34dd 100644 --- a/graflo/architecture/vertex.py +++ b/graflo/architecture/vertex.py @@ -15,19 +15,32 @@ >>> field_names = config.fields_names("user") # Returns list[str] """ +from __future__ import annotations + import ast -import dataclasses import json import logging -from typing import TYPE_CHECKING, Union +from typing import Any + +from pydantic import ( + ConfigDict, + Field as PydanticField, + PrivateAttr, + field_validator, + model_validator, +) +from graflo.architecture.base import ConfigBaseModel from graflo.architecture.onto import Index -from graflo.filter.onto import Expression +from graflo.filter.onto import Clause from graflo.onto import DBType -from graflo.onto import BaseDataclass, BaseEnum +from graflo.onto import BaseEnum logger = logging.getLogger(__name__) +# Type accepted for fields before normalization (for use by Edge/WeightConfig) +FieldsInputType = list[str] | list["Field"] | list[dict[str, Any]] + class FieldType(BaseEnum): """Supported field types for graph databases. @@ -54,23 +67,7 @@ class FieldType(BaseEnum): DATETIME = "DATETIME" -if TYPE_CHECKING: - # For type checking: after __post_init__, fields is always list[Field] - # Using string literal to avoid forward reference issues - _FieldsType = list["Field"] - # For type checking: allow FieldType, str, or None at construction time - # Strings are converted to FieldType enum in __post_init__ - _FieldTypeType = FieldType | str | None -else: - # For runtime: accept flexible input types, will be normalized in __post_init__ - # Use Union for runtime since we can't use | with string literals - _FieldsType = list[Union[str, "Field", dict]] - # For runtime: accept FieldType, str, or None (strings converted in __post_init__) - _FieldTypeType = Union[FieldType, str, None] - - -@dataclasses.dataclass -class Field(BaseDataclass): +class Field(ConfigBaseModel): """Represents a typed field in a vertex. Field objects behave like strings for backward compatibility. They can be used @@ -80,45 +77,37 @@ class Field(BaseDataclass): Attributes: name: Name of the field type: Optional type of the field. Can be FieldType enum, str, or None at construction. - Strings are converted to FieldType enum in __post_init__. - After initialization, this is always FieldType | None (type checker sees this). + Strings are converted to FieldType enum by the validator. None is allowed (most databases like ArangoDB don't require types). Defaults to None. """ - name: str - type: _FieldTypeType = None + model_config = ConfigDict(extra="forbid") - def __post_init__(self): - """Validate and normalize type if specified. + name: str + type: FieldType | None = None - This method handles type normalization AFTER a Field object has been created. - It converts string types to FieldType enum and validates the type. - This is separate from _normalize_fields() which handles the creation of Field - objects from various input formats (str/dict/Field). - """ - if self.type is not None: - # Convert string to FieldType enum if it's a string - if isinstance(self.type, str): - type_upper = self.type.upper() - # Validate and convert to FieldType enum - if type_upper not in FieldType: - allowed_types = sorted(ft.value for ft in FieldType) - raise ValueError( - f"Field type '{self.type}' is not allowed. " - f"Allowed types are: {', '.join(allowed_types)}" - ) - self.type = FieldType(type_upper) - # If it's already a FieldType, validate it's a valid enum member - elif isinstance(self.type, FieldType): - # Already a FieldType enum, no conversion needed - pass - else: + @field_validator("type", mode="before") + @classmethod + def normalize_type(cls, v: Any) -> FieldType | None: + if v is None: + return None + if isinstance(v, FieldType): + return v + if isinstance(v, str): + type_upper = v.upper() + if type_upper not in FieldType: allowed_types = sorted(ft.value for ft in FieldType) raise ValueError( - f"Field type must be FieldType enum, str, or None, got {type(self.type)}. " + f"Field type '{v}' is not allowed. " f"Allowed types are: {', '.join(allowed_types)}" ) + return FieldType(type_upper) + allowed_types = sorted(ft.value for ft in FieldType) + raise ValueError( + f"Field type must be FieldType enum, str, or None, got {type(v)}. " + f"Allowed types are: {', '.join(allowed_types)}" + ) def __str__(self) -> str: """Return field name as string for backward compatibility.""" @@ -134,7 +123,7 @@ def __hash__(self) -> int: """Hash by name only, allowing Field objects to work in sets and as dict keys.""" return hash(self.name) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: """Compare equal to strings with same name, or other Field objects with same name.""" if isinstance(other, Field): return self.name == other.name @@ -142,16 +131,48 @@ def __eq__(self, other) -> bool: return self.name == other return False - def __ne__(self, other) -> bool: + def __ne__(self, other: object) -> bool: """Compare not equal.""" return not self.__eq__(other) - # Field objects are hashable (via __hash__) and comparable to strings (via __eq__) - # This allows them to work in sets, as dict keys, and in membership tests - -@dataclasses.dataclass -class Vertex(BaseDataclass): +def _parse_string_to_dict(field_str: str) -> dict | None: + """Parse a string that might be a JSON or Python dict representation.""" + try: + parsed = json.loads(field_str) + return parsed if isinstance(parsed, dict) else None + except json.JSONDecodeError: + pass + try: + parsed = ast.literal_eval(field_str) + return parsed if isinstance(parsed, dict) else None + except (ValueError, SyntaxError): + return None + + +def _dict_to_field(field_dict: dict[str, Any]) -> Field: + """Convert a dict to a Field object.""" + name = field_dict.get("name") + if name is None: + raise ValueError(f"Field dict must have 'name' key: {field_dict}") + return Field(name=name, type=field_dict.get("type")) + + +def _normalize_fields_item(item: str | Field | dict[str, Any]) -> Field: + """Convert a single field item (str, Field, or dict) to Field.""" + if isinstance(item, Field): + return item + if isinstance(item, dict): + return _dict_to_field(item) + if isinstance(item, str): + parsed_dict = _parse_string_to_dict(item) + if parsed_dict: + return _dict_to_field(parsed_dict) + return Field(name=item, type=None) + raise TypeError(f"Field must be str, Field, or dict, got {type(item)}") + + +class Vertex(ConfigBaseModel): """Represents a vertex in the graph database. A vertex is a fundamental unit in the graph that can have fields, indexes, @@ -162,8 +183,7 @@ class Vertex(BaseDataclass): Attributes: name: Name of the vertex fields: List of field names (str), Field objects, or dicts. - Will be normalized to Field objects internally in __post_init__. - After initialization, this is always list[Field] (type checker sees this). + Will be normalized to Field objects by the validator. indexes: List of indexes for the vertex filters: List of filter expressions dbname: Optional database name (defaults to vertex name) @@ -185,136 +205,80 @@ class Vertex(BaseDataclass): ... ]) """ + # Allow extra keys when loading from YAML (e.g. transforms, other runtime keys) + model_config = ConfigDict(extra="ignore") + name: str - fields: _FieldsType = dataclasses.field(default_factory=list) - indexes: list[Index] = dataclasses.field(default_factory=list) - filters: list[Expression] = dataclasses.field(default_factory=list) + fields: list[Field] = PydanticField(default_factory=list) + indexes: list[Index] = PydanticField(default_factory=list) + filters: list[Any] = PydanticField( + default_factory=list + ) # items become Clause via convert_to_expressions dbname: str | None = None - @staticmethod - def _parse_string_to_dict(field_str: str) -> dict | None: - """Parse a string that might be a JSON or Python dict representation. - - Args: - field_str: String that might be a dict representation - - Returns: - dict if successfully parsed as dict, None otherwise - """ - # Try JSON first (handles double-quoted strings) - try: - parsed = json.loads(field_str) - return parsed if isinstance(parsed, dict) else None - except json.JSONDecodeError: - pass - - # Try Python literal eval (handles single-quoted strings) - try: - parsed = ast.literal_eval(field_str) - return parsed if isinstance(parsed, dict) else None - except (ValueError, SyntaxError): - return None - - @staticmethod - def _dict_to_field(field_dict: dict) -> Field: - """Convert a dict to a Field object. - - Args: - field_dict: Dictionary with 'name' key and optional 'type' key - - Returns: - Field object - - Raises: - ValueError: If dict doesn't have 'name' key - """ - name = field_dict.get("name") - if name is None: - raise ValueError(f"Field dict must have 'name' key: {field_dict}") - return Field(name=name, type=field_dict.get("type")) - - def _normalize_fields( - self, fields: list[str] | list[Field] | list[dict] - ) -> list[Field]: - """Normalize fields to Field objects. - - Converts strings, Field objects, or dicts to Field objects. - Handles the case where dataclass_wizard may have converted dicts to JSON strings. - Field objects behave like strings for backward compatibility. - - Args: - fields: List of strings, Field objects, or dicts - - Returns: - list[Field]: Normalized list of Field objects (preserving order) - """ - normalized = [] - for field in fields: - if isinstance(field, Field): - normalized.append(field) - elif isinstance(field, dict): - normalized.append(self._dict_to_field(field)) - elif isinstance(field, str): - # Try to parse as dict (JSON or Python literal) - parsed_dict = self._parse_string_to_dict(field) - if parsed_dict: - normalized.append(self._dict_to_field(parsed_dict)) - else: - # Plain field name - normalized.append(Field(name=field, type=None)) + @field_validator("fields", mode="before") + @classmethod + def convert_to_fields(cls, v: Any) -> Any: + if not isinstance(v, list): + raise ValueError("fields must be a list") + return [_normalize_fields_item(item) for item in v] + + @field_validator("indexes", mode="before") + @classmethod + def convert_to_indexes(cls, v: Any) -> Any: + if not isinstance(v, list): + return v + result = [] + for item in v: + if isinstance(item, dict): + result.append(Index.model_validate(item)) + else: + result.append(item) + return result + + @field_validator("filters", mode="before") + @classmethod + def convert_to_expressions(cls, v: Any) -> Any: + if not isinstance(v, list): + return v + result: list[Any] = [] + for item in v: + if isinstance(item, dict): + result.append(Clause.from_dict(item)) else: - raise TypeError(f"Field must be str, Field, or dict, got {type(field)}") - return normalized + result.append(item) + return result + + @model_validator(mode="after") + def set_dbname_and_indexes(self) -> "Vertex": + if self.dbname is None: + object.__setattr__(self, "dbname", self.name) + indexes = list(self.indexes) + if not indexes: + object.__setattr__( + self, + "indexes", + [Index(fields=[f.name for f in self.fields])], + ) + else: + seen_names = {f.name for f in self.fields} + new_fields = list(self.fields) + for idx in indexes: + for field_name in idx.fields: + if field_name not in seen_names: + new_fields.append(Field(name=field_name, type=None)) + seen_names.add(field_name) + object.__setattr__(self, "fields", new_fields) + return self @property def field_names(self) -> list[str]: - """Get list of field names (as strings). - - Returns: - list[str]: List of field names - """ + """Get list of field names (as strings).""" return [field.name for field in self.fields] def get_fields(self) -> list[Field]: return self.fields - def __post_init__(self): - """Initialize the vertex after dataclass initialization. - - Sets the database name if not provided, normalizes fields to Field objects, - and updates fields based on indexes. Field objects behave like strings, - maintaining backward compatibility. - """ - if self.dbname is None: - self.dbname = self.name - - # Normalize fields to Field objects (preserve order) - self.fields = self._normalize_fields(self.fields) - - # Normalize indexes to Index objects if they're dicts - normalized_indexes = [] - for idx in self.indexes: - if isinstance(idx, dict): - normalized_indexes.append(Index.from_dict(idx)) - else: - normalized_indexes.append(idx) - self.indexes = normalized_indexes - - if not self.indexes: - # Index expects list[str], but Field objects convert to strings automatically - # via __str__, so we extract names - self.indexes = [Index(fields=self.field_names)] - - # Collect field names from existing fields (preserve order) - seen_names = {f.name for f in self.fields} - # Add index fields that aren't already present (preserve original order, append new) - for idx in self.indexes: - for field_name in idx.fields: - if field_name not in seen_names: - # Add new field, preserving order by adding to end - self.fields.append(Field(name=field_name, type=None)) - seen_names.add(field_name) - def finish_init(self, db_flavor: DBType): """Complete initialization of vertex with database-specific field types. @@ -329,8 +293,7 @@ def finish_init(self, db_flavor: DBType): ] -@dataclasses.dataclass -class VertexConfig(BaseDataclass): +class VertexConfig(ConfigBaseModel): """Configuration for managing vertices. This class manages vertices, providing methods for accessing @@ -343,31 +306,35 @@ class VertexConfig(BaseDataclass): db_flavor: Database flavor (ARANGO or NEO4J) """ + # Allow extra keys when loading from YAML (e.g. vertex_config wrapper key) + model_config = ConfigDict(extra="ignore") + vertices: list[Vertex] - blank_vertices: list[str] = dataclasses.field(default_factory=list) - force_types: dict[str, list] = dataclasses.field(default_factory=dict) + blank_vertices: list[str] = PydanticField(default_factory=list) + force_types: dict[str, list] = PydanticField(default_factory=dict) db_flavor: DBType = DBType.ARANGO - def __post_init__(self): - """Initialize the vertex configuration. - - Creates internal mappings and validates blank vertices. - - Raises: - ValueError: If blank vertices are not defined in the configuration - """ - self._vertices_map: dict[str, Vertex] = { - item.name: item for item in self.vertices - } - - # TODO replace by types - # vertex_name -> [numeric fields] - self._vertex_numeric_fields_map = {} + _vertices_map: dict[str, Vertex] | None = PrivateAttr(default=None) + _vertex_numeric_fields_map: dict[str, object] | None = PrivateAttr(default=None) + @model_validator(mode="after") + def build_vertices_map_and_validate_blank(self) -> "VertexConfig": + object.__setattr__( + self, + "_vertices_map", + {item.name: item for item in self.vertices}, + ) + object.__setattr__(self, "_vertex_numeric_fields_map", {}) if set(self.blank_vertices) - set(self.vertex_set): raise ValueError( f" Blank vertices {self.blank_vertices} are not defined as vertices" ) + return self + + def _get_vertices_map(self) -> dict[str, Vertex]: + """Return the vertices map (set by model validator).""" + assert self._vertices_map is not None, "VertexConfig not fully initialized" + return self._vertices_map @property def vertex_set(self): @@ -376,7 +343,7 @@ def vertex_set(self): Returns: set[str]: Set of vertex names """ - return set(self._vertices_map.keys()) + return set(self._get_vertices_map().keys()) @property def vertex_list(self): @@ -385,7 +352,7 @@ def vertex_list(self): Returns: list[Vertex]: List of vertex configurations """ - return list(self._vertices_map.values()) + return list(self._get_vertices_map().values()) def _get_vertex_by_name_or_dbname(self, identifier: str) -> Vertex: """Get vertex by name or dbname. @@ -399,18 +366,19 @@ def _get_vertex_by_name_or_dbname(self, identifier: str) -> Vertex: Raises: KeyError: If vertex is not found by name or dbname """ + m = self._get_vertices_map() # First try by name (most common case) - if identifier in self._vertices_map: - return self._vertices_map[identifier] + if identifier in m: + return m[identifier] # Try by dbname - for vertex in self._vertices_map.values(): + for vertex in m.values(): if vertex.dbname == identifier: return vertex # Not found - available_names = list(self._vertices_map.keys()) - available_dbnames = [v.dbname for v in self._vertices_map.values()] + available_names = list(m.keys()) + available_dbnames = [v.dbname for v in m.values()] raise KeyError( f"Vertex '{identifier}' not found by name or dbname. " f"Available names: {available_names}, " @@ -429,13 +397,12 @@ def vertex_dbname(self, vertex_name): Raises: KeyError: If vertex is not found """ + m = self._get_vertices_map() try: - value = self._vertices_map[vertex_name].dbname + value = m[vertex_name].dbname except KeyError as e: logger.error( - "Available vertices :" - f" {self._vertices_map.keys()}; vertex" - f" requested : {vertex_name}" + f"Available vertices : {m.keys()}; vertex requested : {vertex_name}" ) raise e return value @@ -449,7 +416,7 @@ def index(self, vertex_name) -> Index: Returns: Index: Primary index for the vertex """ - return self._vertices_map[vertex_name].indexes[0] + return self._get_vertices_map()[vertex_name].indexes[0] def indexes(self, vertex_name) -> list[Index]: """Get all indexes for a vertex. @@ -460,7 +427,7 @@ def indexes(self, vertex_name) -> list[Index]: Returns: list[Index]: List of indexes for the vertex """ - return self._vertices_map[vertex_name].indexes + return self._get_vertices_map()[vertex_name].indexes def fields(self, vertex_name: str) -> list[Field]: """Get fields for a vertex. @@ -504,8 +471,9 @@ def numeric_fields_list(self, vertex_name): ValueError: If vertex is not defined in config """ if vertex_name in self.vertex_set: - if vertex_name in self._vertex_numeric_fields_map: - return self._vertex_numeric_fields_map[vertex_name] + nmap = self._vertex_numeric_fields_map + if nmap is not None and vertex_name in nmap: + return nmap[vertex_name] else: return () else: @@ -514,17 +482,18 @@ def numeric_fields_list(self, vertex_name): f" {vertex_name} was not defined in config" ) - def filters(self, vertex_name) -> list[Expression]: - """Get filter expressions for a vertex. + def filters(self, vertex_name) -> list[Clause]: + """Get filter clauses for a vertex. Args: vertex_name: Name of the vertex Returns: - list[Expression]: List of filter expressions + list[Clause]: List of filter clauses """ - if vertex_name in self._vertices_map: - return self._vertices_map[vertex_name].filters + m = self._get_vertices_map() + if vertex_name in m: + return m[vertex_name].filters else: return [] @@ -540,8 +509,9 @@ def remove_vertices(self, names: set[str]) -> None: if not names: return self.vertices[:] = [v for v in self.vertices if v.name not in names] + m = self._get_vertices_map() for n in names: - self._vertices_map.pop(n, None) + m.pop(n, None) self.blank_vertices[:] = [b for b in self.blank_vertices if b not in names] def update_vertex(self, v: Vertex): @@ -550,7 +520,7 @@ def update_vertex(self, v: Vertex): Args: v: Vertex configuration to update """ - self._vertices_map[v.name] = v + self._get_vertices_map()[v.name] = v def __getitem__(self, key: str): """Get vertex configuration by name. @@ -564,8 +534,9 @@ def __getitem__(self, key: str): Raises: KeyError: If vertex is not found """ - if key in self._vertices_map: - return self._vertices_map[key] + m = self._get_vertices_map() + if key in m: + return m[key] else: raise KeyError(f"Vertex {key} absent") @@ -576,7 +547,7 @@ def __setitem__(self, key: str, value: Vertex): key: Vertex name value: Vertex configuration """ - self._vertices_map[key] = value + self._get_vertices_map()[key] = value def finish_init(self): """Complete initialization of all vertices with database-specific field types. diff --git a/graflo/data_source/factory.py b/graflo/data_source/factory.py index 329f0817..8bc6b066 100644 --- a/graflo/data_source/factory.py +++ b/graflo/data_source/factory.py @@ -226,7 +226,6 @@ def create_data_source( pagination = None if pagination_dict is not None: if isinstance(pagination_dict, dict): - # Manually construct PaginationConfig to avoid dataclass_wizard issues pagination = PaginationConfig(**pagination_dict) else: pagination = pagination_dict @@ -243,7 +242,6 @@ def create_data_source( pagination = None if pagination_dict is not None: if isinstance(pagination_dict, dict): - # Manually construct PaginationConfig to avoid dataclass_wizard issues pagination = PaginationConfig(**pagination_dict) else: pagination = pagination_dict diff --git a/graflo/db/arango/conn.py b/graflo/db/arango/conn.py index d6cad87d..4a157cc7 100644 --- a/graflo/db/arango/conn.py +++ b/graflo/db/arango/conn.py @@ -64,8 +64,11 @@ class ArangoConnection(Connection): Attributes: conn: ArangoDB database connection instance + flavor: Database type (ARANGO) for expression flavor mapping """ + flavor = DBType.ARANGO + def __init__(self, config: ArangoConfig): """Initialize ArangoDB connection. @@ -410,6 +413,11 @@ def define_edge_classes(self, edges: list[Edge]) -> None: logger.warning("Edge has no database_name, skipping") continue if not g.has_edge_definition(collection_name): + if item._source is None or item._target is None: + logger.warning( + "Edge has no _source or _target, skipping edge definition" + ) + continue _ = g.create_edge_definition( edge_collection=collection_name, from_vertex_collections=[item._source], diff --git a/graflo/db/arango/query.py b/graflo/db/arango/query.py index faa7deef..bcf0a50b 100644 --- a/graflo/db/arango/query.py +++ b/graflo/db/arango/query.py @@ -21,8 +21,8 @@ from arango import ArangoClient -from graflo.filter.onto import Expression -from graflo.onto import DBType +from graflo.filter.onto import Clause, Expression +from graflo.onto import ExpressionFlavor logger = logging.getLogger(__name__) @@ -129,7 +129,7 @@ def fetch_fields_query( docs, match_keys, keep_keys, - filters: list | dict | None = None, + filters: list | dict | Clause | None = None, ): """Generate and execute a field-fetching AQL query. @@ -165,8 +165,8 @@ def fetch_fields_query( keep_clause = f"KEEP(_x, {list(keep_keys)})" if keep_keys is not None else "_x" if filters is not None: - ff = Expression.from_dict(filters) - extrac_filter_clause = f" && {ff(doc_name='_cdoc', kind=DBType.ARANGO)}" + ff = filters if isinstance(filters, Clause) else Expression.from_dict(filters) + extrac_filter_clause = f" && {ff(doc_name='_cdoc', kind=ExpressionFlavor.AQL)}" else: extrac_filter_clause = "" diff --git a/graflo/db/arango/util.py b/graflo/db/arango/util.py index 11815292..f104291c 100644 --- a/graflo/db/arango/util.py +++ b/graflo/db/arango/util.py @@ -16,7 +16,7 @@ import logging from graflo.architecture.edge import Edge -from graflo.filter.onto import Clause, Expression +from graflo.filter.onto import Clause from graflo.onto import ExpressionFlavor logger = logging.getLogger(__name__) @@ -42,14 +42,17 @@ def define_extra_edges(g: Edge): >>> # Generates query to create user->post edges through comments """ ucol, vcol, wcol = g.source, g.target, g.by - weight = g.weight_dict + weight = g.weights s = f"""FOR w IN {wcol} LET uset = (FOR u IN 1..1 INBOUND w {ucol}_{wcol}_edges RETURN u) LET vset = (FOR v IN 1..1 INBOUND w {vcol}_{wcol}_edges RETURN v) FOR u in uset FOR v in vset """ - s_ins_ = ", ".join([f"{v}: w.{k}" for k, v in weight.items()]) + if weight is None: + raise ValueError("WeightConfig is required for edge list rendering") + # WeightConfig.direct is list[Field]; AQL copies each field from w to the edge + s_ins_ = ", ".join([f"{f.name}: w.{f.name}" for f in weight.direct]) s_ins_ = f"_from: u._id, _to: v._id, {s_ins_}" s_ins = f" INSERT {{{s_ins_}}} " s_last = f"IN {ucol}_{vcol}_edges" @@ -77,10 +80,10 @@ def render_filters(filters: None | list | dict | Clause = None, doc_name="d") -> """ if filters is not None: if not isinstance(filters, Clause): - ff = Expression.from_dict(filters) + ff = Clause.from_dict(filters) else: ff = filters - literal_condition = ff(doc_name=doc_name, kind=ExpressionFlavor.ARANGO) + literal_condition = ff(doc_name=doc_name, kind=ExpressionFlavor.AQL) filter_clause = f"FILTER {literal_condition}" else: filter_clause = "" diff --git a/graflo/db/conn.py b/graflo/db/conn.py index a51bdda1..1f24347f 100644 --- a/graflo/db/conn.py +++ b/graflo/db/conn.py @@ -52,12 +52,17 @@ import abc import logging -from typing import Any, TypeVar +from typing import Any, ClassVar, TypeVar from graflo.architecture.edge import Edge from graflo.architecture.schema import Schema from graflo.architecture.vertex import VertexConfig -from graflo.onto import AggregationType +from graflo.onto import ( + AggregationType, + DB_TYPE_TO_EXPRESSION_FLAVOR, + DBType, + ExpressionFlavor, +) logger = logging.getLogger(__name__) ConnectionType = TypeVar("ConnectionType", bound="Connection") @@ -80,13 +85,25 @@ class Connection(abc.ABC): Note: All methods marked with @abc.abstractmethod must be implemented by - concrete connection classes. + concrete connection classes. Subclasses must set the class attribute + `flavor` to their DBType. """ + flavor: ClassVar[DBType] = DBType.ARANGO # Overridden by subclasses + def __init__(self): """Initialize the connection.""" pass + @classmethod + def expression_flavor(cls) -> ExpressionFlavor: + """Expression flavor for filter rendering (AQL, CYPHER, GSQL). + + Graph connection subclasses must set class attribute `flavor` to a + DBType present in DB_TYPE_TO_EXPRESSION_FLAVOR. + """ + return DB_TYPE_TO_EXPRESSION_FLAVOR[cls.flavor] + @abc.abstractmethod def create_database(self, name: str): """Create a new database. diff --git a/graflo/db/falkordb/conn.py b/graflo/db/falkordb/conn.py index 07fd1e46..aa3d5d46 100644 --- a/graflo/db/falkordb/conn.py +++ b/graflo/db/falkordb/conn.py @@ -38,7 +38,7 @@ from graflo.db.conn import Connection, SchemaExistsError from graflo.db.util import serialize_value from graflo.filter.onto import Expression -from graflo.onto import AggregationType, ExpressionFlavor +from graflo.onto import AggregationType from graflo.onto import DBType @@ -690,8 +690,7 @@ def fetch_docs( # Build filter clause if filters is not None: ff = Expression.from_dict(filters) - # Use NEO4J flavor since FalkorDB uses OpenCypher - filter_clause = f"WHERE {ff(doc_name='n', kind=DBType.NEO4J)}" + filter_clause = f"WHERE {ff(doc_name='n', kind=self.expression_flavor())}" else: filter_clause = "" @@ -796,7 +795,7 @@ def fetch_edges( # Add additional filters if provided if filters is not None: ff = Expression.from_dict(filters) - filter_clause = ff(doc_name="r", kind=ExpressionFlavor.NEO4J) + filter_clause = ff(doc_name="r", kind=self.expression_flavor()) where_clauses.append(filter_clause) where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" @@ -914,7 +913,7 @@ def aggregate( # Build filter clause if filters is not None: ff = Expression.from_dict(filters) - filter_clause = f"WHERE {ff(doc_name='n', kind=DBType.NEO4J)}" + filter_clause = f"WHERE {ff(doc_name='n', kind=self.expression_flavor())}" else: filter_clause = "" diff --git a/graflo/db/memgraph/conn.py b/graflo/db/memgraph/conn.py index 97079133..44c7d969 100644 --- a/graflo/db/memgraph/conn.py +++ b/graflo/db/memgraph/conn.py @@ -88,7 +88,7 @@ from graflo.architecture.vertex import VertexConfig from graflo.db.conn import Connection, SchemaExistsError from graflo.filter.onto import Expression -from graflo.onto import AggregationType, ExpressionFlavor +from graflo.onto import AggregationType from graflo.onto import DBType @@ -877,7 +877,7 @@ def fetch_docs( if filters is not None: ff = Expression.from_dict(filters) - filter_str = ff(doc_name="n", kind=ExpressionFlavor.NEO4J) + filter_str = ff(doc_name="n", kind=self.expression_flavor()) q += f" WHERE {filter_str}" # Handle projection @@ -972,7 +972,7 @@ def fetch_edges( # Add relationship property filters if filters is not None: ff = Expression.from_dict(filters) - filter_str = ff(doc_name="r", kind=ExpressionFlavor.NEO4J) + filter_str = ff(doc_name="r", kind=self.expression_flavor()) where_clauses.append(filter_str) if where_clauses: @@ -1114,7 +1114,7 @@ def aggregate( filter_clause = "" if filters is not None: ff = Expression.from_dict(filters) - filter_str = ff(doc_name="n", kind=ExpressionFlavor.NEO4J) + filter_str = ff(doc_name="n", kind=self.expression_flavor()) filter_clause = f" WHERE {filter_str}" q = f"MATCH (n:{class_name}){filter_clause}" diff --git a/graflo/db/neo4j/conn.py b/graflo/db/neo4j/conn.py index fc1cdf6d..9b48d147 100644 --- a/graflo/db/neo4j/conn.py +++ b/graflo/db/neo4j/conn.py @@ -34,7 +34,7 @@ from graflo.architecture.vertex import VertexConfig from graflo.db.conn import Connection, SchemaExistsError from graflo.filter.onto import Expression -from graflo.onto import AggregationType, ExpressionFlavor +from graflo.onto import AggregationType from graflo.onto import DBType @@ -510,7 +510,7 @@ def fetch_docs( """ if filters is not None: ff = Expression.from_dict(filters) - filter_clause = f"WHERE {ff(doc_name='n', kind=DBType.NEO4J)}" + filter_clause = f"WHERE {ff(doc_name='n', kind=self.expression_flavor())}" else: filter_clause = "" @@ -592,7 +592,7 @@ def fetch_edges( from graflo.filter.onto import Expression ff = Expression.from_dict(filters) - filter_clause = ff(doc_name="r", kind=ExpressionFlavor.NEO4J) + filter_clause = ff(doc_name="r", kind=self.expression_flavor()) where_clauses.append(filter_clause) where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" diff --git a/graflo/db/postgres/resource_mapping.py b/graflo/db/postgres/resource_mapping.py index b59ef48b..51408f13 100644 --- a/graflo/db/postgres/resource_mapping.py +++ b/graflo/db/postgres/resource_mapping.py @@ -116,8 +116,8 @@ def create_edge_resource( ) # Get primary key fields for source and target vertices - source_vertex_obj = vertex_config._vertices_map[source_table] - target_vertex_obj = vertex_config._vertices_map[target_table] + source_vertex_obj = vertex_config[source_table] + target_vertex_obj = vertex_config[target_table] # Get the primary key field(s) from the first index (primary key) source_pk_fields = ( diff --git a/graflo/db/postgres/schema_inference.py b/graflo/db/postgres/schema_inference.py index 2d39b829..f5b1a34b 100644 --- a/graflo/db/postgres/schema_inference.py +++ b/graflo/db/postgres/schema_inference.py @@ -71,7 +71,8 @@ def infer_vertex_config( fields = [] for col in columns: field_name = col.name - field_type = self.type_mapper.map_type(col.type) + raw_type = self.type_mapper.map_type(col.type) + field_type = FieldType(raw_type) if raw_type else None fields.append(Field(name=field_name, type=field_type)) # Create indexes from primary key @@ -256,13 +257,15 @@ def infer_edge_weights(self, edge_table_info: EdgeTableInfo) -> WeightConfig | N direct_weights = [] for col in weight_columns: # Infer type: use PostgreSQL type first, then sample if needed - field_type = self._infer_type_from_samples( + raw_type = self._infer_type_from_samples( edge_table_info.name, edge_table_info.schema_name, col.name, col.type, ) - direct_weights.append(Field(name=col.name, type=field_type)) + direct_weights.append( + Field(name=col.name, type=FieldType(raw_type) if raw_type else None) + ) logger.debug( f"Inferred {len(direct_weights)} weights for edge table " @@ -313,11 +316,11 @@ def infer_edge_config( # Infer weights weights = self.infer_edge_weights(edge_table_info) indexes = [] - # Create edge + # Create edge (use alias "index" for indexes field) edge = Edge( source=source_table, target=target_table, - indexes=indexes, + index=indexes, weights=weights, relation=edge_table_info.relation, ) diff --git a/graflo/db/tigergraph/conn.py b/graflo/db/tigergraph/conn.py index 4d502061..6cf0e2ef 100644 --- a/graflo/db/tigergraph/conn.py +++ b/graflo/db/tigergraph/conn.py @@ -49,8 +49,8 @@ VALID_TIGERGRAPH_TYPES, ) from graflo.db.util import json_serializer -from graflo.filter.onto import Clause, Expression -from graflo.onto import AggregationType, ExpressionFlavor +from graflo.filter.onto import Clause +from graflo.onto import AggregationType from graflo.onto import DBType from graflo.util.transform import pick_unique_dict from urllib.parse import quote @@ -2057,6 +2057,8 @@ def _define_schema_local(self, schema: Schema) -> None: # Vertices for vertex in vertex_config.vertices: # Validate vertex name + if vertex.dbname is None: + raise ValueError(f"Vertex {vertex.name!r} has no dbname") _validate_tigergraph_schema_name(vertex.dbname, "vertex") stmt = self._get_vertex_add_statement(vertex, vertex_config) vertex_stmts.append(stmt) @@ -3697,18 +3699,18 @@ def _render_rest_filter( """ if filters is not None: if not isinstance(filters, Clause): - ff = Expression.from_dict(filters) + ff = Clause.from_dict(filters) else: ff = filters - # Use ExpressionFlavor.TIGERGRAPH with empty doc_name to trigger REST++ format + # Use GSQL flavor with empty doc_name to trigger REST++ format # Pass field_types to help with proper value quoting - filter_str = ff( + result = ff( doc_name="", - kind=ExpressionFlavor.TIGERGRAPH, + kind=self.expression_flavor(), field_types=field_types, ) - return filter_str + return result if isinstance(result, str) else "" else: return "" diff --git a/graflo/filter/onto.py b/graflo/filter/onto.py index ef65bd79..b7a2837a 100644 --- a/graflo/filter/onto.py +++ b/graflo/filter/onto.py @@ -7,13 +7,10 @@ Key Components: - LogicalOperator: Enum for logical operations (AND, OR, NOT, IMPLICATION) - ComparisonOperator: Enum for comparison operations (==, !=, >, <, etc.) - - AbsClause: Abstract base class for filter clauses - - LeafClause: Concrete clause for field comparisons - - Clause: Composite clause combining multiple sub-clauses - - Expression: Factory class for creating filter expressions from dictionaries + - Clause: Unified filter clause (discriminated: kind="leaf" or kind="composite") Example: - >>> expr = Expression.from_dict({ + >>> expr = Clause.from_dict({ ... "AND": [ ... {"field": "age", "cmp_operator": ">=", "value": 18}, ... {"field": "status", "cmp_operator": "==", "value": "active"} @@ -22,13 +19,15 @@ >>> # Converts to: "age >= 18 AND status == 'active'" """ -import dataclasses +from __future__ import annotations + import logging -from abc import ABCMeta, abstractmethod from types import MappingProxyType -from typing import Any +from typing import Any, Literal, Self -from graflo.onto import BaseDataclass, BaseEnum, ExpressionFlavor +from graflo.architecture.base import ConfigBaseModel +from graflo.onto import BaseEnum, ExpressionFlavor +from pydantic import Field, field_validator, model_validator logger = logging.getLogger(__name__) @@ -93,108 +92,176 @@ class ComparisonOperator(BaseEnum): IN = "IN" -@dataclasses.dataclass -class AbsClause(BaseDataclass, metaclass=ABCMeta): - """Abstract base class for filter clauses. +class Clause(ConfigBaseModel): + """Unified filter clause (discriminated: leaf or composite). - This class defines the interface for all filter clauses, requiring - implementation of the __call__ method to evaluate or render the clause. + - kind="leaf": single field comparison (field, cmp_operator, value, optional unary_op). + - kind="composite": logical combination (operator AND/OR/NOT/IF_THEN, deps). """ - @abstractmethod - def __call__( - self, - doc_name, - kind: ExpressionFlavor = ExpressionFlavor.ARANGO, - **kwargs, - ): - """Evaluate or render the clause. - - Args: - doc_name: Name of the document variable in the query - kind: Target expression flavor (ARANGO, NEO4J, PYTHON) - **kwargs: Additional arguments for evaluation + kind: Literal["leaf", "composite"] - Returns: - str: Rendered clause in the target language - """ - pass + # Leaf fields (used when kind="leaf") + cmp_operator: ComparisonOperator | None = None + value: list[Any] = Field(default_factory=list) + field: str | None = None + unary_op: str | None = ( + None # optional operator before comparison (YAML key: "operator") + ) + # Composite fields (used when kind="composite") + operator: LogicalOperator | None = None # AND, OR, NOT, IF_THEN + deps: list[Clause] = Field(default_factory=list) -@dataclasses.dataclass -class LeafClause(AbsClause): - """Concrete clause for field comparisons. + @field_validator("value", mode="before") + @classmethod + def value_to_list(cls, v: list[Any] | Any) -> list[Any]: + """Convert single value to list if necessary. Explicit None becomes [None] for null comparison.""" + if v is None: + return [None] + if isinstance(v, list): + return v + return [v] + + @model_validator(mode="before") + @classmethod + def leaf_operator_to_unary_op(cls, data: Any) -> Any: + """Map leaf 'operator' (YAML/kwargs) to unary_op; infer kind=leaf when missing.""" + if not isinstance(data, dict): + return data + # Only map operator -> unary_op for leaf clauses (never for composite) + if data.get("kind") == "composite": + return data + if "operator" in data and isinstance(data["operator"], str): + data = dict(data) + data["unary_op"] = data.pop("operator") + if data.get("kind") is None: + data["kind"] = "leaf" + return data + + @model_validator(mode="after") + def check_discriminated_shape(self) -> Clause: + """Enforce exactly one shape per kind.""" + if self.kind == "leaf": + if self.operator is not None or self.deps: + raise ValueError("leaf clause must not have operator or deps") + else: + if self.operator is None: + raise ValueError("composite clause must have operator") + return self - This class represents a single field comparison operation, such as - "field >= value" or "field IN [values]". + @field_validator("deps", mode="before") + @classmethod + def parse_deps(cls, v: list[Any]) -> list[Any]: + """Parse dict/list items into Clause instances.""" + if not isinstance(v, list): + return v + result = [] + for item in v: + if isinstance(item, (dict, list)): + result.append(Clause.from_dict(item)) + else: + result.append(item) + return result - Attributes: - cmp_operator: Comparison operator to use - value: Value(s) to compare against - field: Field name to compare - operator: Optional operator to apply before comparison - """ + @classmethod + def from_list(cls, current: list[Any]) -> Clause: + """Build a leaf clause from list form [cmp_operator, value, field?, unary_op?].""" + cmp_operator = current[0] + value = current[1] + field = current[2] if len(current) > 2 else None + unary_op = current[3] if len(current) > 3 else None + return cls( + kind="leaf", + cmp_operator=cmp_operator, + value=value, + field=field, + unary_op=unary_op, + ) - cmp_operator: ComparisonOperator | None = None - value: list = dataclasses.field(default_factory=list) - field: str | None = None - operator: str | None = None + @classmethod + def from_dict(cls, current: dict[str, Any] | list[Any]) -> Self: # type: ignore[override] + """Create a filter clause from a dictionary or list. - def __post_init__(self): - """Convert single value to list if necessary.""" - if not isinstance(self.value, list): - self.value = [self.value] + Returns Clause (leaf or composite). LSP-compliant: return type is Self. + """ + if isinstance(current, list): + if current[0] in ComparisonOperator: + return cls.from_list(current) # type: ignore[return-value] + elif current[0] in LogicalOperator: + return cls(kind="composite", operator=current[0], deps=current[1]) + elif isinstance(current, dict): + k = list(current.keys())[0] + if k in LogicalOperator: + clauses = [cls.from_dict(v) for v in current[k]] + return cls(kind="composite", operator=LogicalOperator(k), deps=clauses) + else: + # Leaf from dict: map YAML "operator" -> unary_op + unary_op = current.get("operator") + return cls( + kind="leaf", + cmp_operator=current.get("cmp_operator"), + value=current.get("value", []), + field=current.get("field"), + unary_op=unary_op, + ) + raise ValueError(f"expected dict or list, got {type(current)}") def __call__( self, doc_name="doc", - kind: ExpressionFlavor = ExpressionFlavor.ARANGO, + kind: ExpressionFlavor = ExpressionFlavor.AQL, **kwargs, - ): - """Render the leaf clause in the target language. + ) -> str | bool: + """Render or evaluate the clause in the target language.""" + if self.kind == "leaf": + return self._call_leaf(doc_name=doc_name, kind=kind, **kwargs) + return self._call_composite(doc_name=doc_name, kind=kind, **kwargs) - Args: - doc_name: Name of the document variable - kind: Target expression flavor - **kwargs: Additional arguments (may include field_types for REST++) - - Returns: - str: Rendered clause - - Raises: - ValueError: If kind is not implemented - """ + def _call_leaf( + self, + doc_name="doc", + kind: ExpressionFlavor = ExpressionFlavor.AQL, + **kwargs, + ) -> str | bool: if not self.value: logger.warning(f"for {self} value is not set : {self.value}") - if kind == ExpressionFlavor.ARANGO: + if kind == ExpressionFlavor.AQL: assert self.cmp_operator is not None return self._cast_arango(doc_name) - elif kind == ExpressionFlavor.NEO4J: + elif kind == ExpressionFlavor.CYPHER: assert self.cmp_operator is not None return self._cast_cypher(doc_name) - elif kind == ExpressionFlavor.TIGERGRAPH: + elif kind == ExpressionFlavor.GSQL: assert self.cmp_operator is not None - # Check if this is for REST++ API (no doc_name prefix) if doc_name == "": field_types = kwargs.get("field_types") return self._cast_restpp(field_types=field_types) - else: - return self._cast_tigergraph(doc_name) + return self._cast_tigergraph(doc_name) elif kind == ExpressionFlavor.PYTHON: return self._cast_python(**kwargs) - else: - raise ValueError(f"kind {kind} not implemented") + raise ValueError(f"kind {kind} not implemented") - def _cast_value(self): - """Format the comparison value for query rendering. + def _call_composite( + self, + doc_name="doc", + kind: ExpressionFlavor = ExpressionFlavor.AQL, + **kwargs, + ) -> str | bool: + if kind in ( + ExpressionFlavor.AQL, + ExpressionFlavor.CYPHER, + ExpressionFlavor.GSQL, + ): + return self._cast_generic(doc_name=doc_name, kind=kind) + elif kind == ExpressionFlavor.PYTHON: + return self._cast_python_composite(kind=kind, **kwargs) + raise ValueError(f"kind {kind} not implemented") - Returns: - str: Formatted value string - """ + def _cast_value(self) -> str: value = f"{self.value[0]}" if len(self.value) == 1 else f"{self.value}" if len(self.value) == 1: if isinstance(self.value[0], str): - # Escape backslashes first, then double quotes escaped = self.value[0].replace("\\", "\\\\").replace('"', '\\"') value = f'"{escaped}"' elif self.value[0] is None: @@ -203,87 +270,42 @@ def _cast_value(self): value = f"{self.value[0]}" return value - def _cast_arango(self, doc_name): - """Render the clause in AQL format. - - Args: - doc_name: Document variable name - - Returns: - str: AQL clause - """ + def _cast_arango(self, doc_name: str) -> str: const = self._cast_value() - lemma = f"{self.cmp_operator} {const}" - if self.operator is not None: - lemma = f"{self.operator} {lemma}" - + if self.unary_op is not None: + lemma = f"{self.unary_op} {lemma}" if self.field is not None: lemma = f'{doc_name}["{self.field}"] {lemma}' return lemma - def _cast_cypher(self, doc_name): - """Render the clause in Cypher format. - - Args: - doc_name: Document variable name - - Returns: - str: Cypher clause - """ + def _cast_cypher(self, doc_name: str) -> str: const = self._cast_value() - if self.cmp_operator == ComparisonOperator.EQ: - cmp_operator = "=" - else: - cmp_operator = self.cmp_operator - lemma = f"{cmp_operator} {const}" - if self.operator is not None: - lemma = f"{self.operator} {lemma}" - + cmp_op = ( + "=" if self.cmp_operator == ComparisonOperator.EQ else self.cmp_operator + ) + lemma = f"{cmp_op} {const}" + if self.unary_op is not None: + lemma = f"{self.unary_op} {lemma}" if self.field is not None: lemma = f"{doc_name}.{self.field} {lemma}" return lemma - def _cast_tigergraph(self, doc_name): - """Render the clause in GSQL format. - - Args: - doc_name: Document variable name (typically "v" for vertex) - - Returns: - str: GSQL clause - """ + def _cast_tigergraph(self, doc_name: str) -> str: const = self._cast_value() - # GSQL supports both == and =, but == is more common - if self.cmp_operator == ComparisonOperator.EQ: - cmp_operator = "==" - else: - cmp_operator = self.cmp_operator - lemma = f"{cmp_operator} {const}" - if self.operator is not None: - lemma = f"{self.operator} {lemma}" - + cmp_op = ( + "==" if self.cmp_operator == ComparisonOperator.EQ else self.cmp_operator + ) + lemma = f"{cmp_op} {const}" + if self.unary_op is not None: + lemma = f"{self.unary_op} {lemma}" if self.field is not None: lemma = f"{doc_name}.{self.field} {lemma}" return lemma - def _cast_restpp(self, field_types: dict[str, Any] | None = None): - """Render the clause in REST++ filter format. - - REST++ filter format: "field=value" or "field>value" etc. - Format: fieldoperatorvalue (no spaces, quotes for string values) - Example: "hindex=10" or "hindex>20" or 'name="John"' - - Args: - field_types: Optional mapping of field names to FieldType enum values or type strings - - Returns: - str: REST++ filter clause - """ + def _cast_restpp(self, field_types: dict[str, Any] | None = None) -> str: if not self.field: return "" - - # Map operator if self.cmp_operator == ComparisonOperator.EQ: op_str = "=" elif self.cmp_operator == ComparisonOperator.NEQ: @@ -298,230 +320,69 @@ def _cast_restpp(self, field_types: dict[str, Any] | None = None): op_str = "<=" else: op_str = str(self.cmp_operator) - - # Format value for REST++ API - # Use field_types to determine if value should be quoted - # Default: if no explicit type information, treat as string (quote it) value = self.value[0] if self.value else None if value is None: value_str = "null" elif isinstance(value, (int, float)): - # Numeric values: pass as string without quotes value_str = str(value) elif isinstance(value, str): - # Check field type to determine if it's a string field - is_string_field = True # Default: treat as string unless explicitly numeric + is_string_field = True if field_types and self.field in field_types: field_type = field_types[self.field] - # Handle FieldType enum or string type - if hasattr(field_type, "value"): - # It's a FieldType enum - field_type_str = field_type.value - else: - # It's a string - field_type_str = str(field_type).upper() - # Check if it's explicitly a numeric type - numeric_types = ("INT", "UINT", "FLOAT", "DOUBLE") - if field_type_str in numeric_types: - # Explicitly numeric type, don't quote + field_type_str = ( + field_type.value + if hasattr(field_type, "value") + else str(field_type).upper() + ) + if field_type_str in ("INT", "UINT", "FLOAT", "DOUBLE"): is_string_field = False - else: - # Explicitly string type or other (STRING, VARCHAR, TEXT, DATETIME, BOOL, etc.) - # Quote it - is_string_field = True - # If no field_types info, default to treating as string (quote it) - - if is_string_field: - value_str = f'"{value}"' - else: - # Numeric value (explicitly numeric type) - value_str = value + value_str = f'"{value}"' if is_string_field else str(value) else: value_str = str(value) - - # REST++ format: fieldoperatorvalue (no spaces) - # Example: hindex=10, hindex>20, name="John" return f"{self.field}{op_str}{value_str}" - def _cast_python(self, **kwargs): - """Evaluate the clause in Python. - - Args: - **kwargs: Additional arguments - - Returns: - bool: Evaluation result - """ + def _cast_python(self, **kwargs: Any) -> bool: if self.field is not None: - field = kwargs.pop(self.field, None) - if field is not None and self.operator is not None: - foo = getattr(field, self.operator) + field_val = kwargs.pop(self.field, None) + if field_val is not None and self.unary_op is not None: + foo = getattr(field_val, self.unary_op) return foo(self.value[0]) return False - -@dataclasses.dataclass -class Clause(AbsClause): - """Composite clause combining multiple sub-clauses. - - This class represents a logical combination of multiple filter clauses, - such as "clause1 AND clause2" or "NOT clause1". - - Attributes: - operator: Logical operator to combine clauses - deps: List of dependent clauses - """ - - operator: LogicalOperator - deps: list[AbsClause] - - def __call__( - self, - doc_name="doc", - kind: ExpressionFlavor = ExpressionFlavor.ARANGO, - **kwargs, - ): - """Render the composite clause in the target language. - - Args: - doc_name: Document variable name - kind: Target expression flavor - **kwargs: Additional arguments - - Returns: - str: Rendered clause - - Raises: - ValueError: If operator and dependencies don't match - """ - if kind in ( - ExpressionFlavor.ARANGO, - ExpressionFlavor.NEO4J, - ExpressionFlavor.TIGERGRAPH, - ): - return self._cast_generic(doc_name=doc_name, kind=kind) - elif kind == ExpressionFlavor.PYTHON: - return self._cast_python(kind=kind, **kwargs) - - def _cast_generic(self, doc_name, kind): - """Render the clause in a generic format. - - Args: - doc_name: Document variable name - kind: Target expression flavor - - Returns: - str: Rendered clause - - Raises: - ValueError: If operator and dependencies don't match - """ + def _cast_generic(self, doc_name: str, kind: ExpressionFlavor) -> str: + assert self.operator is not None if len(self.deps) == 1: if self.operator == LogicalOperator.NOT: result = self.deps[0](kind=kind, doc_name=doc_name) - # REST++ format uses ! prefix, not "NOT " prefix - if doc_name == "" and kind == ExpressionFlavor.TIGERGRAPH: + if doc_name == "" and kind == ExpressionFlavor.GSQL: return f"!{result}" - else: - return f"{self.operator} {result}" - else: - raise ValueError( - f" length of deps = {len(self.deps)} but operator is not" - f" {LogicalOperator.NOT}" - ) - else: - deps_str = [item(kind=kind, doc_name=doc_name) for item in self.deps] - # REST++ format uses && and || instead of AND and OR - if doc_name == "" and kind == ExpressionFlavor.TIGERGRAPH: - if self.operator == LogicalOperator.AND: - return " && ".join(deps_str) - elif self.operator == LogicalOperator.OR: - return " || ".join(deps_str) - else: - return f" {self.operator} ".join(deps_str) - else: - return f" {self.operator} ".join(deps_str) - - def _cast_python(self, kind, **kwargs): - """Evaluate the clause in Python. - - Args: - kind: Expression flavor - **kwargs: Additional arguments - - Returns: - bool: Evaluation result - - Raises: - ValueError: If operator and dependencies don't match - """ + return f"{self.operator} {result}" + raise ValueError( + f" length of deps = {len(self.deps)} but operator is not {LogicalOperator.NOT}" + ) + deps_str = [dep(kind=kind, doc_name=doc_name) for dep in self.deps] + # __call__ returns str | bool; join expects str + deps_str_cast: list[str] = [str(x) for x in deps_str] + if doc_name == "" and kind == ExpressionFlavor.GSQL: + if self.operator == LogicalOperator.AND: + return " && ".join(deps_str_cast) + if self.operator == LogicalOperator.OR: + return " || ".join(deps_str_cast) + return f" {self.operator} ".join(deps_str_cast) + + def _cast_python_composite(self, kind: ExpressionFlavor, **kwargs: Any) -> bool: + assert self.operator is not None if len(self.deps) == 1: if self.operator == LogicalOperator.NOT: return not self.deps[0](kind=kind, **kwargs) - else: - raise ValueError( - f" length of deps = {len(self.deps)} but operator is not" - f" {LogicalOperator.NOT}" - ) - else: - return OperatorMapping[self.operator]( - [item(kind=kind, **kwargs) for item in self.deps] + raise ValueError( + f" length of deps = {len(self.deps)} but operator is not {LogicalOperator.NOT}" ) + return OperatorMapping[self.operator]( + [dep(kind=kind, **kwargs) for dep in self.deps] + ) -@dataclasses.dataclass -class Expression(AbsClause): - """Factory class for creating filter expressions. - - This class provides methods to create filter expressions from dictionaries - and evaluate them in different languages. - """ - - @classmethod - def from_dict(cls, current): - """Create a filter expression from a dictionary. - - Args: - current: Dictionary or list representing the filter expression - - Returns: - AbsClause: Created filter expression - - Example: - >>> expr = Expression.from_dict({ - ... "AND": [ - ... {"field": "age", "cmp_operator": ">=", "value": 18}, - ... {"field": "status", "cmp_operator": "==", "value": "active"} - ... ] - ... }) - """ - if isinstance(current, list): - if current[0] in ComparisonOperator: - return LeafClause(*current) - elif current[0] in LogicalOperator: - return Clause(*current) - elif isinstance(current, dict): - k = list(current.keys())[0] - if k in LogicalOperator: - clauses = [cls.from_dict(v) for v in current[k]] - return Clause(operator=k, deps=clauses) - else: - return LeafClause(**current) - - def __call__( - self, - doc_name="doc", - kind: ExpressionFlavor = ExpressionFlavor.ARANGO, - **kwargs, - ): - """Evaluate the expression in the target language. - - Args: - doc_name: Document variable name - kind: Target expression flavor - **kwargs: Additional arguments - - Returns: - str: Rendered expression - """ - pass +# Backward compatibility +Expression = Clause +LeafClause = Clause diff --git a/graflo/hq/inferencer.py b/graflo/hq/inferencer.py index 6b2914ef..5073c254 100644 --- a/graflo/hq/inferencer.py +++ b/graflo/hq/inferencer.py @@ -111,7 +111,7 @@ def infer_complete_schema(self, schema_name: str | None = None) -> Schema: schema.resources = resources # Re-initialize to set up resource mappings - schema.__post_init__() + schema.finish_init() return schema diff --git a/graflo/hq/sanitizer.py b/graflo/hq/sanitizer.py index 5f643055..7aac08b2 100644 --- a/graflo/hq/sanitizer.py +++ b/graflo/hq/sanitizer.py @@ -73,14 +73,17 @@ def sanitize(self, schema: Schema) -> Schema: # First pass: Sanitize vertex dbnames for vertex in schema.vertex_config.vertices: + if vertex.dbname is None: + continue + dbname = vertex.dbname sanitized_vertex_name = sanitize_attribute_name( - vertex.dbname, self.reserved_words, suffix=f"_{VERTEX_SUFFIX}" + dbname, self.reserved_words, suffix=f"_{VERTEX_SUFFIX}" ) - if sanitized_vertex_name != vertex.dbname: + if sanitized_vertex_name != dbname: logger.debug( - f"Sanitizing vertex name '{vertex.dbname}' -> '{sanitized_vertex_name}'" + f"Sanitizing vertex name '{dbname}' -> '{sanitized_vertex_name}'" ) - self.vertex_mappings[vertex.dbname] = sanitized_vertex_name + self.vertex_mappings[dbname] = sanitized_vertex_name vertex.dbname = sanitized_vertex_name # Second pass: Sanitize vertex field names @@ -113,6 +116,8 @@ def sanitize(self, schema: Schema) -> Schema: continue original = edge.relation_dbname + if original is None: + continue # First pass: sanitize against reserved words sanitized = sanitize_attribute_name( diff --git a/graflo/onto.py b/graflo/onto.py index c1e7931d..c5636d24 100644 --- a/graflo/onto.py +++ b/graflo/onto.py @@ -97,25 +97,21 @@ def base_enum_representer(dumper, data): class ExpressionFlavor(BaseEnum): - """Supported expression language types. + """Supported expression language types for filter/query rendering. - This enum defines the supported expression languages for querying and - filtering data. + Uses the actual query language names: AQL (ArangoDB), CYPHER (Neo4j, + FalkorDB, Memgraph), GSQL (TigerGraph), PYTHON for in-memory evaluation. Attributes: - ARANGO: ArangoDB AQL expressions - NEO4J: Neo4j Cypher expressions - TIGERGRAPH: TigerGraph GSQL expressions - FALKORDB: FalkorDB Cypher expressions (OpenCypher compatible) - MEMGRAPH: Memgraph Cypher expressions (OpenCypher compatible) - PYTHON: Python expressions + AQL: ArangoDB AQL expressions + CYPHER: OpenCypher expressions (Neo4j, FalkorDB, Memgraph) + GSQL: TigerGraph GSQL expressions (including REST++ filter format) + PYTHON: Python expression evaluation """ - ARANGO = "arango" - NEO4J = "neo4j" - TIGERGRAPH = "tigergraph" - FALKORDB = "falkordb" - MEMGRAPH = "memgraph" + AQL = "aql" + CYPHER = "cypher" + GSQL = "gsql" PYTHON = "python" @@ -319,3 +315,14 @@ class DBType(StrEnum, metaclass=MetaEnum): MYSQL = "mysql" MONGODB = "mongodb" SQLITE = "sqlite" + + +# Mapping from graph DB type to expression flavor for filter rendering. +# Used by Connection subclasses so filters are rendered in the correct language. +DB_TYPE_TO_EXPRESSION_FLAVOR: dict[DBType, ExpressionFlavor] = { + DBType.ARANGO: ExpressionFlavor.AQL, + DBType.NEO4J: ExpressionFlavor.CYPHER, + DBType.FALKORDB: ExpressionFlavor.CYPHER, + DBType.MEMGRAPH: ExpressionFlavor.CYPHER, + DBType.TIGERGRAPH: ExpressionFlavor.GSQL, +} diff --git a/test/architecture/conftest.py b/test/architecture/conftest.py index 1283f5ab..a4ca1d04 100644 --- a/test/architecture/conftest.py +++ b/test/architecture/conftest.py @@ -44,17 +44,17 @@ def vertex_pub(): - OR: - IF_THEN: - field: name - foo: __eq__ + cmp_operator: "==" value: Open - field: value - foo: __gt__ + cmp_operator: ">" value: 0 - IF_THEN: - field: name - foo: __eq__ + cmp_operator: "==" value: Close - field: value - foo: __gt__ + cmp_operator: ">" value: 0 transforms: - foo: cast_ibes_analyst @@ -221,7 +221,7 @@ def edge_config_kg(): - name: publication fields: - _key - - exclude_edge_end_vertices: true + - exclude_edge_endpoints: true unique: false fields: - publication@_key diff --git a/test/architecture/test_actor.py b/test/architecture/test_actor.py index e322d724..4bee74cc 100644 --- a/test/architecture/test_actor.py +++ b/test/architecture/test_actor.py @@ -187,3 +187,38 @@ def test_find_descendants_transform_by_target_vertex( actor_type=TransformActor, vertex={"nonexistent"} ) assert len(by_vertex_empty) == 0 + + +def test_explicit_format_pipeline_transform_create_edge(): + """New explicit format: pipeline with transform (map, to_vertex) and create_edge (from, to).""" + from graflo.architecture.actor_config import ( + TransformActorConfig, + normalize_actor_step, + validate_actor_step, + ) + from graflo.architecture.vertex import VertexConfig + + vc = VertexConfig.from_dict({"vertices": [{"name": "users", "fields": ["id"]}]}) + pipeline = [ + {"transform": {"map": {"follower_id": "id"}, "to_vertex": "users"}}, + {"transform": {"map": {"followed_id": "id"}, "to_vertex": "users"}}, + {"create_edge": {"from": "users", "to": "users"}}, + ] + anw = ActorWrapper(pipeline=pipeline) + anw.finish_init(vertex_config=vc, transforms={}) + # Root is DescendActor with at least 3 descendants: 2 TransformActors, 1 EdgeActor + # (finish_init may auto-add a VertexActor for "users") + assert isinstance(anw.actor, DescendActor) + assert len(anw.actor.descendants) >= 3 + transform_count = sum( + 1 for d in anw.actor.descendants if isinstance(d.actor, TransformActor) + ) + edge_count = sum(1 for d in anw.actor.descendants if isinstance(d.actor, EdgeActor)) + assert transform_count >= 2 and edge_count >= 1 + # Explicit format parses via Pydantic config + step = {"transform": {"map": {"x": "y"}, "to_vertex": "users"}} + config = validate_actor_step(normalize_actor_step(step)) + assert config.type == "transform" + assert isinstance(config, TransformActorConfig) + assert config.map == {"x": "y"} + assert config.to_vertex == "users" diff --git a/test/architecture/test_edge.py b/test/architecture/test_edge.py index 1d203a2a..a4f8c4be 100644 --- a/test/architecture/test_edge.py +++ b/test/architecture/test_edge.py @@ -14,7 +14,7 @@ def test_weight_config_b(vertex_helper_b): def test_init_edge(edge_with_weights): vc = Edge.from_dict(edge_with_weights) - assert len(vc.weights.vertices) == 2 + assert vc.weights is not None and len(vc.weights.vertices) == 2 assert len(vc.indexes) == 0 diff --git a/test/architecture/test_vertex.py b/test/architecture/test_vertex.py index cbd179ff..b194af6c 100644 --- a/test/architecture/test_vertex.py +++ b/test/architecture/test_vertex.py @@ -25,25 +25,25 @@ def test_field_with_explicit_type(): assert field.type == field_type.value # Test case insensitive - field = Field(name="test", type="int") + field = Field(name="test", type=FieldType.INT) assert field.type == "INT" - field = Field(name="test", type="string") + field = Field(name="test", type=FieldType.STRING) assert field.type == "STRING" def test_field_type_validation(): """Test that invalid field types raise errors.""" with pytest.raises(ValueError, match="not allowed"): - Field(name="test", type="INVALID_TYPE") + Field.from_dict({"name": "test", "type": "INVALID_TYPE"}) with pytest.raises(ValueError, match="not allowed"): - Field(name="test", type="invalid") + Field.from_dict({"name": "test", "type": "invalid"}) def test_field_string_behavior(): """Test that Field objects behave like strings.""" - field = Field(name="test_field", type="INT") + field = Field(name="test_field", type=FieldType.INT) # String conversion assert str(field) == "test_field" @@ -112,10 +112,10 @@ def test_vertex_with_string_fields_dict_compatibility(): def test_vertex_with_field_objects(): """Test Vertex creation with list of Field objects.""" fields = [ - Field(name="id", type="INT"), - Field(name="name", type="STRING"), - Field(name="age", type="INT"), - Field(name="active", type="BOOL"), + Field(name="id", type=FieldType.INT), + Field(name="name", type=FieldType.STRING), + Field(name="age", type=FieldType.INT), + Field(name="active", type=FieldType.BOOL), ] vertex = Vertex(name="user", fields=fields) @@ -147,7 +147,7 @@ def test_vertex_mixed_field_inputs(): """Test Vertex creation with mixed field types.""" fields = [ "id", # string - Field(name="name", type="STRING"), # Field object + Field(name="name", type=FieldType.STRING), # Field object {"name": "email", "type": "STRING"}, # dict ] vertex = Vertex(name="user", fields=fields) # type: ignore[arg-type] @@ -182,8 +182,8 @@ def test_vertex_config_fields_with_objects(): vertex = Vertex( name="user", fields=[ - Field(name="id", type="INT"), - Field(name="name", type="STRING"), + Field(name="id", type=FieldType.INT), + Field(name="name", type=FieldType.STRING), ], ) config = VertexConfig(vertices=[vertex]) @@ -235,8 +235,8 @@ def test_vertex_indexes_work_with_field_objects(): vertex = Vertex( name="user", fields=[ - Field(name="id", type="INT"), - Field(name="email", type="STRING"), + Field(name="id", type=FieldType.INT), + Field(name="email", type=FieldType.STRING), ], indexes=[], # Will create default index ) @@ -299,9 +299,9 @@ def test_get_fields_with_defaults_tigergraph(): vertex = Vertex( name="user", fields=[ # type: ignore[arg-type] - Field(name="id", type="INT"), # Already has type + Field(name="id", type=FieldType.INT), # Already has type Field(name="name"), # None type - Field(name="email", type="STRING"), # Already has type + Field(name="email", type=FieldType.STRING), # Already has type "address", # String field (will be Field with None type) ], ) @@ -325,7 +325,7 @@ def test_get_fields_with_defaults_other_db(): vertex = Vertex( name="user", fields=[ - Field(name="id", type="INT"), + Field(name="id", type=FieldType.INT), Field(name="name"), # None type ], ) @@ -347,7 +347,7 @@ def test_get_fields_with_defaults_none(): vertex = Vertex( name="user", fields=[ - Field(name="id", type="INT"), + Field(name="id", type=FieldType.INT), Field(name="name"), # None type ], ) @@ -364,7 +364,7 @@ def test_vertex_config_fields_with_db_flavor(): vertex = Vertex( name="user", fields=[ - Field(name="id", type="INT"), + Field(name="id", type=FieldType.INT), Field(name="name"), # None type ], ) diff --git a/test/test_filters.py b/test/test_filters.py index 8ae1f72e..13fd5383 100644 --- a/test/test_filters.py +++ b/test/test_filters.py @@ -1,9 +1,9 @@ import pytest from graflo.filter.onto import ( + Clause, ComparisonOperator, Expression, - LeafClause, LogicalOperator, ) @@ -37,12 +37,14 @@ def and_clause(eq_clause, cong_clause): def test_none_leaf(none_clause): - lc = LeafClause(*none_clause) - assert "null" in lc() + lc = Clause.from_list(none_clause) + result = lc() + assert isinstance(result, str) + assert "null" in result def test_leaf_clause_construct(eq_clause): - lc = LeafClause(*eq_clause) + lc = Clause.from_list(eq_clause) assert lc.cmp_operator == ComparisonOperator.EQ assert lc() == 'doc["x"] == "1"' diff --git a/test/test_filters_python.py b/test/test_filters_python.py index 284f839e..e4670eec 100644 --- a/test/test_filters_python.py +++ b/test/test_filters_python.py @@ -1,7 +1,7 @@ import pytest import yaml -from graflo.filter.onto import Expression, LeafClause, LogicalOperator +from graflo.filter.onto import Clause, Expression, LogicalOperator from graflo.onto import ExpressionFlavor @@ -77,13 +77,13 @@ def filter_implication(clause_open, clause_b): def test_python_clause(clause_open): - lc = LeafClause(**clause_open) + lc = Clause(**clause_open) # kind=leaf inferred from operator (str) doc = {"name": "Open"} assert lc(**doc, kind=ExpressionFlavor.PYTHON) def test_condition_b(clause_b): - m = LeafClause(**clause_b) + m = Clause(**clause_b) # kind=leaf inferred from operator (str) doc = {"value": -1} assert m(value=1, kind=ExpressionFlavor.PYTHON) assert not m(kind=ExpressionFlavor.PYTHON, **doc) From ba7992089726109461eb189ce3b4f701d7fef405 Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Mon, 2 Feb 2026 17:03:24 +0100 Subject: [PATCH 2/7] pydantic transition: passing tests --- graflo/architecture/base.py | 5 +- graflo/data_source/api.py | 23 +- graflo/data_source/base.py | 14 +- graflo/data_source/file.py | 39 +--- graflo/data_source/memory.py | 9 +- graflo/data_source/registry.py | 17 +- graflo/data_source/sql.py | 24 +- graflo/onto.py | 168 +------------- graflo/util/onto.py | 257 +++++++++++----------- pyproject.toml | 1 - test/architecture/test_crossing_keys.py | 4 +- test/db/postgres/test_schema_inference.py | 24 +- test/test_patterns.py | 25 ++- uv.lock | 2 - 14 files changed, 202 insertions(+), 410 deletions(-) diff --git a/graflo/architecture/base.py b/graflo/architecture/base.py index 41cd88c8..97ee3cfb 100644 --- a/graflo/architecture/base.py +++ b/graflo/architecture/base.py @@ -14,9 +14,6 @@ class ConfigBaseModel(BaseModel): Provides YAML serialization/deserialization and standard configuration for all Pydantic models in the system. - - This replaces the JSONWizard/YAMLWizard functionality from dataclass-wizard - with Pydantic's superior validation and type safety. """ model_config = ConfigDict( @@ -96,7 +93,7 @@ def update(self, other: Self) -> None: raise TypeError( f"Expected {type(self).__name__} instance, got {type(other).__name__}" ) - for name in self.model_fields: + for name in type(self).model_fields: current = getattr(self, name) other_val = getattr(other, name) if other_val is None: diff --git a/graflo/data_source/api.py b/graflo/data_source/api.py index e708489e..80f04c0e 100644 --- a/graflo/data_source/api.py +++ b/graflo/data_source/api.py @@ -7,7 +7,6 @@ from __future__ import annotations -import dataclasses import logging from typing import Any, Iterator @@ -16,14 +15,15 @@ from requests.auth import HTTPBasicAuth, HTTPDigestAuth from urllib3.util.retry import Retry +from pydantic import Field + +from graflo.architecture.base import ConfigBaseModel from graflo.data_source.base import AbstractDataSource, DataSourceType -from graflo.onto import BaseDataclass logger = logging.getLogger(__name__) -@dataclasses.dataclass -class PaginationConfig(BaseDataclass): +class PaginationConfig(ConfigBaseModel): """Configuration for API pagination. Supports multiple pagination strategies: @@ -60,8 +60,7 @@ class PaginationConfig(BaseDataclass): data_path: str | None = None # JSON path to data array, None means root -@dataclasses.dataclass -class APIConfig(BaseDataclass): +class APIConfig(ConfigBaseModel): """Configuration for REST API data source. Attributes: @@ -83,20 +82,19 @@ class APIConfig(BaseDataclass): url: str method: str = "GET" - headers: dict[str, str] = dataclasses.field(default_factory=dict) + headers: dict[str, str] = Field(default_factory=dict) auth: dict[str, Any] | None = None - params: dict[str, Any] = dataclasses.field(default_factory=dict) + params: dict[str, Any] = Field(default_factory=dict) timeout: float | None = None retries: int = 0 retry_backoff_factor: float = 0.1 - retry_status_forcelist: list[int] = dataclasses.field( + retry_status_forcelist: list[int] = Field( default_factory=lambda: [500, 502, 503, 504] ) verify: bool = True pagination: PaginationConfig | None = None -@dataclasses.dataclass class APIDataSource(AbstractDataSource): """Data source for REST API endpoints. @@ -109,10 +107,7 @@ class APIDataSource(AbstractDataSource): """ config: APIConfig - - def __post_init__(self): - """Initialize the API data source.""" - self.source_type = DataSourceType.API + source_type: DataSourceType = DataSourceType.API def _create_session(self) -> requests.Session: """Create a requests session with retry configuration. diff --git a/graflo/data_source/base.py b/graflo/data_source/base.py index 0136abb3..30dd4072 100644 --- a/graflo/data_source/base.py +++ b/graflo/data_source/base.py @@ -8,7 +8,10 @@ import abc from typing import Iterator -from graflo.onto import BaseDataclass, BaseEnum +from pydantic import PrivateAttr + +from graflo.architecture.base import ConfigBaseModel +from graflo.onto import BaseEnum class DataSourceType(BaseEnum): @@ -26,7 +29,7 @@ class DataSourceType(BaseEnum): IN_MEMORY = "in_memory" -class AbstractDataSource(BaseDataclass, abc.ABC): +class AbstractDataSource(ConfigBaseModel, abc.ABC): """Abstract base class for all data sources. Data sources handle data retrieval from various sources and provide @@ -40,10 +43,7 @@ class AbstractDataSource(BaseDataclass, abc.ABC): """ source_type: DataSourceType - - def __post_init__(self): - """Initialize the data source after dataclass initialization.""" - self._resource_name: str | None = None + _resource_name: str | None = PrivateAttr(default=None) @property def resource_name(self) -> str | None: @@ -55,7 +55,7 @@ def resource_name(self) -> str | None: return self._resource_name @resource_name.setter - def resource_name(self, value: str | None): + def resource_name(self, value: str | None) -> None: """Set the resource name this data source maps to. Args: diff --git a/graflo/data_source/file.py b/graflo/data_source/file.py index a9b77783..3f7685e3 100644 --- a/graflo/data_source/file.py +++ b/graflo/data_source/file.py @@ -5,16 +5,16 @@ chunker logic for efficient batch processing. """ -import dataclasses from pathlib import Path from typing import Iterator +from pydantic import field_validator + from graflo.architecture.onto import EncodingType from graflo.data_source.base import AbstractDataSource, DataSourceType from graflo.util.chunker import ChunkerFactory, ChunkerType -@dataclasses.dataclass class FileDataSource(AbstractDataSource): """Base class for file-based data sources. @@ -30,12 +30,12 @@ class FileDataSource(AbstractDataSource): path: Path | str file_type: str | None = None encoding: EncodingType = EncodingType.UTF_8 + source_type: DataSourceType = DataSourceType.FILE - def __post_init__(self): - """Initialize the file data source.""" - self.source_type = DataSourceType.FILE - if isinstance(self.path, str): - self.path = Path(self.path) + @field_validator("path", mode="before") + @classmethod + def _path_to_path(cls, v: Path | str) -> Path: + return Path(v) if isinstance(v, str) else v def iter_batches( self, batch_size: int = 1000, limit: int | None = None @@ -73,7 +73,6 @@ def iter_batches( yield batch -@dataclasses.dataclass class JsonFileDataSource(FileDataSource): """Data source for JSON files. @@ -86,13 +85,9 @@ class JsonFileDataSource(FileDataSource): encoding: File encoding (default: UTF_8) """ - def __post_init__(self): - """Initialize the JSON file data source.""" - super().__post_init__() - self.file_type = ChunkerType.JSON.value + file_type: str = ChunkerType.JSON.value -@dataclasses.dataclass class JsonlFileDataSource(FileDataSource): """Data source for JSONL (JSON Lines) files. @@ -104,13 +99,9 @@ class JsonlFileDataSource(FileDataSource): encoding: File encoding (default: UTF_8) """ - def __post_init__(self): - """Initialize the JSONL file data source.""" - super().__post_init__() - self.file_type = ChunkerType.JSONL.value + file_type: str = ChunkerType.JSONL.value -@dataclasses.dataclass class TableFileDataSource(FileDataSource): """Data source for CSV/TSV files. @@ -124,14 +115,9 @@ class TableFileDataSource(FileDataSource): """ sep: str = "," - - def __post_init__(self): - """Initialize the table file data source.""" - super().__post_init__() - self.file_type = ChunkerType.TABLE.value + file_type: str = ChunkerType.TABLE.value -@dataclasses.dataclass class ParquetFileDataSource(FileDataSource): """Data source for Parquet files. @@ -142,7 +128,4 @@ class ParquetFileDataSource(FileDataSource): path: Path to the Parquet file """ - def __post_init__(self): - """Initialize the Parquet file data source.""" - super().__post_init__() - self.file_type = ChunkerType.PARQUET.value + file_type: str = ChunkerType.PARQUET.value diff --git a/graflo/data_source/memory.py b/graflo/data_source/memory.py index cf7a3ea6..c598683b 100644 --- a/graflo/data_source/memory.py +++ b/graflo/data_source/memory.py @@ -4,7 +4,6 @@ including lists of dictionaries, lists of lists, and Pandas DataFrames. """ -import dataclasses from typing import Iterator import pandas as pd @@ -13,7 +12,6 @@ from graflo.util.chunker import ChunkerFactory -@dataclasses.dataclass class InMemoryDataSource(AbstractDataSource): """Data source for in-memory data structures. @@ -25,12 +23,11 @@ class InMemoryDataSource(AbstractDataSource): columns: Optional column names for list[list] data """ + model_config = {"arbitrary_types_allowed": True} + data: list[dict] | list[list] | pd.DataFrame columns: list[str] | None = None - - def __post_init__(self): - """Initialize the in-memory data source.""" - self.source_type = DataSourceType.IN_MEMORY + source_type: DataSourceType = DataSourceType.IN_MEMORY def iter_batches( self, batch_size: int = 1000, limit: int | None = None diff --git a/graflo/data_source/registry.py b/graflo/data_source/registry.py index b87bf153..834655ff 100644 --- a/graflo/data_source/registry.py +++ b/graflo/data_source/registry.py @@ -7,17 +7,12 @@ from __future__ import annotations -import dataclasses -from typing import TYPE_CHECKING +from graflo.architecture.base import ConfigBaseModel +from graflo.data_source.base import AbstractDataSource +from pydantic import Field -from graflo.onto import BaseDataclass -if TYPE_CHECKING: - from graflo.data_source.base import AbstractDataSource - - -@dataclasses.dataclass -class DataSourceRegistry(BaseDataclass): +class DataSourceRegistry(ConfigBaseModel): """Registry for mapping data sources to resource names. This class maintains a mapping from resource names to lists of data sources. @@ -28,9 +23,7 @@ class DataSourceRegistry(BaseDataclass): sources: Dictionary mapping resource names to lists of data sources """ - sources: dict[str, list[AbstractDataSource]] = dataclasses.field( - default_factory=dict - ) + sources: dict[str, list[AbstractDataSource]] = Field(default_factory=dict) def register(self, data_source: AbstractDataSource, resource_name: str) -> None: """Register a data source for a resource. diff --git a/graflo/data_source/sql.py b/graflo/data_source/sql.py index fc193540..b421144d 100644 --- a/graflo/data_source/sql.py +++ b/graflo/data_source/sql.py @@ -4,21 +4,20 @@ configuration. It supports parameterized queries and pagination. """ -import dataclasses import logging from typing import Any, Iterator +from pydantic import Field, PrivateAttr from sqlalchemy import create_engine, text from sqlalchemy.engine import Engine +from graflo.architecture.base import ConfigBaseModel from graflo.data_source.base import AbstractDataSource, DataSourceType -from graflo.onto import BaseDataclass logger = logging.getLogger(__name__) -@dataclasses.dataclass -class SQLConfig(BaseDataclass): +class SQLConfig(ConfigBaseModel): """Configuration for SQL data source. Uses SQLAlchemy connection string format. @@ -34,12 +33,11 @@ class SQLConfig(BaseDataclass): connection_string: str query: str - params: dict[str, Any] = dataclasses.field(default_factory=dict) + params: dict[str, Any] = Field(default_factory=dict) pagination: bool = True page_size: int = 1000 -@dataclasses.dataclass class SQLDataSource(AbstractDataSource): """Data source for SQL databases. @@ -53,12 +51,8 @@ class SQLDataSource(AbstractDataSource): """ config: SQLConfig - engine: Engine | None = dataclasses.field(default=None, init=False) - - def __post_init__(self): - """Initialize the SQL data source.""" - super().__post_init__() - self.source_type = DataSourceType.SQL + source_type: DataSourceType = DataSourceType.SQL + _engine: Engine | None = PrivateAttr(default=None) def _get_engine(self) -> Engine: """Get or create SQLAlchemy engine. @@ -66,9 +60,9 @@ def _get_engine(self) -> Engine: Returns: SQLAlchemy engine instance """ - if self.engine is None: - self.engine = create_engine(self.config.connection_string) - return self.engine + if self._engine is None: + self._engine = create_engine(self.config.connection_string) + return self._engine def _add_pagination(self, query: str, offset: int, limit: int) -> str: """Add pagination to SQL query. diff --git a/graflo/onto.py b/graflo/onto.py index c5636d24..ce66cfd5 100644 --- a/graflo/onto.py +++ b/graflo/onto.py @@ -1,12 +1,11 @@ """Core ontology and base classes for graph database operations. This module provides the fundamental data structures and base classes used throughout -the graph database system. It includes base classes for enums, dataclasses, and +the graph database system. It includes base classes for enums and database-specific configurations. Key Components: - BaseEnum: Base class for string-based enumerations with flexible membership testing - - BaseDataclass: Base class for dataclasses with JSON/YAML serialization support - ExpressionFlavor: Enum for expression language types - AggregationType: Enum for supported aggregation operations @@ -18,12 +17,8 @@ >>> "invalid" in MyEnum # False """ -import dataclasses -from copy import deepcopy from enum import EnumMeta from strenum import StrEnum -from dataclass_wizard import JSONWizard, YAMLWizard -from dataclass_wizard.enums import DateTimeTo class MetaEnum(EnumMeta): @@ -135,167 +130,6 @@ class AggregationType(BaseEnum): SORTED_UNIQUE = "SORTED_UNIQUE" -@dataclasses.dataclass -class BaseDataclass(JSONWizard, JSONWizard.Meta, YAMLWizard): - """Base class for dataclasses with serialization support. - - This class provides a foundation for dataclasses with JSON and YAML - serialization capabilities. It includes methods for updating instances - and accessing field members. - - Attributes: - marshal_date_time_as: Format for datetime serialization - key_transform_with_dump: Key transformation style for serialization - """ - - class _(JSONWizard.Meta): - """Meta configuration for serialization. - - Set skip_defaults=True here to exclude fields with default values - by default when serializing. Can still be overridden per-call. - """ - - skip_defaults = True - - marshal_date_time_as = DateTimeTo.ISO_FORMAT - key_transform_with_dump = "SNAKE" - - def to_dict(self, skip_defaults: bool | None = None, **kwargs): - """Convert instance to dictionary with enums serialized as strings. - - This method overrides the default to_dict to ensure that all BaseEnum - instances are automatically converted to their string values during - serialization, making YAML/JSON output cleaner and more portable. - - Args: - skip_defaults: If True, fields with default values are excluded. - If None, uses the Meta class skip_defaults setting. - **kwargs: Additional arguments passed to parent to_dict method - - Returns: - dict: Dictionary representation with enums as strings - """ - result = super().to_dict(skip_defaults=skip_defaults, **kwargs) - return self._convert_enums_to_strings(result) - - def to_yaml(self, skip_defaults: bool | None = None, **kwargs) -> str: - """Convert instance to YAML string with enums serialized as strings. - - Args: - skip_defaults: If True, fields with default values are excluded. - If None, uses the Meta class skip_defaults setting. - **kwargs: Additional arguments passed to yaml.safe_dump - - Returns: - str: YAML string representation with enums as strings - """ - # Convert to dict first (with enum conversion), then to YAML - data = self.to_dict(skip_defaults=skip_defaults) - try: - import yaml - - return yaml.safe_dump(data, **kwargs) - except ImportError: - # Fallback to parent method if yaml not available - return super().to_yaml(skip_defaults=skip_defaults, **kwargs) - - def to_yaml_file( - self, file_path: str, skip_defaults: bool | None = None, **kwargs - ) -> None: - """Write instance to YAML file with enums serialized as strings. - - Args: - file_path: Path to the YAML file to write - skip_defaults: If True, fields with default values are excluded. - If None, uses the Meta class skip_defaults setting. - **kwargs: Additional arguments passed to yaml.safe_dump - """ - # Convert to dict first (with enum conversion), then write to file - data = self.to_dict(skip_defaults=skip_defaults) - try: - import yaml - - with open(file_path, "w") as f: - yaml.safe_dump(data, f, **kwargs) - except ImportError: - # Fallback to parent method if yaml not available - super().to_yaml_file(file_path, skip_defaults=skip_defaults, **kwargs) - - @staticmethod - def _convert_enums_to_strings(obj): - """Recursively convert BaseEnum instances to their string values. - - Args: - obj: Object to convert (dict, list, enum, or other) - - Returns: - Object with BaseEnum instances converted to strings - """ - if isinstance(obj, BaseEnum): - return obj.value - elif isinstance(obj, dict): - return { - k: BaseDataclass._convert_enums_to_strings(v) for k, v in obj.items() - } - elif isinstance(obj, list): - return [BaseDataclass._convert_enums_to_strings(item) for item in obj] - elif isinstance(obj, tuple): - return tuple(BaseDataclass._convert_enums_to_strings(item) for item in obj) - elif isinstance(obj, set): - return {BaseDataclass._convert_enums_to_strings(item) for item in obj} - else: - return obj - - def update(self, other): - """Update this instance with values from another instance. - - This method performs a deep update of the instance's attributes using - values from another instance of the same type. It handles different - types of attributes (sets, lists, dicts, dataclasses) appropriately. - - Args: - other: Another instance of the same type to update from - - Raises: - TypeError: If other is not an instance of the same type - """ - if not isinstance(other, type(self)): - raise TypeError( - f"Expected {type(self).__name__} instance, got {type(other).__name__}" - ) - - for field in dataclasses.fields(self): - name = field.name - current_value = getattr(self, name) - other_value = getattr(other, name) - - if other_value is None: - pass - elif isinstance(other_value, set): - setattr(self, name, current_value | deepcopy(other_value)) - elif isinstance(other_value, list): - setattr(self, name, current_value + deepcopy(other_value)) - elif isinstance(other_value, dict): - setattr(self, name, {**current_value, **deepcopy(other_value)}) - elif dataclasses.is_dataclass(type(other_value)): - if current_value is not None: - current_value.update(other_value) - else: - setattr(self, name, deepcopy(other_value)) - else: - if current_value is None: - setattr(self, name, other_value) - - @classmethod - def get_fields_members(cls): - """Get list of field members excluding private ones. - - Returns: - list[str]: List of public field names - """ - return [k for k in cls.__annotations__ if not k.startswith("_")] - - class DBType(StrEnum, metaclass=MetaEnum): """Enum representing different types of databases. diff --git a/graflo/util/onto.py b/graflo/util/onto.py index eab28414..c68e2706 100644 --- a/graflo/util/onto.py +++ b/graflo/util/onto.py @@ -1,6 +1,6 @@ """Utility ontology classes for resource patterns and configurations. -This module provides data classes for managing resource patterns (files and database tables) +This module provides Pydantic models for managing resource patterns (files and database tables) and configurations used throughout the system. These classes support resource discovery, pattern matching, and configuration management. @@ -12,12 +12,15 @@ """ import abc -import dataclasses +import copy import pathlib import re -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Self -from graflo.onto import BaseDataclass, BaseEnum +from pydantic import AliasChoices, Field, model_validator + +from graflo.architecture.base import ConfigBaseModel +from graflo.onto import BaseEnum if TYPE_CHECKING: from graflo.db.connection.onto import PostgresConfig @@ -45,8 +48,7 @@ class ResourceType(BaseEnum): SQL_TABLE = "sql_table" -@dataclasses.dataclass -class ResourcePattern(BaseDataclass, abc.ABC): +class ResourcePattern(ConfigBaseModel, abc.ABC): """Abstract base class for resource patterns (files or tables). Provides common API for pattern matching and resource identification. @@ -80,7 +82,6 @@ def get_resource_type(self) -> ResourceType: pass -@dataclasses.dataclass class FilePattern(ResourcePattern): """Pattern for matching files. @@ -93,48 +94,38 @@ class FilePattern(ResourcePattern): date_range_days: Number of days after start date (used with date_range_start) """ - class _(BaseDataclass.Meta): - tag = "file" - regex: str | None = None - sub_path: None | pathlib.Path = dataclasses.field( - default_factory=lambda: pathlib.Path("./") - ) + sub_path: pathlib.Path = Field(default_factory=lambda: pathlib.Path("./")) date_field: str | None = None date_filter: str | None = None date_range_start: str | None = None date_range_days: int | None = None - def __post_init__(self): - """Initialize and validate the file pattern. - - Ensures that sub_path is a Path object and is not None. - """ - if self.sub_path is not None and not isinstance(self.sub_path, pathlib.Path): - self.sub_path = pathlib.Path(self.sub_path) - elif self.sub_path is None: - self.sub_path = pathlib.Path("./") - assert self.sub_path is not None - # Validate date filtering parameters (note: date filtering for files is not yet implemented) + @model_validator(mode="after") + def _validate_file_pattern(self) -> Self: + """Ensure sub_path is a Path and validate date filtering parameters.""" + if not isinstance(self.sub_path, pathlib.Path): + object.__setattr__(self, "sub_path", pathlib.Path(self.sub_path)) if (self.date_filter or self.date_range_start) and not self.date_field: raise ValueError( "date_field is required when using date_filter or date_range_start" ) if self.date_range_days is not None and not self.date_range_start: raise ValueError("date_range_start is required when using date_range_days") + return self - def matches(self, filename: str) -> bool: + def matches(self, resource_identifier: str) -> bool: """Check if pattern matches a filename. Args: - filename: Filename to match + resource_identifier: Filename to match Returns: bool: True if pattern matches """ if self.regex is None: return False - return bool(re.match(self.regex, filename)) + return bool(re.match(self.regex, resource_identifier)) def get_resource_type(self) -> ResourceType: """Get resource type. @@ -146,7 +137,6 @@ def get_resource_type(self) -> ResourceType: return ResourceType.FILE -@dataclasses.dataclass class TablePattern(ResourcePattern): """Pattern for matching database tables. @@ -160,9 +150,6 @@ class TablePattern(ResourcePattern): date_range_days: Number of days after start date (used with date_range_start) """ - class _(BaseDataclass.Meta): - tag = "table" - table_name: str = "" schema_name: str | None = None database: str | None = None @@ -171,23 +158,24 @@ class _(BaseDataclass.Meta): date_range_start: str | None = None date_range_days: int | None = None - def __post_init__(self): - """Validate table pattern after initialization.""" + @model_validator(mode="after") + def _validate_table_pattern(self) -> Self: + """Validate table_name and date filtering parameters.""" if not self.table_name: raise ValueError("table_name is required for TablePattern") - # Validate date filtering parameters if (self.date_filter or self.date_range_start) and not self.date_field: raise ValueError( "date_field is required when using date_filter or date_range_start" ) if self.date_range_days is not None and not self.date_range_start: raise ValueError("date_range_start is required when using date_range_days") + return self - def matches(self, table_identifier: str) -> bool: + def matches(self, resource_identifier: str) -> bool: """Check if pattern matches a table name. Args: - table_identifier: Table name to match (format: schema.table or just table) + resource_identifier: Table name to match (format: schema.table or just table) Returns: bool: True if pattern matches @@ -203,13 +191,13 @@ def matches(self, table_identifier: str) -> bool: # Exact match pattern pattern = re.compile(f"^{re.escape(self.table_name)}$") - # Check if table_identifier matches - if pattern.match(table_identifier): + # Check if resource_identifier matches + if pattern.match(resource_identifier): return True # If schema_name is specified, also check schema.table format if self.schema_name: - full_name = f"{self.schema_name}.{table_identifier}" + full_name = f"{self.schema_name}.{resource_identifier}" if pattern.match(full_name): return True @@ -259,8 +247,7 @@ def build_where_clause(self) -> str: return "" -@dataclasses.dataclass -class Patterns(BaseDataclass): +class Patterns(ConfigBaseModel): """Collection of named resource patterns with connection management. This class manages a collection of resource patterns (files or tables), @@ -281,24 +268,29 @@ class Patterns(BaseDataclass): postgres_table_configs: Dictionary mapping resource_name to (config_key, schema_name, table_name) """ - file_patterns: dict[str, FilePattern] = dataclasses.field(default_factory=dict) - table_patterns: dict[str, TablePattern] = dataclasses.field(default_factory=dict) - postgres_configs: dict[tuple[str, str | None], Any] = dataclasses.field( - default_factory=dict, metadata={"exclude": True} + file_patterns: dict[str, FilePattern] = Field(default_factory=dict) + table_patterns: dict[str, TablePattern] = Field(default_factory=dict) + postgres_configs: dict[tuple[str, str | None], Any] = Field( + default_factory=dict, exclude=True ) - postgres_table_configs: dict[str, tuple[str, str | None, str]] = dataclasses.field( - default_factory=dict, metadata={"exclude": True} + postgres_table_configs: dict[str, tuple[str, str | None, str]] = Field( + default_factory=dict, exclude=True ) - # Initialization parameters (not stored as fields, excluded from serialization) - # Use Any for _postgres_connections to avoid type evaluation issues with dataclass_wizard - _resource_mapping: dict[str, str | tuple[str, str]] | None = dataclasses.field( - default=None, repr=False, compare=False, metadata={"exclude": True} + # Initialization parameters (not stored in serialization); accept both _name and name + resource_mapping: dict[str, str | tuple[str, str]] | None = Field( + default=None, + exclude=True, + validation_alias=AliasChoices("_resource_mapping", "resource_mapping"), ) - _postgres_connections: dict[str, Any] | None = dataclasses.field( - default=None, repr=False, compare=False, metadata={"exclude": True} + postgres_connections: dict[str, Any] | None = Field( + default=None, + exclude=True, + validation_alias=AliasChoices("_postgres_connections", "postgres_connections"), ) - _postgres_tables: dict[str, tuple[str, str | None, str]] | None = dataclasses.field( - default=None, repr=False, compare=False, metadata={"exclude": True} + postgres_tables: dict[str, tuple[str, str | None, str]] | None = Field( + default=None, + exclude=True, + validation_alias=AliasChoices("_postgres_tables", "postgres_tables"), ) @property @@ -313,76 +305,18 @@ def patterns(self) -> dict[str, TablePattern | FilePattern]: result.update(self.table_patterns) return result - @classmethod - def from_dict(cls, data: dict): - """Create Patterns from dictionary, supporting both old and new YAML formats. - - Supports two formats: - 1. New format: Separate `file_patterns` and `table_patterns` fields - 2. Old format: Unified `patterns` field with `__tag__` markers (for backward compatibility) - - Args: - data: Dictionary containing patterns data - - Returns: - Patterns: New Patterns instance with properly deserialized patterns - """ - # Check if using new format (separate file_patterns/table_patterns) - if "file_patterns" in data or "table_patterns" in data: - # New format - let JSONWizard handle it directly (no union types!) - return super().from_dict(data) - - # Old format - convert unified patterns dict to separate fields - patterns_data = data.get("patterns", {}) - data_copy = {k: v for k, v in data.items() if k != "patterns"} - - # Call parent from_dict (JSONWizard) to handle other fields - instance = super().from_dict(data_copy) - - # Convert old format to new format - for pattern_name, pattern_dict in patterns_data.items(): - if pattern_dict is None: - continue - # Check for tag to determine pattern type - tag = pattern_dict.get("__tag__") - if tag == "file": - pattern = FilePattern.from_dict(pattern_dict) - instance.file_patterns[pattern_name] = pattern - elif tag == "table": - pattern = TablePattern.from_dict(pattern_dict) - instance.table_patterns[pattern_name] = pattern - else: - # Try to infer from structure if no tag - if "table_name" in pattern_dict: - pattern = TablePattern.from_dict(pattern_dict) - instance.table_patterns[pattern_name] = pattern - elif "regex" in pattern_dict or "sub_path" in pattern_dict: - pattern = FilePattern.from_dict(pattern_dict) - instance.file_patterns[pattern_name] = pattern - else: - raise ValueError( - f"Unable to determine pattern type for '{pattern_name}'. " - "Expected either '__tag__: file' or '__tag__: table', " - "or pattern fields (table_name for TablePattern, " - "regex/sub_path for FilePattern)" - ) - - return instance - - def __post_init__(self): - """Initialize Patterns from resource mappings and PostgreSQL configurations.""" - # Store PostgreSQL connection configs - if self._postgres_connections: - for config_key, config in self._postgres_connections.items(): + @model_validator(mode="after") + def _populate_from_mappings(self) -> Self: + """Populate file_patterns/table_patterns from resource mappings and PostgreSQL configs.""" + if self.postgres_connections: + for config_key, config in self.postgres_connections.items(): if config is not None: - schema_name = config.schema_name + schema_name = getattr(config, "schema_name", None) self.postgres_configs[(config_key, schema_name)] = config - # Process resource mappings - if self._resource_mapping: - for resource_name, resource_spec in self._resource_mapping.items(): + if self.resource_mapping: + for resource_name, resource_spec in self.resource_mapping.items(): if isinstance(resource_spec, str): - # File path - create FilePattern file_path = pathlib.Path(resource_spec) pattern = FilePattern( regex=f"^{re.escape(file_path.name)}$", @@ -391,36 +325,33 @@ def __post_init__(self): ) self.file_patterns[resource_name] = pattern elif isinstance(resource_spec, tuple) and len(resource_spec) == 2: - # (config_key, table_name) tuple - create TablePattern config_key, table_name = resource_spec - # Find the schema_name from the config config = ( - self._postgres_connections.get(config_key) - if self._postgres_connections + self.postgres_connections.get(config_key) + if self.postgres_connections else None ) - schema_name = config.schema_name if config else None - + schema_name = ( + getattr(config, "schema_name", None) if config else None + ) pattern = TablePattern( table_name=table_name, schema_name=schema_name, resource_name=resource_name, ) self.table_patterns[resource_name] = pattern - # Store the config mapping self.postgres_table_configs[resource_name] = ( config_key, schema_name, table_name, ) - # Process explicit postgres_tables mapping - if self._postgres_tables: + if self.postgres_tables: for table_name, ( config_key, schema_name, actual_table_name, - ) in self._postgres_tables.items(): + ) in self.postgres_tables.items(): pattern = TablePattern( table_name=actual_table_name, schema_name=schema_name, @@ -432,6 +363,70 @@ def __post_init__(self): schema_name, actual_table_name, ) + return self + + @classmethod + def from_dict(cls, data: dict[str, Any] | list[Any]) -> Self: + """Create Patterns from dictionary, supporting both old and new YAML formats. + + Supports two formats: + 1. New format: Separate `file_patterns` and `table_patterns` fields + 2. Old format: Unified `patterns` field with `__tag__` markers (for backward compatibility) + + Args: + data: Dictionary containing patterns data (or list for base compatibility) + + Returns: + Patterns: New Patterns instance with properly deserialized patterns + """ + if isinstance(data, list): + return cls.model_validate(data) + if "file_patterns" in data or "table_patterns" in data: + # Strip __tag__ from nested pattern dicts so extra="forbid" does not fail + data = copy.deepcopy(data) + for key in ("file_patterns", "table_patterns"): + if key in data and isinstance(data[key], dict): + for name, val in data[key].items(): + if isinstance(val, dict) and "__tag__" in val: + data[key][name] = { + k: v for k, v in val.items() if k != "__tag__" + } + return cls.model_validate(data) + + patterns_data = data.get("patterns", {}) + data_copy = {k: v for k, v in data.items() if k != "patterns"} + instance = cls.model_validate(data_copy) + + for pattern_name, raw in patterns_data.items(): + if raw is None: + continue + pattern_dict = {k: v for k, v in raw.items() if k != "__tag__"} + tag_val = raw.get("__tag__") if isinstance(raw, dict) else None + if tag_val == "file": + instance.file_patterns[pattern_name] = FilePattern.model_validate( + pattern_dict + ) + elif tag_val == "table": + instance.table_patterns[pattern_name] = TablePattern.model_validate( + pattern_dict + ) + else: + if "table_name" in pattern_dict: + instance.table_patterns[pattern_name] = TablePattern.model_validate( + pattern_dict + ) + elif "regex" in pattern_dict or "sub_path" in pattern_dict: + instance.file_patterns[pattern_name] = FilePattern.model_validate( + pattern_dict + ) + else: + raise ValueError( + f"Unable to determine pattern type for '{pattern_name}'. " + "Expected either '__tag__: file' or '__tag__: table', " + "or pattern fields (table_name for TablePattern, " + "regex/sub_path for FilePattern)" + ) + return instance def add_file_pattern(self, name: str, file_pattern: FilePattern): """Add a file pattern to the collection. diff --git a/pyproject.toml b/pyproject.toml index 4da33424..2c4fd301 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,6 @@ classifiers = [ ] dependencies = [ "click>=8.2.0,<9", - "dataclass-wizard>=0.34.0", "falkordb>=1.0.9", "ijson>=3.2.3,<4", "neo4j>=5.22.0,<6", diff --git a/test/architecture/test_crossing_keys.py b/test/architecture/test_crossing_keys.py index 9c4ac0df..bf7ec578 100644 --- a/test/architecture/test_crossing_keys.py +++ b/test/architecture/test_crossing_keys.py @@ -32,7 +32,9 @@ def test_actor_wrapper_openalex_implicit( resource_cross_implicit, vertex_config_cross, sample_cross ): ctx = ActionContext() - anw = ActorWrapper(*resource_cross_implicit) + # Pass list as single arg so it is interpreted as pipeline (DescendActor), not + # as a single step; finish_init then adds VertexActors from transform outputs. + anw = ActorWrapper(resource_cross_implicit) anw.finish_init(transforms={}, vertex_config=vertex_config_cross) ctx = anw(ctx, doc=sample_cross) diff --git a/test/db/postgres/test_schema_inference.py b/test/db/postgres/test_schema_inference.py index 3b134288..4cbd7c0c 100644 --- a/test/db/postgres/test_schema_inference.py +++ b/test/db/postgres/test_schema_inference.py @@ -64,14 +64,12 @@ def test_infer_schema_from_postgres(conn_conf, load_mock_schema): # Verify field types (id should be INT, name/email should be STRING) id_field = next(f for f in users_vertex.fields if f.name == "id") assert id_field.type is not None, "id field should have a type" - assert id_field.type.value == "INT", ( - f"Expected id type to be INT, got {id_field.type.value}" - ) + assert id_field.type == "INT", f"Expected id type to be INT, got {id_field.type}" name_field = next(f for f in users_vertex.fields if f.name == "name") assert name_field.type is not None, "name field should have a type" - assert name_field.type.value == "STRING", ( - f"Expected name type to be STRING, got {name_field.type.value}" + assert name_field.type == "STRING", ( + f"Expected name type to be STRING, got {name_field.type}" ) # Verify datetime field type (created_at should be DATETIME) @@ -80,8 +78,8 @@ def test_infer_schema_from_postgres(conn_conf, load_mock_schema): ) if created_at_field: assert created_at_field.type is not None, "created_at field should have a type" - assert created_at_field.type.value == "DATETIME", ( - f"Expected created_at type to be DATETIME, got {created_at_field.type.value}" + assert created_at_field.type == "DATETIME", ( + f"Expected created_at type to be DATETIME, got {created_at_field.type}" ) # Verify purchases edge structure @@ -145,7 +143,7 @@ def test_infer_schema_from_postgres(conn_conf, load_mock_schema): print(f"\nVertices ({len(schema.vertex_config.vertices)}):") for v in schema.vertex_config.vertices: field_types = ", ".join( - [f"{f.name}:{f.type.value if f.type else 'None'}" for f in v.fields[:5]] + [f"{f.name}:{f.type if f.type else 'None'}" for f in v.fields[:5]] ) print(f" - {v.name}: {field_types}...") @@ -250,16 +248,16 @@ def test_infer_schema_with_pg_catalog_fallback(conn_conf, load_mock_schema): assert id_field.type is not None, ( "id field should have a type when using pg_catalog" ) - assert id_field.type.value == "INT", ( - f"Expected id type to be INT when using pg_catalog, got {id_field.type.value}" + assert id_field.type == "INT", ( + f"Expected id type to be INT when using pg_catalog, got {id_field.type}" ) name_field = next(f for f in users_vertex.fields if f.name == "name") assert name_field.type is not None, ( "name field should have a type when using pg_catalog" ) - assert name_field.type.value == "STRING", ( - f"Expected name type to be STRING when using pg_catalog, got {name_field.type.value}" + assert name_field.type == "STRING", ( + f"Expected name type to be STRING when using pg_catalog, got {name_field.type}" ) # Verify purchases edge structure - should be correctly inferred via pg_catalog @@ -329,7 +327,7 @@ def test_infer_schema_with_pg_catalog_fallback(conn_conf, load_mock_schema): print(f"\nVertices ({len(schema.vertex_config.vertices)}):") for v in schema.vertex_config.vertices: field_types = ", ".join( - [f"{f.name}:{f.type.value if f.type else 'None'}" for f in v.fields[:5]] + [f"{f.name}:{f.type if f.type else 'None'}" for f in v.fields[:5]] ) print(f" - {v.name}: {field_types}...") diff --git a/test/test_patterns.py b/test/test_patterns.py index d662fddf..2d988255 100644 --- a/test/test_patterns.py +++ b/test/test_patterns.py @@ -19,11 +19,15 @@ def test_patterns(): patterns.add_file_pattern("a", pattern_a) patterns.add_file_pattern("b", pattern_b) - # Test that patterns work correctly - assert patterns.patterns["a"].sub_path is not None - assert isinstance(patterns.patterns["a"].sub_path / "a", pathlib.Path) - assert patterns.patterns["b"].sub_path is not None - assert str(patterns.patterns["b"].sub_path / "a") == "a" + # Test that patterns work correctly (narrow to FilePattern for .sub_path) + pattern_a_loaded = patterns.patterns["a"] + pattern_b_loaded = patterns.patterns["b"] + assert isinstance(pattern_a_loaded, FilePattern) + assert isinstance(pattern_b_loaded, FilePattern) + assert pattern_a_loaded.sub_path is not None + assert isinstance(pattern_a_loaded.sub_path / "a", pathlib.Path) + assert pattern_b_loaded.sub_path is not None + assert str(pattern_b_loaded.sub_path / "a") == "a" # Test that patterns can be accessed by name assert "a" in patterns.patterns @@ -229,10 +233,13 @@ def test_patterns_with_filtering(): ) patterns.add_table_pattern("events", table_pattern) - # Verify patterns are stored correctly - assert patterns.patterns["users"].regex == r".*\.csv$" - assert patterns.patterns["events"].date_field == "created_at" - assert patterns.patterns["events"].date_filter == "> '2020-10-10'" + # Verify patterns are stored correctly (narrow to FilePattern for .regex) + users_pattern = patterns.patterns["users"] + events_pattern = patterns.patterns["events"] + assert isinstance(users_pattern, FilePattern) + assert users_pattern.regex == r".*\.csv$" + assert events_pattern.date_field == "created_at" + assert events_pattern.date_filter == "> '2020-10-10'" def test_table_pattern_sql_query_building(): diff --git a/uv.lock b/uv.lock index 21ec64c2..b1190be3 100644 --- a/uv.lock +++ b/uv.lock @@ -352,7 +352,6 @@ version = "1.4.5" source = { editable = "." } dependencies = [ { name = "click" }, - { name = "dataclass-wizard" }, { name = "falkordb" }, { name = "ijson" }, { name = "neo4j" }, @@ -399,7 +398,6 @@ plot = [ [package.metadata] requires-dist = [ { name = "click", specifier = ">=8.2.0,<9" }, - { name = "dataclass-wizard", specifier = ">=0.34.0" }, { name = "falkordb", specifier = ">=1.0.9" }, { name = "ijson", specifier = ">=3.2.3,<4" }, { name = "neo4j", specifier = ">=5.22.0,<6" }, From 0f05727cf744e2d8a52b13b1b3150184f0bb7aed Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Mon, 2 Feb 2026 17:42:35 +0100 Subject: [PATCH 3/7] imrpoved descriptions --- graflo/architecture/actor_config.py | 27 +----- graflo/architecture/edge.py | 122 ++++++++++++++++++++++------ graflo/architecture/onto.py | 60 +++++++++++--- graflo/architecture/transform.py | 52 +++++++++--- graflo/architecture/vertex.py | 75 ++++++++++++----- graflo/db/arango/conn.py | 12 +-- graflo/db/arango/query.py | 10 ++- graflo/db/arango/util.py | 10 ++- graflo/db/falkordb/conn.py | 8 +- graflo/db/memgraph/conn.py | 8 +- graflo/db/neo4j/conn.py | 8 +- graflo/db/tigergraph/conn.py | 14 ++-- graflo/filter/__init__.py | 7 +- graflo/filter/onto.py | 39 ++++----- test/test_filters.py | 15 ++-- test/test_filters_python.py | 14 ++-- 16 files changed, 313 insertions(+), 168 deletions(-) diff --git a/graflo/architecture/actor_config.py b/graflo/architecture/actor_config.py index 93f63a1a..263ced2b 100644 --- a/graflo/architecture/actor_config.py +++ b/graflo/architecture/actor_config.py @@ -17,6 +17,7 @@ from pydantic import Field, TypeAdapter, model_validator from graflo.architecture.base import ConfigBaseModel +from graflo.architecture.edge import EdgeBase logger = logging.getLogger(__name__) @@ -179,42 +180,20 @@ def set_type_and_flatten(cls, data: Any) -> Any: return data -class EdgeActorConfig(ConfigBaseModel): - """Configuration for an EdgeActor. Supports 'from'/'to' and 'source'/'target'.""" +class EdgeActorConfig(EdgeBase): + """Configuration for an EdgeActor. Extends EdgeBase; supports 'from'/'to' and 'source'/'target'.""" type: Literal["edge"] = Field( default="edge", description="Actor type discriminator" ) source: str = Field(..., alias="from", description="Source vertex type name") target: str = Field(..., alias="to", description="Target vertex type name") - match_source: str | None = Field( - default=None, description="Field for matching source vertices" - ) - match_target: str | None = Field( - default=None, description="Field for matching target vertices" - ) weights: dict[str, list[str]] | None = Field( default=None, description="Weight configuration" ) indexes: list[dict[str, Any]] | None = Field( default=None, description="Index configuration" ) - relation: str | None = Field( - default=None, description="Relation name (e.g. for Neo4j)" - ) - relation_field: str | None = Field( - default=None, description="Field to extract relation from" - ) - relation_from_key: bool = Field( - default=False, description="Extract relation from location key" - ) - exclude_target: str | None = Field( - default=None, description="Exclude target from edge rendering" - ) - exclude_source: str | None = Field( - default=None, description="Exclude source from edge rendering" - ) - match: str | None = Field(default=None, description="Match discriminant") @model_validator(mode="before") @classmethod diff --git a/graflo/architecture/edge.py b/graflo/architecture/edge.py index 92941170..8ab88a67 100644 --- a/graflo/architecture/edge.py +++ b/graflo/architecture/edge.py @@ -5,6 +5,7 @@ The module supports both ArangoDB and Neo4j through the DBType enum. Key Components: + - EdgeBase: Shared base for edge-like configs (Edge and EdgeActorConfig) - Edge: Represents an edge with its source, target, and configuration - EdgeConfig: Manages collections of edges and their configurations - WeightConfig: Configuration for edge weights and relationships @@ -87,8 +88,14 @@ class WeightConfig(ConfigBaseModel): ... ]) """ - vertices: list[Weight] = PydanticField(default_factory=list) - direct: list[Field] = PydanticField(default_factory=list) + vertices: list[Weight] = PydanticField( + default_factory=list, + description="List of weight definitions for vertex-based edge attributes.", + ) + direct: list[Field] = PydanticField( + default_factory=list, + description="Direct edge attributes (field names, Field objects, or dicts). Normalized to Field objects.", + ) @field_validator("direct", mode="before") @classmethod @@ -107,7 +114,56 @@ def direct_names(self) -> list[str]: return [field.name for field in self.direct] -class Edge(ConfigBaseModel): +class EdgeBase(ConfigBaseModel): + """Shared base for edge-like configs (Edge schema and EdgeActorConfig). + + Holds the common scalar fields so Edge and EdgeActorConfig stay in sync + without duplication. + """ + + source: str = PydanticField( + ..., + description="Source vertex type name (e.g. user, company).", + ) + target: str = PydanticField( + ..., + description="Target vertex type name (e.g. post, company).", + ) + match_source: str | None = PydanticField( + default=None, + description="Field used to match source vertices when creating edges.", + ) + match_target: str | None = PydanticField( + default=None, + description="Field used to match target vertices when creating edges.", + ) + relation: str | None = PydanticField( + default=None, + description="Relation/edge type name (e.g. Neo4j relationship type). For ArangoDB used as weight.", + ) + relation_field: str | None = PydanticField( + default=None, + description="Field name to store or read relation type (e.g. for TigerGraph).", + ) + relation_from_key: bool = PydanticField( + default=False, + description="If True, derive relation value from the location key during ingestion.", + ) + exclude_source: str | None = PydanticField( + default=None, + description="Exclude source vertices matching this field from edge creation.", + ) + exclude_target: str | None = PydanticField( + default=None, + description="Exclude target vertices matching this field from edge creation.", + ) + match: str | None = PydanticField( + default=None, + description="Match discriminant for edge creation.", + ) + + +class Edge(EdgeBase): """Represents an edge in the graph database. An edge connects two vertices and can have various configurations for @@ -130,35 +186,44 @@ class Edge(ConfigBaseModel): For ArangoDB, this corresponds to the edge collection name. """ - source: str - target: str - indexes: list[Index] = PydanticField(default_factory=list, alias="index") - weights: WeightConfig | None = None + indexes: list[Index] = PydanticField( + default_factory=list, + alias="index", + description="List of index definitions for this edge. Alias: index.", + ) + weights: WeightConfig | None = PydanticField( + default=None, + description="Optional edge weight/attribute configuration (direct fields and vertex-based weights).", + ) - # relation represents Class in neo4j, for arango it becomes a weight - relation: str | None = None _relation_dbname: str | None = PrivateAttr(default=None) - relation_field: str | None = None - relation_from_key: bool = False - - # used to create extra utility collections between the same type of vertices (A, B) - purpose: str | None = None - - match_source: str | None = None - match_target: str | None = None - exclude_source: str | None = None - exclude_target: str | None = None - match: str | None = None + purpose: str | None = PydanticField( + default=None, + description="Optional purpose label for utility edge collections between same vertex types.", + ) - type: EdgeType = EdgeType.DIRECT + type: EdgeType = PydanticField( + default=EdgeType.DIRECT, + description="Edge type: DIRECT (created during ingestion) or INDIRECT (pre-existing collection).", + ) - aux: bool = False # aux=True edges are init in the db but not considered by graflo + aux: bool = PydanticField( + default=False, + description="If True, edge is initialized in DB but not used by graflo ingestion.", + ) - by: str | None = None - graph_name: str | None = None # ArangoDB-specific: graph name (set in finish_init) - database_name: str | None = ( - None # ArangoDB-specific: edge collection name (set in finish_init) + by: str | None = PydanticField( + default=None, + description="For INDIRECT edges: vertex type name used to define the edge (set to dbname in finish_init).", + ) + graph_name: str | None = PydanticField( + default=None, + description="ArangoDB graph name (set in finish_init).", + ) + database_name: str | None = PydanticField( + default=None, + description="ArangoDB edge collection name (set in finish_init).", ) _source: str | None = PrivateAttr(default=None) @@ -314,7 +379,10 @@ class EdgeConfig(ConfigBaseModel): edges: List of edge configurations """ - edges: list[Edge] = PydanticField(default_factory=list) + edges: list[Edge] = PydanticField( + default_factory=list, + description="List of edge definitions (source, target, weights, indexes, relation, etc.).", + ) _edges_map: dict[EdgeId, Edge] = PrivateAttr() @model_validator(mode="after") diff --git a/graflo/architecture/onto.py b/graflo/architecture/onto.py index 1a320954..3b18d848 100644 --- a/graflo/architecture/onto.py +++ b/graflo/architecture/onto.py @@ -98,9 +98,18 @@ class ABCFields(ConfigBaseModel): fields: List of field names """ - name: str | None = None - fields: list[str] = Field(default_factory=list) - keep_vertex_name: bool = True + name: str | None = Field( + default=None, + description="Optional name of the entity (e.g. vertex name for composite field prefix).", + ) + fields: list[str] = Field( + default_factory=list, + description="List of field names for this entity.", + ) + keep_vertex_name: bool = Field( + default=True, + description="If True, composite field names use entity@field format; otherwise use field only.", + ) def cfield(self, x: str) -> str: """Creates a composite field name by combining the entity name with a field name. @@ -122,8 +131,14 @@ class Weight(ABCFields): filter: Dictionary of filter conditions for weights """ - map: dict = Field(default_factory=dict) - filter: dict = Field(default_factory=dict) + map: dict = Field( + default_factory=dict, + description="Mapping of field values to weight values for vertex-based edge attributes.", + ) + filter: dict = Field( + default_factory=dict, + description="Filter conditions applied when resolving vertex-based weights.", + ) class Index(ConfigBaseModel): @@ -139,13 +154,34 @@ class Index(ConfigBaseModel): exclude_edge_endpoints: Whether to exclude edge endpoints from index """ - name: str | None = None - fields: list[str] = Field(default_factory=list) - unique: bool = True - type: IndexType = IndexType.PERSISTENT - deduplicate: bool = True - sparse: bool = False - exclude_edge_endpoints: bool = False + name: str | None = Field( + default=None, + description="Optional index name. For edges, can reference a vertex name for composite fields.", + ) + fields: list[str] = Field( + default_factory=list, + description="List of field names included in this index.", + ) + unique: bool = Field( + default=True, + description="If True, index enforces uniqueness on the field combination.", + ) + type: IndexType = Field( + default=IndexType.PERSISTENT, + description="Index type (PERSISTENT, HASH, SKIPLIST, FULLTEXT).", + ) + deduplicate: bool = Field( + default=True, + description="Whether to deduplicate index entries (e.g. ArangoDB).", + ) + sparse: bool = Field( + default=False, + description="If True, create a sparse index (exclude null/missing values).", + ) + exclude_edge_endpoints: bool = Field( + default=False, + description="If True, do not add _from/_to to edge index (e.g. ArangoDB).", + ) def __iter__(self): """Iterate over the indexed fields.""" diff --git a/graflo/architecture/transform.py b/graflo/architecture/transform.py index a32c0e19..5fb7b133 100644 --- a/graflo/architecture/transform.py +++ b/graflo/architecture/transform.py @@ -78,12 +78,30 @@ class ProtoTransform(ConfigBaseModel): _foo: Internal reference to the transform function """ - name: str | None = None - module: str | None = None - params: dict[str, Any] = Field(default_factory=dict) - foo: str | None = None - input: tuple[str, ...] = Field(default_factory=tuple) - output: tuple[str, ...] = Field(default_factory=tuple) + name: str | None = Field( + default=None, + description="Optional name for this transform (e.g. for reference in schema.transforms).", + ) + module: str | None = Field( + default=None, + description="Python module path containing the transform function (e.g. my_package.transforms).", + ) + params: dict[str, Any] = Field( + default_factory=dict, + description="Extra parameters passed to the transform function at runtime.", + ) + foo: str | None = Field( + default=None, + description="Name of the callable in module to use as the transform function.", + ) + input: tuple[str, ...] = Field( + default_factory=tuple, + description="Input field names passed to the transform function.", + ) + output: tuple[str, ...] = Field( + default_factory=tuple, + description="Output field names produced by the transform (defaults to input if unset).", + ) _foo: Any = PrivateAttr(default=None) @@ -149,11 +167,23 @@ class Transform(ProtoTransform): functional_transform: Whether this is a functional transform """ - fields: tuple[str, ...] = Field(default_factory=tuple) - map: dict[str, str] = Field(default_factory=dict) - switch: dict[str, Any] = Field(default_factory=dict) - - functional_transform: bool = False + fields: tuple[str, ...] = Field( + default_factory=tuple, + description="Field names for declarative transform (used to derive input when input unset).", + ) + map: dict[str, str] = Field( + default_factory=dict, + description="Mapping of output_key -> input_key for pure field renaming (no function).", + ) + switch: dict[str, Any] = Field( + default_factory=dict, + description="Switch/case-style mapping for conditional field values (key -> output spec).", + ) + + functional_transform: bool = Field( + default=False, + description="True when a callable (module.foo) is set; False for pure map/switch transforms.", + ) @model_validator(mode="before") @classmethod diff --git a/graflo/architecture/vertex.py b/graflo/architecture/vertex.py index 248e34dd..73e69cb6 100644 --- a/graflo/architecture/vertex.py +++ b/graflo/architecture/vertex.py @@ -32,7 +32,7 @@ from graflo.architecture.base import ConfigBaseModel from graflo.architecture.onto import Index -from graflo.filter.onto import Clause +from graflo.filter.onto import FilterExpression from graflo.onto import DBType from graflo.onto import BaseEnum @@ -84,8 +84,14 @@ class Field(ConfigBaseModel): model_config = ConfigDict(extra="forbid") - name: str - type: FieldType | None = None + name: str = PydanticField( + ..., + description="Name of the field (e.g. column or attribute name).", + ) + type: FieldType | None = PydanticField( + default=None, + description="Optional field type for databases that require it (e.g. TigerGraph: INT, STRING). None for schema-agnostic backends.", + ) @field_validator("type", mode="before") @classmethod @@ -208,13 +214,26 @@ class Vertex(ConfigBaseModel): # Allow extra keys when loading from YAML (e.g. transforms, other runtime keys) model_config = ConfigDict(extra="ignore") - name: str - fields: list[Field] = PydanticField(default_factory=list) - indexes: list[Index] = PydanticField(default_factory=list) - filters: list[Any] = PydanticField( - default_factory=list - ) # items become Clause via convert_to_expressions - dbname: str | None = None + name: str = PydanticField( + ..., + description="Name of the vertex type (e.g. user, post, company).", + ) + fields: list[Field] = PydanticField( + default_factory=list, + description="List of fields (names, Field objects, or dicts). Normalized to Field objects.", + ) + indexes: list[Index] = PydanticField( + default_factory=list, + description="List of index definitions for this vertex. Defaults to primary index on all fields if empty.", + ) + filters: list[FilterExpression] = PydanticField( + default_factory=list, + description="Filter expressions (logical formulae) applied when querying this vertex.", + ) + dbname: str | None = PydanticField( + default=None, + description="Optional database collection/table name. Defaults to vertex name if not set.", + ) @field_validator("fields", mode="before") @classmethod @@ -241,12 +260,16 @@ def convert_to_indexes(cls, v: Any) -> Any: def convert_to_expressions(cls, v: Any) -> Any: if not isinstance(v, list): return v - result: list[Any] = [] + result: list[FilterExpression] = [] for item in v: - if isinstance(item, dict): - result.append(Clause.from_dict(item)) - else: + if isinstance(item, FilterExpression): result.append(item) + elif isinstance(item, (dict, list)): + result.append(FilterExpression.from_dict(item)) + else: + raise ValueError( + "each filter must be a FilterExpression instance or a dict/list (parsed as FilterExpression)" + ) return result @model_validator(mode="after") @@ -309,10 +332,22 @@ class VertexConfig(ConfigBaseModel): # Allow extra keys when loading from YAML (e.g. vertex_config wrapper key) model_config = ConfigDict(extra="ignore") - vertices: list[Vertex] - blank_vertices: list[str] = PydanticField(default_factory=list) - force_types: dict[str, list] = PydanticField(default_factory=dict) - db_flavor: DBType = DBType.ARANGO + vertices: list[Vertex] = PydanticField( + ..., + description="List of vertex type definitions (name, fields, indexes, filters).", + ) + blank_vertices: list[str] = PydanticField( + default_factory=list, + description="Vertex names that may be created without explicit data (e.g. placeholders).", + ) + force_types: dict[str, list] = PydanticField( + default_factory=dict, + description="Override mapping: vertex name -> list of field type names for type inference.", + ) + db_flavor: DBType = PydanticField( + default=DBType.ARANGO, + description="Database flavor (ARANGO, NEO4J, TIGERGRAPH) for schema and index generation.", + ) _vertices_map: dict[str, Vertex] | None = PrivateAttr(default=None) _vertex_numeric_fields_map: dict[str, object] | None = PrivateAttr(default=None) @@ -482,14 +517,14 @@ def numeric_fields_list(self, vertex_name): f" {vertex_name} was not defined in config" ) - def filters(self, vertex_name) -> list[Clause]: + def filters(self, vertex_name) -> list[FilterExpression]: """Get filter clauses for a vertex. Args: vertex_name: Name of the vertex Returns: - list[Clause]: List of filter clauses + list[FilterExpression]: List of filter expressions """ m = self._get_vertices_map() if vertex_name in m: diff --git a/graflo/db/arango/conn.py b/graflo/db/arango/conn.py index 4a157cc7..504f778e 100644 --- a/graflo/db/arango/conn.py +++ b/graflo/db/arango/conn.py @@ -41,7 +41,7 @@ from graflo.db.arango.util import render_filters from graflo.db.conn import Connection, SchemaExistsError from graflo.db.util import get_data_from_cursor, json_serializer -from graflo.filter.onto import Clause +from graflo.filter.onto import FilterExpression from graflo.onto import AggregationType from graflo.util.transform import pick_unique_dict from graflo.onto import DBType @@ -856,7 +856,7 @@ def fetch_present_documents( match_keys: list[str] | tuple[str, ...], keep_keys: list[str] | tuple[str, ...] | None = None, flatten: bool = False, - filters: None | Clause | list[Any] | dict[str, Any] = None, + filters: None | FilterExpression | list[Any] | dict[str, Any] = None, ) -> list[dict[str, Any]] | dict[int, list[dict[str, Any]]]: """Fetch documents that exist in the database. @@ -899,7 +899,7 @@ def fetch_present_documents( def fetch_docs( self, class_name: str, - filters: None | Clause | list[Any] | dict[str, Any] = None, + filters: None | FilterExpression | list[Any] | dict[str, Any] = None, limit: int | None = None, return_keys: list[str] | None = None, unset_keys: list[str] | None = None, @@ -954,7 +954,7 @@ def fetch_edges( edge_type: str | None = None, to_type: str | None = None, to_id: str | None = None, - filters: list[Any] | dict[str, Any] | Clause | None = None, + filters: list[Any] | dict[str, Any] | FilterExpression | None = None, limit: int | None = None, return_keys: list[str] | None = None, unset_keys: list[str] | None = None, @@ -1040,7 +1040,7 @@ def aggregate( aggregation_function: AggregationType, discriminant: str | None = None, aggregated_field: str | None = None, - filters: None | Clause | list[Any] | dict[str, Any] = None, + filters: None | FilterExpression | list[Any] | dict[str, Any] = None, ) -> list[dict[str, Any]]: """Perform aggregation on a collection. @@ -1097,7 +1097,7 @@ def keep_absent_documents( class_name: str, match_keys: list[str] | tuple[str, ...], keep_keys: list[str] | tuple[str, ...] | None = None, - filters: None | Clause | list[Any] | dict[str, Any] = None, + filters: None | FilterExpression | list[Any] | dict[str, Any] = None, ) -> list[dict[str, Any]]: """Keep documents that don't exist in the database. diff --git a/graflo/db/arango/query.py b/graflo/db/arango/query.py index bcf0a50b..454ed1d4 100644 --- a/graflo/db/arango/query.py +++ b/graflo/db/arango/query.py @@ -21,7 +21,7 @@ from arango import ArangoClient -from graflo.filter.onto import Clause, Expression +from graflo.filter.onto import FilterExpression from graflo.onto import ExpressionFlavor logger = logging.getLogger(__name__) @@ -129,7 +129,7 @@ def fetch_fields_query( docs, match_keys, keep_keys, - filters: list | dict | Clause | None = None, + filters: list | dict | FilterExpression | None = None, ): """Generate and execute a field-fetching AQL query. @@ -165,7 +165,11 @@ def fetch_fields_query( keep_clause = f"KEEP(_x, {list(keep_keys)})" if keep_keys is not None else "_x" if filters is not None: - ff = filters if isinstance(filters, Clause) else Expression.from_dict(filters) + ff = ( + filters + if isinstance(filters, FilterExpression) + else FilterExpression.from_dict(filters) + ) extrac_filter_clause = f" && {ff(doc_name='_cdoc', kind=ExpressionFlavor.AQL)}" else: extrac_filter_clause = "" diff --git a/graflo/db/arango/util.py b/graflo/db/arango/util.py index f104291c..a2e6e514 100644 --- a/graflo/db/arango/util.py +++ b/graflo/db/arango/util.py @@ -16,7 +16,7 @@ import logging from graflo.architecture.edge import Edge -from graflo.filter.onto import Clause +from graflo.filter.onto import FilterExpression from graflo.onto import ExpressionFlavor logger = logging.getLogger(__name__) @@ -60,7 +60,9 @@ def define_extra_edges(g: Edge): return query0 -def render_filters(filters: None | list | dict | Clause = None, doc_name="d") -> str: +def render_filters( + filters: None | list | dict | FilterExpression = None, doc_name="d" +) -> str: """Convert filter expressions to AQL filter clauses. This function converts filter expressions into AQL filter clauses that @@ -79,8 +81,8 @@ def render_filters(filters: None | list | dict | Clause = None, doc_name="d") -> >>> # Returns: "FILTER user.field == 'value' && user.age > 18" """ if filters is not None: - if not isinstance(filters, Clause): - ff = Clause.from_dict(filters) + if not isinstance(filters, FilterExpression): + ff = FilterExpression.from_dict(filters) else: ff = filters literal_condition = ff(doc_name=doc_name, kind=ExpressionFlavor.AQL) diff --git a/graflo/db/falkordb/conn.py b/graflo/db/falkordb/conn.py index aa3d5d46..5551bc06 100644 --- a/graflo/db/falkordb/conn.py +++ b/graflo/db/falkordb/conn.py @@ -37,7 +37,7 @@ from graflo.architecture.vertex import VertexConfig from graflo.db.conn import Connection, SchemaExistsError from graflo.db.util import serialize_value -from graflo.filter.onto import Expression +from graflo.filter.onto import FilterExpression from graflo.onto import AggregationType from graflo.onto import DBType @@ -689,7 +689,7 @@ def fetch_docs( """ # Build filter clause if filters is not None: - ff = Expression.from_dict(filters) + ff = FilterExpression.from_dict(filters) filter_clause = f"WHERE {ff(doc_name='n', kind=self.expression_flavor())}" else: filter_clause = "" @@ -794,7 +794,7 @@ def fetch_edges( # Add additional filters if provided if filters is not None: - ff = Expression.from_dict(filters) + ff = FilterExpression.from_dict(filters) filter_clause = ff(doc_name="r", kind=self.expression_flavor()) where_clauses.append(filter_clause) @@ -912,7 +912,7 @@ def aggregate( """ # Build filter clause if filters is not None: - ff = Expression.from_dict(filters) + ff = FilterExpression.from_dict(filters) filter_clause = f"WHERE {ff(doc_name='n', kind=self.expression_flavor())}" else: filter_clause = "" diff --git a/graflo/db/memgraph/conn.py b/graflo/db/memgraph/conn.py index 44c7d969..1760fac4 100644 --- a/graflo/db/memgraph/conn.py +++ b/graflo/db/memgraph/conn.py @@ -87,7 +87,7 @@ from graflo.architecture.schema import Schema from graflo.architecture.vertex import VertexConfig from graflo.db.conn import Connection, SchemaExistsError -from graflo.filter.onto import Expression +from graflo.filter.onto import FilterExpression from graflo.onto import AggregationType from graflo.onto import DBType @@ -876,7 +876,7 @@ def fetch_docs( q = f"MATCH (n:{class_name})" if filters is not None: - ff = Expression.from_dict(filters) + ff = FilterExpression.from_dict(filters) filter_str = ff(doc_name="n", kind=self.expression_flavor()) q += f" WHERE {filter_str}" @@ -971,7 +971,7 @@ def fetch_edges( # Add relationship property filters if filters is not None: - ff = Expression.from_dict(filters) + ff = FilterExpression.from_dict(filters) filter_str = ff(doc_name="r", kind=self.expression_flavor()) where_clauses.append(filter_str) @@ -1113,7 +1113,7 @@ def aggregate( # Build filter clause filter_clause = "" if filters is not None: - ff = Expression.from_dict(filters) + ff = FilterExpression.from_dict(filters) filter_str = ff(doc_name="n", kind=self.expression_flavor()) filter_clause = f" WHERE {filter_str}" diff --git a/graflo/db/neo4j/conn.py b/graflo/db/neo4j/conn.py index 9b48d147..320706cd 100644 --- a/graflo/db/neo4j/conn.py +++ b/graflo/db/neo4j/conn.py @@ -33,7 +33,7 @@ from graflo.architecture.schema import Schema from graflo.architecture.vertex import VertexConfig from graflo.db.conn import Connection, SchemaExistsError -from graflo.filter.onto import Expression +from graflo.filter.onto import FilterExpression from graflo.onto import AggregationType from graflo.onto import DBType @@ -509,7 +509,7 @@ def fetch_docs( list: Fetched nodes """ if filters is not None: - ff = Expression.from_dict(filters) + ff = FilterExpression.from_dict(filters) filter_clause = f"WHERE {ff(doc_name='n', kind=self.expression_flavor())}" else: filter_clause = "" @@ -589,9 +589,7 @@ def fetch_edges( # Add additional filters if provided if filters is not None: - from graflo.filter.onto import Expression - - ff = Expression.from_dict(filters) + ff = FilterExpression.from_dict(filters) filter_clause = ff(doc_name="r", kind=self.expression_flavor()) where_clauses.append(filter_clause) diff --git a/graflo/db/tigergraph/conn.py b/graflo/db/tigergraph/conn.py index 6cf0e2ef..7d307b12 100644 --- a/graflo/db/tigergraph/conn.py +++ b/graflo/db/tigergraph/conn.py @@ -49,7 +49,7 @@ VALID_TIGERGRAPH_TYPES, ) from graflo.db.util import json_serializer -from graflo.filter.onto import Clause +from graflo.filter.onto import FilterExpression from graflo.onto import AggregationType from graflo.onto import DBType from graflo.util.transform import pick_unique_dict @@ -3681,7 +3681,7 @@ def insert_return_batch( def _render_rest_filter( self, - filters: list | dict | Clause | None, + filters: list | dict | FilterExpression | None, field_types: dict[str, FieldType] | None = None, ) -> str: """Convert filter expressions to REST++ filter format. @@ -3698,8 +3698,8 @@ def _render_rest_filter( str: REST++ filter string (empty if no filters) """ if filters is not None: - if not isinstance(filters, Clause): - ff = Clause.from_dict(filters) + if not isinstance(filters, FilterExpression): + ff = FilterExpression.from_dict(filters) else: ff = filters @@ -3717,7 +3717,7 @@ def _render_rest_filter( def fetch_docs( self, class_name: str, - filters: list[Any] | dict[str, Any] | Clause | None = None, + filters: list[Any] | dict[str, Any] | FilterExpression | None = None, limit: int | None = None, return_keys: list[str] | None = None, unset_keys: list[str] | None = None, @@ -3728,7 +3728,7 @@ def fetch_docs( Args: class_name: Vertex type name (or dbname) - filters: Filter expression (list, dict, or Clause) + filters: Filter expression (list, dict, or FilterExpression) limit: Maximum number of documents to return return_keys: Keys to return (projection) unset_keys: Keys to exclude (projection) @@ -3816,7 +3816,7 @@ def fetch_edges( edge_type: str | None = None, to_type: str | None = None, to_id: str | None = None, - filters: list[Any] | dict[str, Any] | Clause | None = None, + filters: list[Any] | dict[str, Any] | FilterExpression | None = None, limit: int | None = None, return_keys: list[str] | None = None, unset_keys: list[str] | None = None, diff --git a/graflo/filter/__init__.py b/graflo/filter/__init__.py index 9b746568..55a22bd4 100644 --- a/graflo/filter/__init__.py +++ b/graflo/filter/__init__.py @@ -6,12 +6,11 @@ Key Components: - LogicalOperator: Logical operations (AND, OR, NOT, IMPLICATION) - ComparisonOperator: Comparison operations (==, !=, >, <, etc.) - - Clause: Filter clause implementation - - Expression: Filter expression factory + - FilterExpression: Filter expression (leaf or composite logical formulae) Example: - >>> from graflo.filter import Expression - >>> expr = Expression.from_dict({ + >>> from graflo.filter.onto import FilterExpression + >>> expr = FilterExpression.from_dict({ ... "AND": [ ... {"field": "age", "cmp_operator": ">=", "value": 18}, ... {"field": "status", "cmp_operator": "==", "value": "active"} diff --git a/graflo/filter/onto.py b/graflo/filter/onto.py index b7a2837a..43380fe5 100644 --- a/graflo/filter/onto.py +++ b/graflo/filter/onto.py @@ -7,10 +7,10 @@ Key Components: - LogicalOperator: Enum for logical operations (AND, OR, NOT, IMPLICATION) - ComparisonOperator: Enum for comparison operations (==, !=, >, <, etc.) - - Clause: Unified filter clause (discriminated: kind="leaf" or kind="composite") + - FilterExpression: Unified filter expression (discriminated: kind="leaf" or kind="composite") Example: - >>> expr = Clause.from_dict({ + >>> expr = FilterExpression.from_dict({ ... "AND": [ ... {"field": "age", "cmp_operator": ">=", "value": 18}, ... {"field": "status", "cmp_operator": "==", "value": "active"} @@ -92,8 +92,8 @@ class ComparisonOperator(BaseEnum): IN = "IN" -class Clause(ConfigBaseModel): - """Unified filter clause (discriminated: leaf or composite). +class FilterExpression(ConfigBaseModel): + """Unified filter expression (discriminated: leaf or composite). - kind="leaf": single field comparison (field, cmp_operator, value, optional unary_op). - kind="composite": logical combination (operator AND/OR/NOT/IF_THEN, deps). @@ -111,7 +111,7 @@ class Clause(ConfigBaseModel): # Composite fields (used when kind="composite") operator: LogicalOperator | None = None # AND, OR, NOT, IF_THEN - deps: list[Clause] = Field(default_factory=list) + deps: list[FilterExpression] = Field(default_factory=list) @field_validator("value", mode="before") @classmethod @@ -140,33 +140,33 @@ def leaf_operator_to_unary_op(cls, data: Any) -> Any: return data @model_validator(mode="after") - def check_discriminated_shape(self) -> Clause: + def check_discriminated_shape(self) -> FilterExpression: """Enforce exactly one shape per kind.""" if self.kind == "leaf": if self.operator is not None or self.deps: - raise ValueError("leaf clause must not have operator or deps") + raise ValueError("leaf expression must not have operator or deps") else: if self.operator is None: - raise ValueError("composite clause must have operator") + raise ValueError("composite expression must have operator") return self @field_validator("deps", mode="before") @classmethod def parse_deps(cls, v: list[Any]) -> list[Any]: - """Parse dict/list items into Clause instances.""" + """Parse dict/list items into FilterExpression instances.""" if not isinstance(v, list): return v result = [] for item in v: if isinstance(item, (dict, list)): - result.append(Clause.from_dict(item)) + result.append(FilterExpression.from_dict(item)) else: result.append(item) return result @classmethod - def from_list(cls, current: list[Any]) -> Clause: - """Build a leaf clause from list form [cmp_operator, value, field?, unary_op?].""" + def from_list(cls, current: list[Any]) -> FilterExpression: + """Build a leaf expression from list form [cmp_operator, value, field?, unary_op?].""" cmp_operator = current[0] value = current[1] field = current[2] if len(current) > 2 else None @@ -181,9 +181,9 @@ def from_list(cls, current: list[Any]) -> Clause: @classmethod def from_dict(cls, current: dict[str, Any] | list[Any]) -> Self: # type: ignore[override] - """Create a filter clause from a dictionary or list. + """Create a filter expression from a dictionary or list. - Returns Clause (leaf or composite). LSP-compliant: return type is Self. + Returns FilterExpression (leaf or composite). LSP-compliant: return type is Self. """ if isinstance(current, list): if current[0] in ComparisonOperator: @@ -193,8 +193,8 @@ def from_dict(cls, current: dict[str, Any] | list[Any]) -> Self: # type: ignore elif isinstance(current, dict): k = list(current.keys())[0] if k in LogicalOperator: - clauses = [cls.from_dict(v) for v in current[k]] - return cls(kind="composite", operator=LogicalOperator(k), deps=clauses) + deps = [cls.from_dict(v) for v in current[k]] + return cls(kind="composite", operator=LogicalOperator(k), deps=deps) else: # Leaf from dict: map YAML "operator" -> unary_op unary_op = current.get("operator") @@ -213,7 +213,7 @@ def __call__( kind: ExpressionFlavor = ExpressionFlavor.AQL, **kwargs, ) -> str | bool: - """Render or evaluate the clause in the target language.""" + """Render or evaluate the expression in the target language.""" if self.kind == "leaf": return self._call_leaf(doc_name=doc_name, kind=kind, **kwargs) return self._call_composite(doc_name=doc_name, kind=kind, **kwargs) @@ -381,8 +381,3 @@ def _cast_python_composite(self, kind: ExpressionFlavor, **kwargs: Any) -> bool: return OperatorMapping[self.operator]( [dep(kind=kind, **kwargs) for dep in self.deps] ) - - -# Backward compatibility -Expression = Clause -LeafClause = Clause diff --git a/test/test_filters.py b/test/test_filters.py index 13fd5383..8896f442 100644 --- a/test/test_filters.py +++ b/test/test_filters.py @@ -1,9 +1,8 @@ import pytest from graflo.filter.onto import ( - Clause, ComparisonOperator, - Expression, + FilterExpression, LogicalOperator, ) @@ -37,35 +36,35 @@ def and_clause(eq_clause, cong_clause): def test_none_leaf(none_clause): - lc = Clause.from_list(none_clause) + lc = FilterExpression.from_list(none_clause) result = lc() assert isinstance(result, str) assert "null" in result def test_leaf_clause_construct(eq_clause): - lc = Clause.from_list(eq_clause) + lc = FilterExpression.from_list(eq_clause) assert lc.cmp_operator == ComparisonOperator.EQ assert lc() == 'doc["x"] == "1"' def test_leaf_clause_construct_(eq_clause): - lc = Expression.from_dict(eq_clause) + lc = FilterExpression.from_dict(eq_clause) assert lc.cmp_operator == ComparisonOperator.EQ assert lc() == 'doc["x"] == "1"' def test_init_filter_and(and_clause): - c = Expression.from_dict(and_clause) + c = FilterExpression.from_dict(and_clause) assert c.operator == LogicalOperator.AND assert c() == 'doc["x"] == "1" AND doc["y"] % 2 == 2' def test_init_filter_eq(eq_clause): - c = Expression.from_dict(eq_clause) + c = FilterExpression.from_dict(eq_clause) assert c() == 'doc["x"] == "1"' def test_init_filter_in(in_clause): - c = Expression.from_dict(in_clause) + c = FilterExpression.from_dict(in_clause) assert c() == "IN [1, 2]" diff --git a/test/test_filters_python.py b/test/test_filters_python.py index e4670eec..678a1b08 100644 --- a/test/test_filters_python.py +++ b/test/test_filters_python.py @@ -1,7 +1,7 @@ import pytest import yaml -from graflo.filter.onto import Clause, Expression, LogicalOperator +from graflo.filter.onto import FilterExpression, LogicalOperator from graflo.onto import ExpressionFlavor @@ -77,20 +77,20 @@ def filter_implication(clause_open, clause_b): def test_python_clause(clause_open): - lc = Clause(**clause_open) # kind=leaf inferred from operator (str) + lc = FilterExpression(**clause_open) # kind=leaf inferred from operator (str) doc = {"name": "Open"} assert lc(**doc, kind=ExpressionFlavor.PYTHON) def test_condition_b(clause_b): - m = Clause(**clause_b) # kind=leaf inferred from operator (str) + m = FilterExpression(**clause_b) # kind=leaf inferred from operator (str) doc = {"value": -1} assert m(value=1, kind=ExpressionFlavor.PYTHON) assert not m(kind=ExpressionFlavor.PYTHON, **doc) def test_clause_a(clause_a): - m = Expression.from_dict(clause_a) + m = FilterExpression.from_dict(clause_a) doc = {"name": "Open", "value": 5.0} assert m(kind=ExpressionFlavor.PYTHON, **doc) @@ -100,7 +100,7 @@ def test_clause_a(clause_a): def test_clause_ab(clause_ab): - m = Expression.from_dict(clause_ab) + m = FilterExpression.from_dict(clause_ab) doc = {"name": "Open", "value": 5.0} assert m(kind=ExpressionFlavor.PYTHON, **doc) @@ -116,7 +116,7 @@ def test_clause_ab(clause_ab): def test_filter_implication(filter_implication): - m = Expression.from_dict(filter_implication) + m = FilterExpression.from_dict(filter_implication) doc = {"name": "Open", "value": -1.0} assert not m(kind=ExpressionFlavor.PYTHON, **doc) @@ -126,7 +126,7 @@ def test_filter_implication(filter_implication): def test_filter_neq(clause_volume): - m = Expression.from_dict(clause_volume) + m = FilterExpression.from_dict(clause_volume) doc = {"name": "Open", "value": -1.0} assert m(kind=ExpressionFlavor.PYTHON, **doc) From a67f9d9c3f47ccd7370cf08aa92c3e153cfd13d4 Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Mon, 2 Feb 2026 18:35:01 +0100 Subject: [PATCH 4/7] fixed examples --- CHANGELOG.md | 11 + docs/examples/example-1.md | 2 +- docs/examples/example-2.md | 2 +- docs/examples/example-3.md | 2 +- docs/examples/example-4.md | 2 +- docs/examples/example-5.md | 4 +- docs/examples/example-6.md | 5 +- docs/getting_started/creating_schema.md | 257 ++++++++++++++++++ docs/getting_started/quickstart.md | 2 +- examples/1-ingest-csv/ingest.py | 1 + examples/2-ingest-self-references/ingest.py | 1 + examples/2-ingest-self-references/schema.yaml | 3 +- examples/3-ingest-csv-edge-weights/ingest.py | 2 +- examples/4-ingest-neo4j/ingest.py | 6 +- .../5-ingest-postgres/generated-schema.yaml | 55 +++- examples/5-ingest-postgres/ingest.py | 4 +- graflo/architecture/edge.py | 9 - graflo/architecture/transform.py | 7 +- graflo/db/connection/onto.py | 90 ++++-- graflo/hq/graph_engine.py | 4 + mkdocs.yml | 1 + test/db/connection/test_onto.py | 41 ++- 22 files changed, 450 insertions(+), 61 deletions(-) create mode 100644 docs/getting_started/creating_schema.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 1164c6aa..ecb1025f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,17 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.4.6] - 2026-02-02 + +### Added +... + +### Changed +- **Configs use Pydantic**: Schema and all schema-related configs now use Pydantic `BaseModel` (via `ConfigBaseModel`) instead of dataclasses + - `Schema`, `SchemaMetadata`, `VertexConfig`, `Vertex`, `EdgeConfig`, `Edge`, `Resource`, `WeightConfig`, `Field`, and actor configs are Pydantic models + - Validation, YAML/dict loading via `model_validate()` / `from_dict()` / `from_yaml()`, and consistent serialization + - Backward compatible: `resources` accepts empty dict as empty list; field/weight inputs accept strings, `Field` objects, or dicts + ## [1.4.5] - 2026-02-02 ### Added diff --git a/docs/examples/example-1.md b/docs/examples/example-1.md index 99829d11..cc538214 100644 --- a/docs/examples/example-1.md +++ b/docs/examples/example-1.md @@ -121,7 +121,7 @@ from graflo.hq.caster import IngestionParams caster = Caster(schema) ingestion_params = IngestionParams( - recreate_schema=False, # Set to True to drop and redefine schema (script halts if schema exists) + clear_data=True, # Clear existing data before ingesting # max_items=1000, # Optional: limit number of items to process ) diff --git a/docs/examples/example-2.md b/docs/examples/example-2.md index fed7bbe8..a544f277 100644 --- a/docs/examples/example-2.md +++ b/docs/examples/example-2.md @@ -132,7 +132,7 @@ patterns.add_file_pattern( from graflo.hq.caster import IngestionParams ingestion_params = IngestionParams( - recreate_schema=True, # Wipe existing schema before defining and ingesting + clear_data=True, # Clear existing data before ingesting ) caster.ingest( diff --git a/docs/examples/example-3.md b/docs/examples/example-3.md index 9a5616f7..f5a22da7 100644 --- a/docs/examples/example-3.md +++ b/docs/examples/example-3.md @@ -120,7 +120,7 @@ from graflo.hq.caster import IngestionParams caster = Caster(schema) ingestion_params = IngestionParams( - recreate_schema=True, # Wipe existing schema before defining and ingesting + clear_data=True, # Clear existing data before ingesting ) caster.ingest( diff --git a/docs/examples/example-4.md b/docs/examples/example-4.md index 70cad859..184cc9db 100644 --- a/docs/examples/example-4.md +++ b/docs/examples/example-4.md @@ -214,7 +214,7 @@ from graflo.hq.caster import IngestionParams caster = Caster(schema) ingestion_params = IngestionParams( - recreate_schema=True, # Wipe existing schema before defining and ingesting + clear_data=True, # Clear existing data before ingesting ) caster.ingest( diff --git a/docs/examples/example-5.md b/docs/examples/example-5.md index 328f411b..c39c7d7a 100644 --- a/docs/examples/example-5.md +++ b/docs/examples/example-5.md @@ -373,7 +373,7 @@ from graflo.hq.caster import IngestionParams # Use GraphEngine for schema definition and ingestion engine = GraphEngine() ingestion_params = IngestionParams( - recreate_schema=True, # Drop existing schema and define new one before ingesting + clear_data=True, # Clear existing data before ingesting ) engine.define_and_ingest( @@ -436,7 +436,7 @@ patterns = engine.create_patterns(postgres_conf, schema_name="public") # Step 7: Define schema and ingest data ingestion_params = IngestionParams( - recreate_schema=True, # Drop existing schema and define new one before ingesting + clear_data=True, # Clear existing data before ingesting ) # Use GraphEngine to define schema and ingest data diff --git a/docs/examples/example-6.md b/docs/examples/example-6.md index debd13d0..32538656 100644 --- a/docs/examples/example-6.md +++ b/docs/examples/example-6.md @@ -91,6 +91,7 @@ config_data = FileHandle.load("db.yaml") conn_conf = DBConfig.from_dict(config_data) ingestion_params = IngestionParams( + clear_data=True, batch_size=1000, # Process 1000 items per batch ) @@ -218,7 +219,9 @@ registry.register(file_source, resource_name="users") # Both will be processed and combined from graflo.hq.caster import IngestionParams -ingestion_params = IngestionParams() # Use default parameters +ingestion_params = IngestionParams( + clear_data=True, +) caster.ingest_data_sources( data_source_registry=registry, diff --git a/docs/getting_started/creating_schema.md b/docs/getting_started/creating_schema.md new file mode 100644 index 00000000..4af822d5 --- /dev/null +++ b/docs/getting_started/creating_schema.md @@ -0,0 +1,257 @@ +# Creating a Schema + +This guide explains how to define a graflo **Schema**: the central configuration that describes your graph structure (vertices and edges), how data is transformed (resources and actors), and optional metadata. The content is structured so that both developers and automated agents can follow the same principles. + +## Principles + +1. **Schema is the single source of truth** for the graph: vertex types, edge types, indexes, and the mapping from raw data to vertices/edges. +2. **All schema configs are Pydantic models** (`ConfigBaseModel`). You can load from YAML or dicts; validation runs at load time. +3. **Resources define data pipelines**: each resource has a unique `resource_name` and an `apply` (or `pipeline`) list of **actor steps**. Data sources (files, APIs, SQL) are bound to resources by name elsewhere (e.g. `Patterns`). +4. **Order of definition is flexible** in YAML: `general`, `vertex_config`, `edge_config`, `resources`, and `transforms` can appear in any order. References (e.g. vertex names in edges or in `apply`) must refer to names defined in the same schema. + +## Schema structure + +A Schema has five top-level parts: + +| Section | Required | Description | +|------------------|----------|-------------| +| `general` | Yes | Schema name and optional version. | +| `vertex_config` | Yes | Vertex types and their fields, indexes, filters. | +| `edge_config` | Yes | Edge types (source, target, weights, indexes). | +| `resources` | No | List of resources: data pipelines (apply/pipeline) that map data to vertices and edges. | +| `transforms` | No | Named transform functions used by resources. | + +## `general` (SchemaMetadata) + +Identifies the schema. Used for versioning and as fallback graph/schema name when the database config does not set one. + +```yaml +general: + name: my_graph # required + version: "1.0" # optional +``` + +- **`name`**: Required. Identifier for the schema (e.g. graph or database name). +- **`version`**: Optional. Semantic or custom version string. + +## `vertex_config` + +Defines **vertex types**: their fields, indexes, and optional filters. Each vertex type has a unique `name` and is referenced by that name in edges and in resources. + +### Structure + +```yaml +vertex_config: + vertices: + - name: person + fields: [id, name, age] + indexes: + - fields: [id] + - name: department + fields: [name] + indexes: + - fields: [name] + blank_vertices: [] # optional: vertex names allowed without explicit data + force_types: {} # optional: vertex -> list of field type names + db_flavor: ARANGO # optional: ARANGO | NEO4J | TIGERGRAPH +``` + +### Vertex fields + +- **`name`**: Required. Vertex type name (e.g. `person`, `department`). Must be unique. +- **`fields`**: List of field definitions. Each item can be: + - A **string** (field name, type inferred or omitted). + - A **dict** with `name` and optional `type`: `{"name": "created_at", "type": "DATETIME"}`. + - For TigerGraph or typed backends, use types: `INT`, `UINT`, `FLOAT`, `DOUBLE`, `BOOL`, `STRING`, `DATETIME`. +- **`indexes`**: List of index definitions. If empty, a single primary index on all fields is created. Each index can specify `fields` and optionally `unique: true/false`. +- **`filters`**: Optional list of filter expressions for querying this vertex. +- **`dbname`**: Optional. Database-specific name (e.g. collection/table). Defaults to `name` if not set. + +### VertexConfig-level options + +- **`blank_vertices`**: Vertex names that may be created without explicit row data (e.g. placeholders). Each must exist in `vertices`. +- **`force_types`**: Override mapping from vertex name to list of field type names for inference. +- **`db_flavor`**: Database flavor used for schema/index generation: `ARANGO`, `NEO4J`, or `TIGERGRAPH`. + +## `edge_config` + +Defines **edge types**: source and target vertex types, relation name, weights, and indexes. + +### Structure + +```yaml +edge_config: + edges: + - source: person + target: department + # optional: relation, match_source, match_target, weights, indexes, etc. +``` + +### Edge fields + +- **`source`**, **`target`**: Required. Vertex type names (must exist in `vertex_config.vertices`). +- **`relation`**: Optional. Relationship/edge type name (especially for Neo4j). For ArangoDB can be used as weight. +- **`relation_field`**: Optional. Field name that stores or reads the relation type (e.g. for TigerGraph). +- **`relation_from_key`**: Optional. If true, derive relation from the location key during ingestion (e.g. JSON key). +- **`match_source`**, **`match_target`**: Optional. Fields used to match source/target vertices when creating edges. +- **`weights`**: Optional. Weight/attribute configuration: + - **`direct`**: List of field names or typed fields to attach directly to the edge (e.g. `["date", "weight"]` or `[{"name": "date", "type": "DATETIME"}]`). + - **`vertices`**: List of vertex-based weight definitions. +- **`indexes`** (or **`index`**): Optional. List of index definitions for the edge. +- **`purpose`**: Optional. Extra label for utility edges between the same vertex types. +- **`type`**: Optional. `DIRECT` (default) or `INDIRECT`. +- **`aux`**: Optional. If true, edge is created in DB but not used by graflo ingestion. +- **`by`**: Optional. For `INDIRECT` edges: vertex type name used to define the edge. + +## `resources` (focus) + +Resources define **how** each data stream is turned into vertices and edges. Each resource has a unique **`resource_name`** (used by Patterns / DataSourceRegistry to bind files, APIs, or SQL to this pipeline) and an **`apply`** (or **`pipeline`**) list of **actor steps**. Steps are executed in order; the pipeline can branch with **descend** steps. + +### Resource-level fields + +- **`resource_name`**: Required. Unique identifier (e.g. table or file name). Used when mapping data sources to this resource. +- **`apply`** (or **`pipeline`**): Required. List of actor steps (see below). +- **`encoding`**: Optional. Character encoding (default `UTF_8`). +- **`merge_collections`**: Optional. List of collection names to merge when writing. +- **`extra_weights`**: Optional. Additional edge weight configs for this resource. +- **`types`**: Optional. Field name → Python type expression for casting (e.g. `{"amount": "float"}`). +- **`edge_greedy`**: Optional. If true (default), emit edges as soon as source/target exist; if false, wait for explicit targets. + +### Actor steps in `apply` / `pipeline` + +Each step is a dict. The system recognizes: + +1. **Vertex step** — create vertices of a given type from the current document level: + ```yaml + - vertex: person + ``` + Optional: `keep_fields: [id, name]`. + +2. **Transform step** — rename fields, change shape, or apply a named transform; optionally send result to a vertex: + ```yaml + - map: + person: name + person_id: id + target_vertex: department # or to_vertex + ``` + Or use a **named transform** (defined in `transforms`): + ```yaml + - name: keep_suffix_id + params: { sep: "/", keep: -1 } + input: [id] + output: [_key] + ``` + +3. **Edge step** — create edges between two vertex types: + ```yaml + - source: person + target: department + ``` + Or: + ```yaml + - edge: + from: person + to: department + ``` + You can add edge-specific `weights`, `indexes`, etc. in the step when needed. + +4. **Descend step** — go into a nested key and run a sub-pipeline (or process all keys with `any_key`): + ```yaml + - key: referenced_works + apply: + - vertex: work + - source: work + target: work + ``` + Or with **`any_key`** to iterate over all keys: + ```yaml + - any_key: true + apply: [...] + ``` + +### Rules for resources (for agents) + +- **Unique names**: Every `resource_name` in the schema must be unique. +- **References**: All vertex names in `apply` (e.g. `vertex: person`, `source`/`target`, `target_vertex`) must exist in `vertex_config.vertices`. All edge relationships implied by `source`/`target` should exist in `edge_config.edges` (or be compatible). +- **Order**: Steps run in sequence. Typically you create vertices before creating edges that reference them; use **transform** to reshape data and **descend** to handle nested structures. +- **Transforms**: If a step uses `name: `, that name must exist in `transforms` (see below). + +## `transforms` + +Optional dictionary of **named transforms** used by resources. Keys are transform names; values are configs (e.g. `foo`, `module`, `params`, `input`, `output`). + +```yaml +transforms: + keep_suffix_id: + foo: split_keep_part + module: graflo.util.transform + params: { sep: "/", keep: -1 } + input: [id] + output: [_key] +``` + +Resources refer to them with `name: keep_suffix_id` (and optional `params`, `input`, `output` overrides) in a transform step. + +## Loading a schema + +All schema configs are Pydantic models. You can load a Schema from a dict or YAML: + +```python +from graflo import Schema +from suthing import FileHandle + +# From dict (e.g. from YAML already parsed) +schema = Schema.model_validate(FileHandle.load("schema.yaml")) +# Or explicit method +schema = Schema.from_dict(FileHandle.load("schema.yaml")) + +# From YAML file path (if your root is the schema dict) +data = FileHandle.load("schema.yaml") +schema = Schema.model_validate(data) +``` + +After loading, the schema runs `finish_init()` (transform names, edge init, resource pipelines, and the internal resource name map). If you modify `resources` programmatically, call `schema.finish_init()` so that `fetch_resource(name)` and ingestion use the updated pipelines. + +## Minimal full example + +```yaml +general: + name: hr + +vertex_config: + vertices: + - name: person + fields: [id, name, age] + indexes: + - fields: [id] + - name: department + fields: [name] + indexes: + - fields: [name] + +edge_config: + edges: + - source: person + target: department + +resources: + - resource_name: people + apply: + - vertex: person + - resource_name: departments + apply: + - map: + person: name + person_id: id + - target_vertex: department + map: + department: name +``` + +This defines two vertex types (`person`, `department`), one edge type (`person` → `department`), and two resources: **people** (each row → one `person` vertex) and **departments** (transform + `department` vertices). Data sources are attached to these resources by name (e.g. via `Patterns` or `DataSourceRegistry`) as shown in the [Quick Start](quickstart.md). + +## See also + +- [Concepts — Schema and constituents](../concepts/index.md#schema) for higher-level overview. +- [Quick Start](quickstart.md) for loading a schema and running ingestion. +- [API Reference — architecture](../reference/architecture/index.md) for Pydantic model details. diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index 9466133f..f03b4f8f 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -105,7 +105,7 @@ engine.define_and_ingest( # ) ``` -Here `schema` defines the graph and the mapping the sources to vertices and edges (refer to [Schema](../concepts/index.md#schema) for details on schema and its components). +Here `schema` defines the graph and the mapping of sources to vertices and edges. See [Creating a Schema](creating_schema.md) for how to define `vertex_config`, `edge_config`, and **resources**; see [Concepts — Schema](../concepts/index.md#schema) for a high-level overview. The `Patterns` class maps resource names (from `Schema`) to their physical data sources: - **FilePattern**: For file-based resources with `regex` for matching filenames and `sub_path` for the directory to search diff --git a/examples/1-ingest-csv/ingest.py b/examples/1-ingest-csv/ingest.py index a3d008ed..6bc40f27 100644 --- a/examples/1-ingest-csv/ingest.py +++ b/examples/1-ingest-csv/ingest.py @@ -57,4 +57,5 @@ target_db_config=conn_conf, patterns=patterns, ingestion_params=ingestion_params, + recreate_schema=True, ) diff --git a/examples/2-ingest-self-references/ingest.py b/examples/2-ingest-self-references/ingest.py index 1a0b7c0b..d81b9da1 100644 --- a/examples/2-ingest-self-references/ingest.py +++ b/examples/2-ingest-self-references/ingest.py @@ -48,4 +48,5 @@ target_db_config=conn_conf, patterns=patterns, ingestion_params=ingestion_params, + recreate_schema=True, ) diff --git a/examples/2-ingest-self-references/schema.yaml b/examples/2-ingest-self-references/schema.yaml index 78d75f5f..9180d40a 100644 --- a/examples/2-ingest-self-references/schema.yaml +++ b/examples/2-ingest-self-references/schema.yaml @@ -33,7 +33,8 @@ vertex_config: - unique: false fields: - doi -edge_config: [] +edge_config: + edges: [] transforms: keep_suffix_id: foo: split_keep_part diff --git a/examples/3-ingest-csv-edge-weights/ingest.py b/examples/3-ingest-csv-edge-weights/ingest.py index a93f1a43..fe4f5acc 100644 --- a/examples/3-ingest-csv-edge-weights/ingest.py +++ b/examples/3-ingest-csv-edge-weights/ingest.py @@ -19,7 +19,6 @@ conn_conf = Neo4jConfig.from_docker_env() # from graflo.db.connection.onto import TigergraphConfig -# # conn_conf = TigergraphConfig.from_docker_env() # Alternative: Create config directly or use environment variables @@ -56,4 +55,5 @@ target_db_config=conn_conf, patterns=patterns, ingestion_params=ingestion_params, + recreate_schema=True, ) diff --git a/examples/4-ingest-neo4j/ingest.py b/examples/4-ingest-neo4j/ingest.py index 59a5ab88..cdbc3292 100644 --- a/examples/4-ingest-neo4j/ingest.py +++ b/examples/4-ingest-neo4j/ingest.py @@ -55,13 +55,11 @@ # Create GraphEngine and define schema + ingest in one operation engine = GraphEngine(target_db_flavor=db_type) -ingestion_params = IngestionParams( - # max_items=5, -) +ingestion_params = IngestionParams(clear_data=True) engine.define_and_ingest( schema=schema, target_db_config=conn_conf, # Target database config patterns=patterns, # Source data patterns ingestion_params=ingestion_params, - recreate_schema=True, # Wipe existing schema before defining and ingesting + recreate_schema=True, ) diff --git a/examples/5-ingest-postgres/generated-schema.yaml b/examples/5-ingest-postgres/generated-schema.yaml index 64b3a492..e4668371 100644 --- a/examples/5-ingest-postgres/generated-schema.yaml +++ b/examples/5-ingest-postgres/generated-schema.yaml @@ -1,15 +1,24 @@ edge_config: edges: - - relation: follows + - aux: false + index: [] + relation: follows + relation_from_key: false source: users target: users + type: direct weights: direct: - name: created_at type: DATETIME - - relation: purchases + vertices: [] + - aux: false + index: [] + relation: purchases + relation_from_key: false source: users target: products + type: direct weights: direct: - name: purchase_date @@ -18,15 +27,26 @@ edge_config: type: INT - name: total_amount type: FLOAT + vertices: [] general: name: accounting resources: - apply: - vertex: products + edge_greedy: true + encoding: utf-8 + extra_weights: [] + merge_collections: [] resource_name: products + types: {} - apply: - vertex: users + edge_greedy: true + encoding: utf-8 + extra_weights: [] + merge_collections: [] resource_name: users + types: {} - apply: - map: follower_id: id @@ -34,7 +54,12 @@ resources: - map: followed_id: id target_vertex: users + edge_greedy: true + encoding: utf-8 + extra_weights: [] + merge_collections: [] resource_name: follows + types: {} - apply: - map: user_id: id @@ -42,11 +67,17 @@ resources: - map: product_id: id target_vertex: products + edge_greedy: true + encoding: utf-8 + extra_weights: [] + merge_collections: [] resource_name: purchases + types: {} transforms: {} vertex_config: - db_flavor: !!python/object/apply:graflo.onto.DBType - - tigergraph + blank_vertices: [] + db_flavor: tigergraph + force_types: {} vertices: - dbname: products fields: @@ -60,9 +91,15 @@ vertex_config: type: STRING - name: created_at type: DATETIME + filters: [] indexes: - - fields: + - deduplicate: true + exclude_edge_endpoints: false + fields: - id + sparse: false + type: persistent + unique: true name: products - dbname: users fields: @@ -74,7 +111,13 @@ vertex_config: type: STRING - name: created_at type: DATETIME + filters: [] indexes: - - fields: + - deduplicate: true + exclude_edge_endpoints: false + fields: - id + sparse: false + type: persistent + unique: true name: users diff --git a/examples/5-ingest-postgres/ingest.py b/examples/5-ingest-postgres/ingest.py index 5271657c..08260f01 100644 --- a/examples/5-ingest-postgres/ingest.py +++ b/examples/5-ingest-postgres/ingest.py @@ -117,8 +117,8 @@ schema=schema, target_db_config=conn_conf, patterns=patterns, - ingestion_params=IngestionParams(clear_data=False), - recreate_schema=True, # Drop existing schema and define new one before ingesting + ingestion_params=IngestionParams(clear_data=True), + recreate_schema=True, ) print("\n" + "=" * 80) diff --git a/graflo/architecture/edge.py b/graflo/architecture/edge.py index 8ab88a67..5d09d505 100644 --- a/graflo/architecture/edge.py +++ b/graflo/architecture/edge.py @@ -464,12 +464,3 @@ def vertices(self): set[str]: Set of vertex names """ return {e.source for e in self.edges} | {e.target for e in self.edges} - - # def __getitem__(self, key: EdgeId): - # if key in self._reset_edges(): - # return self._edges_map[key] - # else: - # raise KeyError(f"Vertex {key} absent") - # - # def __setitem__(self, key: EdgeId, value: Edge): - # self._edges_map[key] = value diff --git a/graflo/architecture/transform.py b/graflo/architecture/transform.py index 5fb7b133..0a5bc5ca 100644 --- a/graflo/architecture/transform.py +++ b/graflo/architecture/transform.py @@ -112,8 +112,11 @@ def _normalize_input_output(cls, data: Any) -> Any: return data data = dict(data) for key in ("input", "output"): - if key in data and data[key] is not None: - data[key] = _tuple_it(data[key]) + if key in data: + if data[key] is not None: + data[key] = _tuple_it(data[key]) + else: + data[key] = () return data @model_validator(mode="after") diff --git a/graflo/db/connection/onto.py b/graflo/db/connection/onto.py index 3e8256ea..7c9ab049 100644 --- a/graflo/db/connection/onto.py +++ b/graflo/db/connection/onto.py @@ -1,5 +1,6 @@ import abc import logging +import os import warnings from pathlib import Path from typing import Any, Dict, Type, TypeVar @@ -399,50 +400,85 @@ def from_docker_env(cls, docker_dir: str | Path | None = None) -> "DBConfig": raise NotImplementedError("Subclasses must implement from_docker_env") @classmethod - def from_env(cls: Type[T], prefix: str | None = None) -> T: + def from_env( + cls: Type[T], + *, + prefix: str | None = None, + profile: str | None = None, + suffix: str | None = None, + ) -> T: """Load config from environment variables using Pydantic BaseSettings. - Supports custom prefixes for multiple configs: - - Default (prefix=None): Uses {BASE_PREFIX}URI, {BASE_PREFIX}USERNAME, etc. - - With prefix (prefix="USER"): Uses USER_{BASE_PREFIX}URI, USER_{BASE_PREFIX}USERNAME, etc. + Supports qualifiers for multiple configs from the same env: + + - **prefix**: outer prefix → ``{prefix}_{BASE_PREFIX}URI`` (e.g. ``USER_ARANGO_URI``). + - **profile**: segment after base → ``{BASE_PREFIX}{profile}_URI`` (e.g. ``ARANGO_DEV_URI``). + - **suffix**: after field name → ``{BASE_PREFIX}URI_{suffix}`` (e.g. ``ARANGO_URI_DEV``). + + At most one of ``prefix``, ``profile``, ``suffix`` should be set. Args: - prefix: Optional prefix for environment variables (e.g., "USER", "LAKE", "KG"). - If None, uses default {BASE_PREFIX}* variables. + prefix: Outer env prefix (e.g. ``"USER"`` → ``USER_ARANGO_URI``). + profile: Env segment after base (e.g. ``"DEV"`` → ``ARANGO_DEV_URI``). + suffix: Env segment after field name (e.g. ``"DEV"`` → ``ARANGO_URI_DEV``). Returns: - DBConfig instance loaded from environment variables using Pydantic BaseSettings + DBConfig instance loaded from environment variables. Examples: - # Load default config (ARANGO_URI, ARANGO_USERNAME, etc.) + # Default (ARANGO_URI, ARANGO_USERNAME, ...) config = ArangoConfig.from_env() - # Load config with prefix (USER_ARANGO_URI, USER_ARANGO_USERNAME, etc.) + # By profile: ARANGO_DEV_URI, ARANGO_DEV_USERNAME, ... + dev = ArangoConfig.from_env(profile="DEV") + + # By suffix: ARANGO_URI_DEV, ARANGO_USERNAME_DEV, ... + dev2 = ArangoConfig.from_env(suffix="DEV") + + # Outer prefix: USER_ARANGO_URI, ... user_config = ArangoConfig.from_env(prefix="USER") """ + base_prefix = cls.model_config.get("env_prefix") + if not base_prefix: + raise ValueError( + f"Class {cls.__name__} does not have env_prefix configured in model_config" + ) + case_sensitive = cls.model_config.get("case_sensitive", False) + qualifiers = sum(1 for q in (prefix, profile, suffix) if q is not None) + if qualifiers > 1: + raise ValueError("At most one of prefix, profile, suffix may be set") + + if suffix: + # Pydantic doesn't support env_suffix; read suffixed vars manually. + data: Dict[str, Any] = {} + suf = suffix if case_sensitive else suffix.upper() + for name in cls.model_fields: + env_name = f"{base_prefix}{name.upper()}_{suf}" + if not case_sensitive: + # Match pydantic-settings: first try exact, then uppercase + val = os.environ.get(env_name) or os.environ.get(env_name.lower()) + else: + val = os.environ.get(env_name) + if val is not None: + data[name] = val + return cls(**data) + if prefix: - # Get the base prefix from the class's model_config - base_prefix = cls.model_config.get("env_prefix") - if not base_prefix: - raise ValueError( - f"Class {cls.__name__} does not have env_prefix configured in model_config" - ) - # Create a new model class with modified env_prefix new_prefix = f"{prefix.upper()}_{base_prefix}" - case_sensitive = cls.model_config.get("case_sensitive", False) - model_config = SettingsConfigDict( - env_prefix=new_prefix, - case_sensitive=case_sensitive, - ) - # Create a new class dynamically with the modified prefix - temp_class = type( - f"{cls.__name__}WithPrefix", (cls,), {"model_config": model_config} - ) - return temp_class() + elif profile: + new_prefix = f"{base_prefix}{profile.upper()}_" else: - # Use default prefix - Pydantic will read from environment automatically return cls() + model_config = SettingsConfigDict( + env_prefix=new_prefix, + case_sensitive=case_sensitive, + ) + temp_class = type( + f"{cls.__name__}WithPrefix", (cls,), {"model_config": model_config} + ) + return temp_class() + class ArangoConfig(DBConfig): """Configuration for ArangoDB connections.""" diff --git a/graflo/hq/graph_engine.py b/graflo/hq/graph_engine.py index c28b3c9d..ff01604e 100644 --- a/graflo/hq/graph_engine.py +++ b/graflo/hq/graph_engine.py @@ -155,6 +155,10 @@ def define_schema( # ArangoDB, Neo4j use 'database' field (which maps to effective_schema) target_db_config.database = schema_name + # Ensure schema's vertex_config reflects target DB so Edge.finish_init() + # applies DB-specific defaults (e.g. TigerGraph default relation name) + schema.vertex_config.db_flavor = target_db_config.connection_type + # Initialize database with schema definition # init_db() handles database/schema creation automatically # It checks if the database exists and creates it if needed diff --git a/mkdocs.yml b/mkdocs.yml index ab5d6870..b7da2314 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -53,6 +53,7 @@ nav: - Getting Started: - Installation: getting_started/installation.md - Quick Start: getting_started/quickstart.md + - Creating a Schema: getting_started/creating_schema.md - Concepts: concepts/index.md - Examples: - Examples Index: examples/index.md diff --git a/test/db/connection/test_onto.py b/test/db/connection/test_onto.py index 9a36fad2..d1431453 100644 --- a/test/db/connection/test_onto.py +++ b/test/db/connection/test_onto.py @@ -1,4 +1,6 @@ -"""Tests for DBConfig.from_env() method with prefix support.""" +"""Tests for DBConfig.from_env() method with prefix, profile, and suffix support.""" + +import pytest from graflo.db.connection.onto import ArangoConfig, Neo4jConfig, TigergraphConfig @@ -74,6 +76,43 @@ def test_from_env_with_port_components(self, monkeypatch): # automatically read. Only fields defined in the model are read. assert config.username == "admin" + def test_from_env_with_profile(self, monkeypatch): + """Test profile: ARANGO_{profile}_* (e.g. ARANGO_DEV_URI).""" + monkeypatch.setenv("ARANGO_DEV_URI", "http://dev-db:8529") + monkeypatch.setenv("ARANGO_DEV_USERNAME", "dev_user") + monkeypatch.setenv("ARANGO_DEV_PASSWORD", "dev_pass") + monkeypatch.setenv("ARANGO_DEV_DATABASE", "dev_db") + + config = ArangoConfig.from_env(profile="DEV") + + assert config.uri == "http://dev-db:8529" + assert config.username == "dev_user" + assert config.password == "dev_pass" + assert config.database == "dev_db" + + def test_from_env_with_suffix(self, monkeypatch): + """Test suffix: ARANGO_*_{suffix} (e.g. ARANGO_URI_DEV).""" + monkeypatch.setenv("ARANGO_URI_DEV", "http://dev-db:8529") + monkeypatch.setenv("ARANGO_USERNAME_DEV", "dev_user") + monkeypatch.setenv("ARANGO_PASSWORD_DEV", "dev_pass") + monkeypatch.setenv("ARANGO_DATABASE_DEV", "dev_db") + + config = ArangoConfig.from_env(suffix="DEV") + + assert config.uri == "http://dev-db:8529" + assert config.username == "dev_user" + assert config.password == "dev_pass" + assert config.database == "dev_db" + + def test_from_env_at_most_one_qualifier(self): + """Test that at most one of prefix, profile, suffix may be set.""" + with pytest.raises(ValueError, match="At most one of prefix, profile, suffix"): + ArangoConfig.from_env(prefix="USER", profile="DEV") + with pytest.raises(ValueError, match="At most one of prefix, profile, suffix"): + ArangoConfig.from_env(prefix="USER", suffix="DEV") + with pytest.raises(ValueError, match="At most one of prefix, profile, suffix"): + ArangoConfig.from_env(profile="DEV", suffix="DEV") + class TestNeo4jConfigFromEnv: """Tests for Neo4jConfig.from_env().""" From ca45da6054b96c258b06cba9cd1cb76701943e32 Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Mon, 2 Feb 2026 19:25:17 +0100 Subject: [PATCH 5/7] cosmetic changes --- docs/getting_started/creating_schema.md | 4 ++-- graflo/architecture/resource.py | 10 ++++----- graflo/db/postgres/resource_mapping.py | 4 ++-- test/db/postgres/test_schema_inference.py | 26 +++++++++++------------ 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/docs/getting_started/creating_schema.md b/docs/getting_started/creating_schema.md index 4af822d5..8fbd587c 100644 --- a/docs/getting_started/creating_schema.md +++ b/docs/getting_started/creating_schema.md @@ -114,12 +114,12 @@ Resources define **how** each data stream is turned into vertices and edges. Eac - **`encoding`**: Optional. Character encoding (default `UTF_8`). - **`merge_collections`**: Optional. List of collection names to merge when writing. - **`extra_weights`**: Optional. Additional edge weight configs for this resource. -- **`types`**: Optional. Field name → Python type expression for casting (e.g. `{"amount": "float"}`). +- **`types`**: Optional. Field name → Python type expression for casting during ingestion (e.g. `{"age": "int"}`, `{"amount": "float"}`, `{"created_at": "datetime"}`). Useful when input is string-only (CSV, JSON) and you need numeric or date values. - **`edge_greedy`**: Optional. If true (default), emit edges as soon as source/target exist; if false, wait for explicit targets. ### Actor steps in `apply` / `pipeline` -Each step is a dict. The system recognizes: +Each step is a dict. You can write steps in shorthand (e.g. `vertex: person`) or with an explicit **`type`** (`vertex`, `transform`, `edge`, `descend`). The system recognizes: 1. **Vertex step** — create vertices of a given type from the current document level: ```yaml diff --git a/graflo/architecture/resource.py b/graflo/architecture/resource.py index 25244248..65481646 100644 --- a/graflo/architecture/resource.py +++ b/graflo/architecture/resource.py @@ -19,7 +19,7 @@ Example: >>> resource = Resource( ... resource_name="users", - ... apply=[{"vertex": "user"}, {"edge": {"from": "user", "to": "user"}}], + ... pipeline=[{"vertex": "user"}, {"edge": {"from": "user", "to": "user"}}], ... encoding=EncodingType.UTF_8 ... ) >>> result = resource(doc) @@ -62,11 +62,11 @@ class Resource(ConfigBaseModel): ..., description="Name of the resource (e.g. table or file identifier).", ) - apply: list[dict[str, Any]] = PydanticField( + pipeline: list[dict[str, Any]] = PydanticField( ..., description="Pipeline of actor steps to apply in sequence (vertex, edge, transform, descend). " 'Each step is a dict, e.g. {"vertex": "user"} or {"edge": {"from": "a", "to": "b"}}.', - validation_alias=AliasChoices("apply", "pipeline"), + validation_alias=AliasChoices("pipeline", "apply"), ) encoding: EncodingType = PydanticField( default=EncodingType.UTF_8, @@ -96,8 +96,8 @@ class Resource(ConfigBaseModel): @model_validator(mode="after") def _build_root_and_types(self) -> Resource: - """Build root ActorWrapper from apply and evaluate type expressions.""" - object.__setattr__(self, "_root", ActorWrapper(*self.apply)) + """Build root ActorWrapper from pipeline and evaluate type expressions.""" + object.__setattr__(self, "_root", ActorWrapper(*self.pipeline)) object.__setattr__(self, "_types", {}) for k, v in self.types.items(): try: diff --git a/graflo/db/postgres/resource_mapping.py b/graflo/db/postgres/resource_mapping.py index 51408f13..0026e5e0 100644 --- a/graflo/db/postgres/resource_mapping.py +++ b/graflo/db/postgres/resource_mapping.py @@ -67,7 +67,7 @@ def create_vertex_resource( resource = Resource( resource_name=table_name, - apply=apply, + pipeline=apply, ) logger.debug( @@ -202,7 +202,7 @@ def create_edge_resource( resource = Resource( resource_name=table_name, - apply=apply, + pipeline=apply, ) relation_info = f" with relation '{relation}'" if relation else "" diff --git a/test/db/postgres/test_schema_inference.py b/test/db/postgres/test_schema_inference.py index 4cbd7c0c..2c7dd67d 100644 --- a/test/db/postgres/test_schema_inference.py +++ b/test/db/postgres/test_schema_inference.py @@ -124,16 +124,16 @@ def test_infer_schema_from_postgres(conn_conf, load_mock_schema): # Verify resource actors users_resource = next(r for r in schema.resources if r.name == "users") - assert users_resource.apply is not None, "users resource should have apply list" - assert len(users_resource.apply) > 0, ( + assert users_resource.pipeline is not None, "users resource should have pipeline" + assert len(users_resource.pipeline) > 0, ( "users resource should have at least one actor" ) purchases_resource = next(r for r in schema.resources if r.name == "purchases") - assert purchases_resource.apply is not None, ( - "purchases resource should have apply list" + assert purchases_resource.pipeline is not None, ( + "purchases resource should have pipeline" ) - assert len(purchases_resource.apply) > 0, ( + assert len(purchases_resource.pipeline) > 0, ( "purchases resource should have at least one actor" ) @@ -158,7 +158,7 @@ def test_infer_schema_from_postgres(conn_conf, load_mock_schema): print(f"\nResources ({len(schema.resources)}):") for r in schema.resources: - actor_types = [type(a).__name__ for a in r.apply] + actor_types = [type(a).__name__ for a in r.pipeline] print(f" - {r.name} (actors: {', '.join(actor_types)})") print("=" * 80) @@ -306,18 +306,18 @@ def test_infer_schema_with_pg_catalog_fallback(conn_conf, load_mock_schema): # Verify resource actors - should be correctly created via pg_catalog users_resource = next(r for r in schema.resources if r.name == "users") - assert users_resource.apply is not None, ( - "users resource should have apply list when using pg_catalog" + assert users_resource.pipeline is not None, ( + "users resource should have pipeline when using pg_catalog" ) - assert len(users_resource.apply) > 0, ( + assert len(users_resource.pipeline) > 0, ( "users resource should have at least one actor when using pg_catalog" ) purchases_resource = next(r for r in schema.resources if r.name == "purchases") - assert purchases_resource.apply is not None, ( - "purchases resource should have apply list when using pg_catalog" + assert purchases_resource.pipeline is not None, ( + "purchases resource should have pipeline when using pg_catalog" ) - assert len(purchases_resource.apply) > 0, ( + assert len(purchases_resource.pipeline) > 0, ( "purchases resource should have at least one actor when using pg_catalog" ) @@ -344,7 +344,7 @@ def test_infer_schema_with_pg_catalog_fallback(conn_conf, load_mock_schema): print(f"\nResources ({len(schema.resources)}):") for r in schema.resources: - actor_types = [type(a).__name__ for a in r.apply] + actor_types = [type(a).__name__ for a in r.pipeline] print(f" - {r.name} (actors: {', '.join(actor_types)})") print("=" * 80) From 4e45be33bea5105c43f117293d36c219dd6b383b Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Mon, 2 Feb 2026 19:53:41 +0100 Subject: [PATCH 6/7] added ranged ingest --- docs/reference/architecture/actor_config.md | 3 + docs/reference/architecture/base.md | 3 + examples/5-ingest-postgres/ingest.py | 25 ++++- graflo/filter/onto.py | 32 ++++++ graflo/hq/caster.py | 82 +++++++++++++- graflo/hq/graph_engine.py | 13 ++- graflo/hq/resource_mapper.py | 13 ++- graflo/onto.py | 4 +- .../db/postgres/test_ingest_datetime_range.py | 104 ++++++++++++++++++ test/test_filters_python.py | 44 +++++++- test/test_ingestion_datetime.py | 88 +++++++++++++++ 11 files changed, 401 insertions(+), 10 deletions(-) create mode 100644 docs/reference/architecture/actor_config.md create mode 100644 docs/reference/architecture/base.md create mode 100644 test/db/postgres/test_ingest_datetime_range.py create mode 100644 test/test_ingestion_datetime.py diff --git a/docs/reference/architecture/actor_config.md b/docs/reference/architecture/actor_config.md new file mode 100644 index 00000000..c4feaf60 --- /dev/null +++ b/docs/reference/architecture/actor_config.md @@ -0,0 +1,3 @@ +# `graflo.architecture.actor_config` + +::: graflo.architecture.actor_config diff --git a/docs/reference/architecture/base.md b/docs/reference/architecture/base.md new file mode 100644 index 00000000..4519d2bc --- /dev/null +++ b/docs/reference/architecture/base.md @@ -0,0 +1,3 @@ +# `graflo.architecture.base` + +::: graflo.architecture.base diff --git a/examples/5-ingest-postgres/ingest.py b/examples/5-ingest-postgres/ingest.py index 08260f01..54f4ba0d 100644 --- a/examples/5-ingest-postgres/ingest.py +++ b/examples/5-ingest-postgres/ingest.py @@ -107,12 +107,35 @@ # Step 4: Create Patterns from PostgreSQL tables # This maps PostgreSQL tables to resource patterns that Caster can use # Connection is automatically managed inside create_patterns() -patterns = engine.create_patterns(postgres_conf, schema_name="public") +# +# Optional: provide a datetime column per resource for date-range filtering. +# Use with IngestionParams(datetime_after=..., datetime_before=...) to ingest +# only rows where the column falls in [datetime_after, datetime_before). +datetime_columns = { + "purchases": "purchase_date", + "users": "created_at", + "products": "created_at", + "follows": "created_at", +} +patterns = engine.create_patterns( + postgres_conf, + schema_name="public", + datetime_columns=datetime_columns, +) # Step 4.5 & 5: Define schema and ingest data in one operation # This creates/initializes the database schema and then ingests data # Some databases don't require explicit schema definition, but this ensures proper initialization # Note: ingestion will create its own PostgreSQL connections per table internally +# +# Optional: ingest only rows in a date range (requires datetime_columns above or +# IngestionParams(datetime_column="column_name") for a single default column). +# ingestion_params = IngestionParams( +# clear_data=True, +# datetime_after="2020-01-01", +# datetime_before="2021-01-01", +# datetime_column="created_at", # default if a pattern has no date_field +# ) engine.define_and_ingest( schema=schema, target_db_config=conn_conf, diff --git a/graflo/filter/onto.py b/graflo/filter/onto.py index 43380fe5..942727df 100644 --- a/graflo/filter/onto.py +++ b/graflo/filter/onto.py @@ -238,6 +238,9 @@ def _call_leaf( field_types = kwargs.get("field_types") return self._cast_restpp(field_types=field_types) return self._cast_tigergraph(doc_name) + elif kind == ExpressionFlavor.SQL: + assert self.cmp_operator is not None + return self._cast_sql() elif kind == ExpressionFlavor.PYTHON: return self._cast_python(**kwargs) raise ValueError(f"kind {kind} not implemented") @@ -252,6 +255,7 @@ def _call_composite( ExpressionFlavor.AQL, ExpressionFlavor.CYPHER, ExpressionFlavor.GSQL, + ExpressionFlavor.SQL, ): return self._cast_generic(doc_name=doc_name, kind=kind) elif kind == ExpressionFlavor.PYTHON: @@ -303,6 +307,34 @@ def _cast_tigergraph(self, doc_name: str) -> str: lemma = f"{doc_name}.{self.field} {lemma}" return lemma + def _cast_sql(self) -> str: + """Render leaf as SQL WHERE fragment: \"column\" op value (strings/dates single-quoted).""" + if not self.field: + return "" + if self.cmp_operator == ComparisonOperator.EQ: + op_str = "=" + elif self.cmp_operator == ComparisonOperator.NEQ: + op_str = "!=" + elif self.cmp_operator in ( + ComparisonOperator.GT, + ComparisonOperator.LT, + ComparisonOperator.GE, + ComparisonOperator.LE, + ): + op_str = str(self.cmp_operator) + else: + op_str = str(self.cmp_operator) + value = self.value[0] if self.value else None + if value is None: + value_str = "null" + elif isinstance(value, (int, float)): + value_str = str(value) + else: + # Strings and ISO datetimes: single-quoted for SQL + value_str = str(value).replace("'", "''") + value_str = f"'{value_str}'" + return f'"{self.field}" {op_str} {value_str}' + def _cast_restpp(self, field_types: dict[str, Any] | None = None) -> str: if not self.field: return "" diff --git a/graflo/hq/caster.py b/graflo/hq/caster.py index 3d71765f..93c5c458 100644 --- a/graflo/hq/caster.py +++ b/graflo/hq/caster.py @@ -28,6 +28,12 @@ from graflo.architecture.edge import Edge from graflo.architecture.onto import EncodingType, GraphContainer from graflo.architecture.schema import Schema +from graflo.filter.onto import ( + ComparisonOperator, + FilterExpression, + LogicalOperator, +) +from graflo.onto import ExpressionFlavor from graflo.data_source import ( AbstractDataSource, DataSourceFactory, @@ -57,6 +63,12 @@ class IngestionParams(BaseModel): max_concurrent_db_ops: Maximum number of concurrent database operations (for vertices/edges). If None, uses n_cores. Set to 1 to prevent deadlocks in databases that don't handle concurrent transactions well (e.g., Neo4j). Database-independent setting. + datetime_after: Inclusive lower bound for datetime filtering (ISO format). + Rows with date_column >= datetime_after are included. Used with SQL/table sources. + datetime_before: Exclusive upper bound for datetime filtering (ISO format). + Rows with date_column < datetime_before are included. Range is [datetime_after, datetime_before). + datetime_column: Default column name for datetime filtering when the pattern does not + specify date_field. Per-table override: set date_field on TablePattern (or FilePattern). """ clear_data: bool = False @@ -67,6 +79,9 @@ class IngestionParams(BaseModel): init_only: bool = False limit_files: int | None = None max_concurrent_db_ops: int | None = None + datetime_after: str | None = None + datetime_before: str | None = None + datetime_column: str | None = None class Caster: @@ -106,6 +121,46 @@ def __init__( self.ingestion_params = ingestion_params self.schema = schema + @staticmethod + def _datetime_range_where_sql( + datetime_after: str | None, + datetime_before: str | None, + date_column: str, + ) -> str: + """Build SQL WHERE fragment for [datetime_after, datetime_before) via FilterExpression. + + Returns empty string if both bounds are None; otherwise uses column with >= and <. + """ + if not datetime_after and not datetime_before: + return "" + parts: list[FilterExpression] = [] + if datetime_after is not None: + parts.append( + FilterExpression( + kind="leaf", + field=date_column, + cmp_operator=ComparisonOperator.GE, + value=[datetime_after], + ) + ) + if datetime_before is not None: + parts.append( + FilterExpression( + kind="leaf", + field=date_column, + cmp_operator=ComparisonOperator.LT, + value=[datetime_before], + ) + ) + if len(parts) == 1: + return cast(str, parts[0](kind=ExpressionFlavor.SQL)) + expr = FilterExpression( + kind="composite", + operator=LogicalOperator.AND, + deps=parts, + ) + return cast(str, expr(kind=ExpressionFlavor.SQL)) + @staticmethod def discover_files( fpath: Path | str, pattern: FilePattern, limit_files=None @@ -646,9 +701,30 @@ def _register_sql_table_sources( try: # Build base query query = f'SELECT * FROM "{effective_schema}"."{table_name}"' - where_clause = pattern.build_where_clause() - if where_clause: - query += f" WHERE {where_clause}" + where_parts: list[str] = [] + pattern_where = pattern.build_where_clause() + if pattern_where: + where_parts.append(pattern_where) + # Ingestion datetime range [datetime_after, datetime_before) + date_column = pattern.date_field or ingestion_params.datetime_column + if ( + ingestion_params.datetime_after or ingestion_params.datetime_before + ) and date_column: + datetime_where = Caster._datetime_range_where_sql( + ingestion_params.datetime_after, + ingestion_params.datetime_before, + date_column, + ) + if datetime_where: + where_parts.append(datetime_where) + elif ingestion_params.datetime_after or ingestion_params.datetime_before: + logger.warning( + "datetime_after/datetime_before set but no date column: " + "set TablePattern.date_field or IngestionParams.datetime_column for resource %s", + resource_name, + ) + if where_parts: + query += " WHERE " + " AND ".join(where_parts) # Get SQLAlchemy connection string from PostgresConfig connection_string = postgres_config.to_sqlalchemy_connection_string() diff --git a/graflo/hq/graph_engine.py b/graflo/hq/graph_engine.py index ff01604e..7d758c85 100644 --- a/graflo/hq/graph_engine.py +++ b/graflo/hq/graph_engine.py @@ -105,19 +105,26 @@ def create_patterns( self, postgres_config: PostgresConfig, schema_name: str | None = None, + datetime_columns: dict[str, str] | None = None, ) -> Patterns: """Create Patterns from PostgreSQL tables. Args: postgres_config: PostgresConfig instance schema_name: Schema name to introspect + datetime_columns: Optional mapping of resource/table name to datetime + column name for date-range filtering (sets date_field per + TablePattern). Use with IngestionParams.datetime_after / + datetime_before. Returns: Patterns: Patterns object with TablePattern instances for all tables """ with PostgresConnection(postgres_config) as postgres_conn: return self.resource_mapper.create_patterns_from_postgres( - conn=postgres_conn, schema_name=schema_name + conn=postgres_conn, + schema_name=schema_name, + datetime_columns=datetime_columns, ) def define_schema( @@ -206,8 +213,8 @@ def define_and_ingest( ) # Then ingest data (clear_data is applied inside ingest() when ingestion_params.clear_data) - ingestion_params = IngestionParams( - **{**ingestion_params.model_dump(), "clear_data": clear_data} + ingestion_params = ingestion_params.model_copy( + update={"clear_data": clear_data} ) self.ingest( schema=schema, diff --git a/graflo/hq/resource_mapper.py b/graflo/hq/resource_mapper.py index cfe57a7e..1e711e48 100644 --- a/graflo/hq/resource_mapper.py +++ b/graflo/hq/resource_mapper.py @@ -20,13 +20,20 @@ class ResourceMapper: """ def create_patterns_from_postgres( - self, conn: PostgresConnection, schema_name: str | None = None + self, + conn: PostgresConnection, + schema_name: str | None = None, + datetime_columns: dict[str, str] | None = None, ) -> Patterns: """Create Patterns from PostgreSQL tables. Args: conn: PostgresConnection instance schema_name: Schema name to introspect + datetime_columns: Optional mapping of resource/table name to datetime + column name for date-range filtering (sets date_field on each + TablePattern). Used with IngestionParams.datetime_after / + datetime_before. Returns: Patterns: Patterns object with TablePattern instances for all tables @@ -44,6 +51,8 @@ def create_patterns_from_postgres( config_key = "default" patterns.postgres_configs[(config_key, effective_schema)] = conn.config + date_cols = datetime_columns or {} + # Add patterns for vertex tables for table_info in introspection_result.vertex_tables: table_name = table_info.name @@ -51,6 +60,7 @@ def create_patterns_from_postgres( table_name=table_name, schema_name=effective_schema, resource_name=table_name, + date_field=date_cols.get(table_name), ) patterns.table_patterns[table_name] = table_pattern patterns.postgres_table_configs[table_name] = ( @@ -66,6 +76,7 @@ def create_patterns_from_postgres( table_name=table_name, schema_name=effective_schema, resource_name=table_name, + date_field=date_cols.get(table_name), ) patterns.table_patterns[table_name] = table_pattern patterns.postgres_table_configs[table_name] = ( diff --git a/graflo/onto.py b/graflo/onto.py index ce66cfd5..8c4d0687 100644 --- a/graflo/onto.py +++ b/graflo/onto.py @@ -95,18 +95,20 @@ class ExpressionFlavor(BaseEnum): """Supported expression language types for filter/query rendering. Uses the actual query language names: AQL (ArangoDB), CYPHER (Neo4j, - FalkorDB, Memgraph), GSQL (TigerGraph), PYTHON for in-memory evaluation. + FalkorDB, Memgraph), GSQL (TigerGraph), SQL for WHERE clauses, PYTHON for in-memory evaluation. Attributes: AQL: ArangoDB AQL expressions CYPHER: OpenCypher expressions (Neo4j, FalkorDB, Memgraph) GSQL: TigerGraph GSQL expressions (including REST++ filter format) + SQL: SQL WHERE clause fragments (column names, single-quoted values) PYTHON: Python expression evaluation """ AQL = "aql" CYPHER = "cypher" GSQL = "gsql" + SQL = "sql" PYTHON = "python" diff --git a/test/db/postgres/test_ingest_datetime_range.py b/test/db/postgres/test_ingest_datetime_range.py new file mode 100644 index 00000000..0dfa8df1 --- /dev/null +++ b/test/db/postgres/test_ingest_datetime_range.py @@ -0,0 +1,104 @@ +"""Real tests for ingesting in a date range from PostgreSQL. + +Requires PostgreSQL running (e.g. docker/postgres). Uses mock_schema tables +and asserts that datetime_after/datetime_before and per-resource date_field +filter rows correctly. +""" + +from graflo.hq.caster import Caster, IngestionParams +from graflo.hq.graph_engine import GraphEngine +from graflo.onto import DBType +from graflo.util.onto import TablePattern + + +def _set_purchase_dates(postgres_conn): + """Set purchases to known dates so we can test range [2020-02-01, 2020-06-01).""" + updates = [ + (1, "2020-01-10"), + (2, "2020-03-15"), + (3, "2020-05-20"), + (4, "2020-07-01"), + (5, "2020-09-01"), + (6, "2020-12-01"), + ] + with postgres_conn.conn.cursor() as cursor: + for pid, dt in updates: + cursor.execute( + "UPDATE purchases SET purchase_date = %s::timestamp WHERE id = %s", + (dt, pid), + ) + postgres_conn.conn.commit() + + +def test_datetime_columns_sets_date_field_on_patterns(conn_conf, load_mock_schema): + """create_patterns(..., datetime_columns=...) sets date_field on TablePatterns.""" + _ = load_mock_schema # ensure tables exist + engine = GraphEngine(target_db_flavor=DBType.ARANGO) + patterns = engine.create_patterns( + conn_conf, + schema_name="public", + datetime_columns={ + "purchases": "purchase_date", + "users": "created_at", + }, + ) + assert patterns.table_patterns["purchases"].date_field == "purchase_date" + assert patterns.table_patterns["users"].date_field == "created_at" + # Tables not in the map have no date_field + if "follows" in patterns.table_patterns: + assert patterns.table_patterns["follows"].date_field is None + + +def test_ingest_datetime_range_postgres(postgres_conn, load_mock_schema): + """Real Postgres: query with datetime_after/datetime_before returns only rows in range.""" + _ = load_mock_schema + _set_purchase_dates(postgres_conn) + + pattern = TablePattern( + table_name="purchases", + schema_name="public", + resource_name="purchases", + date_field="purchase_date", + ) + datetime_where = Caster._datetime_range_where_sql( + "2020-02-01", + "2020-06-01", + pattern.date_field or "purchase_date", + ) + assert datetime_where + query = f'SELECT * FROM "public"."purchases" WHERE {datetime_where}' + + rows = postgres_conn.read(query) + # Range [2020-02-01, 2020-06-01): only id 2 (2020-03-15) and id 3 (2020-05-20) + assert len(rows) == 2 + ids = {r["id"] for r in rows} + assert ids == {2, 3} + + +def test_ingest_datetime_range_with_global_column(postgres_conn, load_mock_schema): + """IngestionParams.datetime_column is used when pattern has no date_field.""" + _ = load_mock_schema + _set_purchase_dates(postgres_conn) + + pattern = TablePattern( + table_name="purchases", + schema_name="public", + resource_name="purchases", + date_field=None, + ) + ingestion_params = IngestionParams( + datetime_after="2020-02-01", + datetime_before="2020-06-01", + datetime_column="purchase_date", + ) + date_column = pattern.date_field or ingestion_params.datetime_column + assert date_column == "purchase_date" + datetime_where = Caster._datetime_range_where_sql( + ingestion_params.datetime_after, + ingestion_params.datetime_before, + date_column, + ) + query = f'SELECT * FROM "public"."purchases" WHERE {datetime_where}' + rows = postgres_conn.read(query) + assert len(rows) == 2 + assert {r["id"] for r in rows} == {2, 3} diff --git a/test/test_filters_python.py b/test/test_filters_python.py index 678a1b08..5607ae8d 100644 --- a/test/test_filters_python.py +++ b/test/test_filters_python.py @@ -1,7 +1,11 @@ import pytest import yaml -from graflo.filter.onto import FilterExpression, LogicalOperator +from graflo.filter.onto import ( + ComparisonOperator, + FilterExpression, + LogicalOperator, +) from graflo.onto import ExpressionFlavor @@ -133,3 +137,41 @@ def test_filter_neq(clause_volume): doc = {"name": "Volume", "value": -1.0} assert not m(kind=ExpressionFlavor.PYTHON, **doc) + + +def test_filter_expression_sql_leaf(): + """FilterExpression renders leaf to SQL WHERE fragment (ExpressionFlavor.SQL).""" + leaf = FilterExpression( + kind="leaf", + field="created_at", + cmp_operator=ComparisonOperator.GE, + value=["2020-01-01T00:00:00"], + ) + out = leaf(kind=ExpressionFlavor.SQL) + assert out == "\"created_at\" >= '2020-01-01T00:00:00'" + + +def test_filter_expression_sql_composite_and(): + """FilterExpression AND composite renders to SQL with AND.""" + ge = FilterExpression( + kind="leaf", + field="dt", + cmp_operator=ComparisonOperator.GE, + value=["2020-01-01"], + ) + lt = FilterExpression( + kind="leaf", + field="dt", + cmp_operator=ComparisonOperator.LT, + value=["2020-12-31"], + ) + expr = FilterExpression( + kind="composite", + operator=LogicalOperator.AND, + deps=[ge, lt], + ) + out = expr(kind=ExpressionFlavor.SQL) + assert isinstance(out, str) + assert "\"dt\" >= '2020-01-01'" in out + assert "\"dt\" < '2020-12-31'" in out + assert " AND " in out diff --git a/test/test_ingestion_datetime.py b/test/test_ingestion_datetime.py new file mode 100644 index 00000000..f3225681 --- /dev/null +++ b/test/test_ingestion_datetime.py @@ -0,0 +1,88 @@ +"""Tests for ingestion datetime range params and SQL WHERE building.""" + +from graflo.hq.caster import Caster, IngestionParams +from graflo.util.onto import TablePattern + + +def test_ingestion_params_datetime_defaults(): + """IngestionParams has None for datetime fields by default.""" + params = IngestionParams() + assert params.datetime_after is None + assert params.datetime_before is None + assert params.datetime_column is None + + +def test_ingestion_params_datetime_set(): + """IngestionParams accepts datetime_after, datetime_before, datetime_column.""" + params = IngestionParams( + datetime_after="2020-01-01T00:00:00", + datetime_before="2020-12-31T23:59:59", + datetime_column="created_at", + ) + assert params.datetime_after == "2020-01-01T00:00:00" + assert params.datetime_before == "2020-12-31T23:59:59" + assert params.datetime_column == "created_at" + + +def test_datetime_range_where_sql_empty(): + """_datetime_range_where_sql returns empty when both bounds None.""" + out = Caster._datetime_range_where_sql(None, None, "dt") + assert out == "" + + +def test_datetime_range_where_sql_both_bounds(): + """_datetime_range_where_sql produces [after, before) with AND.""" + out = Caster._datetime_range_where_sql( + "2020-01-01", + "2020-12-31", + "created_at", + ) + assert "\"created_at\" >= '2020-01-01'" in out + assert "\"created_at\" < '2020-12-31'" in out + assert " AND " in out + + +def test_datetime_range_where_sql_only_after(): + """_datetime_range_where_sql with only datetime_after.""" + out = Caster._datetime_range_where_sql("2020-06-01", None, "dt") + assert out == "\"dt\" >= '2020-06-01'" + + +def test_datetime_range_where_sql_only_before(): + """_datetime_range_where_sql with only datetime_before.""" + out = Caster._datetime_range_where_sql(None, "2021-01-01", "ts") + assert out == "\"ts\" < '2021-01-01'" + + +def test_datetime_range_where_sql_iso_format(): + """_datetime_range_where_sql accepts ISO datetime strings.""" + out = Caster._datetime_range_where_sql( + "2020-01-15T10:30:00", + "2020-01-15T18:00:00", + "updated_at", + ) + assert "2020-01-15T10:30:00" in out + assert "2020-01-15T18:00:00" in out + assert "updated_at" in out + + +def test_sql_query_where_combines_pattern_and_ingestion_datetime(): + """Query WHERE combines TablePattern date_filter and ingestion datetime range.""" + # Simulate the logic in _register_sql_table_sources: pattern WHERE + datetime WHERE + pattern = TablePattern( + table_name="events", + date_field="dt", + date_filter="!= '2020-01-01'", + ) + pattern_where = pattern.build_where_clause() + datetime_where = Caster._datetime_range_where_sql( + "2020-06-01", + "2020-07-01", + pattern.date_field or "dt", + ) + where_parts = [p for p in [pattern_where, datetime_where] if p] + combined = " AND ".join(where_parts) + assert "\"dt\" != '2020-01-01'" in combined + assert "\"dt\" >= '2020-06-01'" in combined + assert "\"dt\" < '2020-07-01'" in combined + assert combined.count(" AND ") == 2 From 4fe5b16e78de3ad9e38f13df76acddd0ecdf3986 Mon Sep 17 00:00:00 2001 From: Alexander Belikov Date: Mon, 2 Feb 2026 20:02:37 +0100 Subject: [PATCH 7/7] docs update --- CHANGELOG.md | 7 ++-- docs/examples/example-5.md | 34 +++++++++++++++++++ docs/getting_started/quickstart.md | 2 ++ .../5-ingest-postgres/generated-schema.yaml | 30 ++++++++-------- pyproject.toml | 2 +- uv.lock | 2 +- 6 files changed, 58 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ecb1025f..19d2fec4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,10 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [1.4.6] - 2026-02-02 +## [1.5.0] - 2026-02-02 ### Added -... +- **Ingestion date range**: `IngestionParams` supports `datetime_after`, `datetime_before`, and `datetime_column` so ingestion can be restricted to a date range + - Use with `GraphEngine.create_patterns(..., datetime_columns={...})` for per-resource datetime columns, or set `IngestionParams.datetime_column` for a single default column + - Rows are included when the datetime column value is in `[datetime_after, datetime_before)` (inclusive lower, exclusive upper) + - Applies to SQL/PostgreSQL table ingestion; enables sampling or incremental loads by time window ### Changed - **Configs use Pydantic**: Schema and all schema-related configs now use Pydantic `BaseModel` (via `ConfigBaseModel`) instead of dataclasses diff --git a/docs/examples/example-5.md b/docs/examples/example-5.md index c39c7d7a..dd442e30 100644 --- a/docs/examples/example-5.md +++ b/docs/examples/example-5.md @@ -337,11 +337,31 @@ patterns = engine.create_patterns( ) ``` +**Optional: datetime columns for date-range filtering** + +To restrict ingestion to a time window, pass `datetime_columns`: a mapping from resource (table) name to the name of the datetime column used for filtering. Use this together with `IngestionParams(datetime_after=..., datetime_before=...)` in the ingestion step: + +```python +# Optional: map each table to its datetime column for date-range filtering +datetime_columns = { + "purchases": "purchase_date", + "users": "created_at", + "products": "created_at", + "follows": "created_at", +} +patterns = engine.create_patterns( + postgres_conf, + schema_name="public", + datetime_columns=datetime_columns, +) +``` + This creates `TablePattern` instances for each table, which: - Map table names to resource names (e.g., `users` table → `users` resource) - Store PostgreSQL connection configuration - Enable the Caster to query data directly from PostgreSQL using SQL +- Optionally store a `date_field` for date-range filtering when `datetime_columns` is provided **How Patterns Work:** @@ -366,6 +386,15 @@ Finally, ingest the data from PostgreSQL into your target graph database. This i 5. **Graph Database Storage**: Data is written to the target graph database (ArangoDB/Neo4j/TigerGraph) using database-specific APIs for optimal performance. The system handles duplicates and updates based on indexes. +**Restricting ingestion by date range** + +You can limit which rows are ingested by providing a date range in `IngestionParams`. Use `datetime_after` and `datetime_before` (ISO-format strings); only rows whose datetime column value falls in `[datetime_after, datetime_before)` are included. This requires either: + +- Passing `datetime_columns` when creating patterns (see Step 5), or +- Setting `datetime_column` in `IngestionParams` as a single default column for all resources. + +Example: + ```python from graflo.hq import GraphEngine from graflo.hq.caster import IngestionParams @@ -374,6 +403,11 @@ from graflo.hq.caster import IngestionParams engine = GraphEngine() ingestion_params = IngestionParams( clear_data=True, # Clear existing data before ingesting + # Optional: ingest only rows in this date range (requires datetime_columns in create_patterns + # or datetime_column below) + # datetime_after="2020-01-01", + # datetime_before="2021-01-01", + # datetime_column="created_at", # default column when a pattern has no date_field ) engine.define_and_ingest( diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index f03b4f8f..2b892d02 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -162,6 +162,8 @@ arango_config = ArangoConfig.from_docker_env() # Target graph database engine = GraphEngine() ingestion_params = IngestionParams( recreate_schema=False, # Set to True to drop and redefine schema (script halts if schema exists) + # Optional: restrict to a date range with datetime_after, datetime_before, datetime_column + # (use with create_patterns(..., datetime_columns={...}) for per-table columns) ) engine.define_and_ingest( diff --git a/examples/5-ingest-postgres/generated-schema.yaml b/examples/5-ingest-postgres/generated-schema.yaml index e4668371..5bb02a60 100644 --- a/examples/5-ingest-postgres/generated-schema.yaml +++ b/examples/5-ingest-postgres/generated-schema.yaml @@ -31,46 +31,46 @@ edge_config: general: name: accounting resources: -- apply: - - vertex: products - edge_greedy: true +- edge_greedy: true encoding: utf-8 extra_weights: [] merge_collections: [] + pipeline: + - vertex: products resource_name: products types: {} -- apply: - - vertex: users - edge_greedy: true +- edge_greedy: true encoding: utf-8 extra_weights: [] merge_collections: [] + pipeline: + - vertex: users resource_name: users types: {} -- apply: +- edge_greedy: true + encoding: utf-8 + extra_weights: [] + merge_collections: [] + pipeline: - map: follower_id: id target_vertex: users - map: followed_id: id target_vertex: users - edge_greedy: true + resource_name: follows + types: {} +- edge_greedy: true encoding: utf-8 extra_weights: [] merge_collections: [] - resource_name: follows - types: {} -- apply: + pipeline: - map: user_id: id target_vertex: users - map: product_id: id target_vertex: products - edge_greedy: true - encoding: utf-8 - extra_weights: [] - merge_collections: [] resource_name: purchases types: {} transforms: {} diff --git a/pyproject.toml b/pyproject.toml index 2c4fd301..bfb00f57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ description = "A framework for transforming tabular (CSV, SQL) and hierarchical name = "graflo" readme = "README.md" requires-python = ">=3.11" -version = "1.4.5" +version = "1.5.0" [project.optional-dependencies] plot = [ diff --git a/uv.lock b/uv.lock index b1190be3..cac297b7 100644 --- a/uv.lock +++ b/uv.lock @@ -348,7 +348,7 @@ wheels = [ [[package]] name = "graflo" -version = "1.4.5" +version = "1.5.0" source = { editable = "." } dependencies = [ { name = "click" },