diff --git a/paracelsus/finders.py b/paracelsus/finders.py new file mode 100644 index 0000000..62dd23c --- /dev/null +++ b/paracelsus/finders.py @@ -0,0 +1,150 @@ +from collections import deque +from dataclasses import dataclass +from pathlib import Path +from typing import Generator, Optional, Set + + +@dataclass(frozen=True, eq=True) +class GlobNode: + pattern: str + next: Optional["GlobNode"] = None + + @staticmethod + def nodify(tokens: list[str]) -> Optional["GlobNode"]: + head = None + for part in reversed(tokens): + new_node = GlobNode(pattern=part, next=head) + head = new_node + return head + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, value: object) -> bool: + return isinstance(value, GlobNode) and self is value + + @property + def is_final(self): + return self.next is None + + def __repr__(self): + return f"Node({self.pattern})" + + +@dataclass(frozen=True) +class SearchState: + """Represents a snapshot of the traversal cursor.""" + + path: Path + node: GlobNode + + +class ModuleFinder: + def __init__(self, root: Path, segments: list[str]): + self.root = root + self.head = GlobNode.nodify(segments) + + self.queue: deque[SearchState] = deque() + # To prevent infinite loops with symlinks or redundant '**' paths + self.visited: Set[SearchState] = set() + + def find(self) -> Generator[Path, None, None]: + """ + Finds all modules that match the glob-like pattern for Python modules. + + Supports patterns like: + - example.*.models + - example.fo?.models + - example.*.*.models + - example.**.api.*.models + - example.api.v[12].models + - example.api.v[0-9].models + - example.api.v[!1].models + + """ + if self.head is None: + return + + # Initialize state + self.queue.append(SearchState(self.root, self.head)) + + while self.queue: + state = self.queue.popleft() + + # Optimization: distinct paths to the same state are redundant + if state in self.visited: + continue + self.visited.add(state) + + yield from self._process_state(state) + + def _process_state(self, state: SearchState) -> Generator[Path, None, None]: + """ + Implement BFS to search for files and directories matching the given pattern. + """ + node = state.node + path = state.path + + # === 1. Recursive Wildcard (**) === + if node.pattern == "**": + # Branch A: Skip (0 matches). + # Move to next node, keep path same. + if node.next: + self.queue.append(SearchState(path, node.next)) + + # Branch B: Consume (1+ matches). + # Stay on current node, move deeper into filesystem. + for child in self._safe_iterdir(path): + if child.is_dir(): + self.queue.append(SearchState(child, node)) + + else: + for child in self._safe_iterdir(path): + is_match = False + + if child.is_dir(): + if child.match(node.pattern): + is_match = True + + elif child.is_file(): + if Path(child.stem).match(node.pattern): + is_match = True + + if not is_match: + continue + + if node.is_final and self._is_valid_module(child): + yield child + + elif node.next is not None and child.is_dir(): + self.queue.append(SearchState(child, node.next)) + + def _safe_iterdir(self, path: Path) -> Generator[Path, None, None]: + """Safe wrapper around iterdir to handle permission errors.""" + try: + if path.is_dir(): + yield from path.iterdir() + except PermissionError: + pass + + def _is_valid_module(self, path: Path) -> bool: + """ + Determines if a path is a valid python module. + 1. File: my_module.py (but not __init__.py) + 2. Package: my_package/ (must contain __init__.py) + 3. Supports PEP 420 Namespace Packages. + """ + + name = path.name + + # 1. Ignore common garbage/internal directories + if name == "__pycache__" or name.startswith("."): + return False + + if path.is_file(): + return path.suffix == ".py" and path.name != "__init__.py" + + if path.is_dir(): + return (path / "__init__.py").exists() or name.isidentifier() + + return False diff --git a/paracelsus/graph.py b/paracelsus/graph.py index 70d9795..844e11f 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -1,4 +1,5 @@ import importlib +import logging import os import re import sys @@ -7,7 +8,10 @@ from sqlalchemy.schema import MetaData +from paracelsus.models.pattern import Pattern + from .config import Layouts +from .finders import ModuleFinder from .transformers.dot import Dot from .transformers.mermaid import Mermaid @@ -18,46 +22,224 @@ "gv": Dot, } +logger = logging.getLogger(__name__) -def get_graph_string( + +def to_module_name(root: Path, path: Path) -> str: + """ + Converts a filesystem path to a Python dotted module string. + Example: /root/app/models.py -> app.models + """ + try: + relative_path = path.resolve().relative_to(root) + except ValueError: + # Fallback if path is not relative to root (should not happen in normal usage) + return path.name + + if path.is_file(): + clean_path = relative_path.with_suffix("") + else: + clean_path = relative_path + + return ".".join(clean_path.parts) + + +def _find_base_classes_by_pattern( + base_class_path: str, + python_dir: List[Path], + current_root: Path, +) -> List[tuple[str, MetaData]]: + """ + Finds all base classes matching a glob pattern and returns their MetaData. + """ + if "*" not in base_class_path and "?" not in base_class_path: + # No wildcards, return single base class + module_path, class_name = base_class_path.split(":", 2) + try: + base_module = importlib.import_module(module_path) + base_class = getattr(base_module, class_name) + return [(module_path, base_class.metadata)] + except (ImportError, AttributeError) as e: + raise ValueError(f"Could not import base class from {base_class_path}: {e}") + + # Extract pattern parts + parts = base_class_path.split(":") + if len(parts) != 2: + raise ValueError(f"Invalid base_class_path format: {base_class_path}") + + pattern_str, class_name = parts + + # Create pattern for finding base.py modules + pattern = Pattern(mask=pattern_str) + + if any(pattern.errors): + raise ValueError(pattern.serialized_errors) + + # Find all matching base.py files + finder = ModuleFinder(current_root, pattern.tokens) + base_metadata_list = [] + + for file_path in finder.find(): + # Only consider base.py files + if file_path.name != "base.py" and not file_path.name.endswith("base.py"): + continue + + module_path = to_module_name(current_root, file_path) + + try: + base_module = importlib.import_module(module_path) + if hasattr(base_module, class_name): + base_class = getattr(base_module, class_name) + base_metadata_list.append((module_path, base_class.metadata)) + except (ImportError, AttributeError) as e: + logger.warning(f"Could not import base class from {module_path}: {e}") + continue + + if not base_metadata_list: + raise ValueError(f"No base classes found matching pattern: {base_class_path}") + + return base_metadata_list + + +def _merge_metadata(metadata_list: List[tuple[str, MetaData]]) -> MetaData: + """ + Merges multiple MetaData objects into a single MetaData. + If there are table name conflicts, prefixes are added based on the module path. + """ + merged_metadata = MetaData() + + for module_path, metadata in metadata_list: + # Extract a prefix from module path to avoid conflicts + # e.g., "project1.example.base" -> "project1_" + parts = module_path.split(".") + prefix = "" + if len(parts) > 1: + # Use first part as prefix (e.g., "project1") + prefix = f"{parts[0]}_" + + for tablename, table in metadata.tables.items(): + # Check for conflicts + prefixed_name = f"{prefix}{tablename}" if prefix else tablename + + # If there's a conflict and we have a prefix, use prefixed name + if prefixed_name in merged_metadata.tables and prefix: + logger.warning( + f"Table name conflict: '{tablename}' from {module_path} conflicts. " + f"Using prefixed name: '{prefixed_name}'" + ) + final_name = prefixed_name + elif tablename in merged_metadata.tables: + # Conflict without prefix - use original name (tables are the same) + final_name = tablename + else: + # No conflict + final_name = tablename if not prefix else prefixed_name + + # Copy table to merged metadata + if final_name not in merged_metadata.tables: + if hasattr(table, "to_metadata"): + table.to_metadata(merged_metadata, name=final_name) + else: + table.tometadata(merged_metadata, name=final_name) + + return merged_metadata + + +def get_graph_metadata( *, base_class_path: str, import_module: List[str], include_tables: Set[str], exclude_tables: Set[str], python_dir: List[Path], - format: str, - column_sort: str, - omit_comments: bool = False, - max_enum_members: int = 0, - layout: Optional[Layouts] = None, - type_parameter_delimiter: str = "-", -) -> str: + merge_namespace_metadata: bool = False, +) -> MetaData: + """ + Builds a graph structure by importing modules and returns the filtered MetaData. + This function separates the graph building logic from serialization, allowing + tests to compare MetaData objects directly without parsing strings. + """ # Update the PYTHON_PATH to allow more module imports. sys.path.append(str(os.getcwd())) for dir in python_dir: sys.path.append(str(dir)) - # Import the base class so the metadata class can be extracted from it. - # The metadata class is passed to the transformer. - module_path, class_name = base_class_path.split(":", 2) - base_module = importlib.import_module(module_path) - base_class = getattr(base_module, class_name) - metadata = base_class.metadata + current_root = Path.cwd() + + # Handle base class path with or without wildcards + has_wildcards = "*" in base_class_path or "?" in base_class_path + base_metadata_list = None + + if has_wildcards or (merge_namespace_metadata and has_wildcards): + # Find all matching base classes and merge their metadata + base_metadata_list = _find_base_classes_by_pattern(base_class_path, python_dir, current_root) + + if len(base_metadata_list) > 1: + metadata = _merge_metadata(base_metadata_list) + elif len(base_metadata_list) == 1: + # Single base class found + metadata = base_metadata_list[0][1] + else: + raise ValueError(f"No base classes found matching pattern: {base_class_path}") + else: + # No wildcards, use single base class + module_path, class_name = base_class_path.split(":", 2) + base_module = importlib.import_module(module_path) + base_class = getattr(base_module, class_name) + metadata = base_class.metadata # The modules holding the model classes have to be imported to get put in the metaclass model registry. # These modules aren't actually used in any way, so they are discarded. # They are also imported in scope of this function to prevent namespace pollution. - for module in import_module: - if ":*" in module: - # Sure, execs are gross, but this is the only way to dynamically import wildcards. - exec(f"from {module[:-2]} import *") - else: - importlib.import_module(module) + for module_lookup_mask in import_module: + module_path, import_modifier = module_lookup_mask, None - # Grab a transformer. - if format not in transformers: - raise ValueError(f"Unknown Format: {format}") + if module_path.endswith(":*"): + module_path, import_modifier = module_path.split(":", 1) + + pattern = Pattern(mask=module_path) + + if any(pattern.errors): + raise ValueError(pattern.serialized_errors) + + finder = ModuleFinder(current_root, pattern.tokens) + needs_wildcards_import = import_modifier == "*" + + for file_path in finder.find(): + module_path = to_module_name(current_root, file_path) + + # Check if already loaded to save time + if module_path in sys.modules: + continue + + try: + if needs_wildcards_import: + exec(f"from {module_path} import *") + else: + importlib.import_module(module_path) + except ImportError as e: + logger.error(f"Failed to load {module_path}: {e}") + raise e + + # If we merged metadata from multiple base classes, we need to re-merge after models are imported + # because models register themselves in the original base class metadata, not the merged one + if base_metadata_list and len(base_metadata_list) > 1: + # Re-collect metadata from all base classes after models have been imported + updated_metadata_list = [] + for module_path, _ in base_metadata_list: + try: + base_module = importlib.import_module(module_path) + # Extract class name from base_class_path + _, class_name = base_class_path.split(":", 2) + if hasattr(base_module, class_name): + base_class = getattr(base_module, class_name) + updated_metadata_list.append((module_path, base_class.metadata)) + except (ImportError, AttributeError) as e: + logger.warning(f"Could not re-import base class from {module_path}: {e}") + continue + + if updated_metadata_list: + metadata = _merge_metadata(updated_metadata_list) # Keep only the tables which were included / not-excluded include_tables = resolve_included_tables( @@ -65,20 +247,77 @@ def get_graph_string( ) filtered_metadata = filter_metadata(metadata=metadata, include_tables=include_tables) - # Save the graph structure to string. - # Note: type_parameter_delimiter only applies to Mermaid transformer + return filtered_metadata + + +def serialize_metadata( + metadata: MetaData, + *, + format: str, + column_sort: str, + omit_comments: bool = False, + max_enum_members: int = 0, + layout: Optional[Layouts] = None, + type_parameter_delimiter: str = "-", +) -> str: + """ + Serializes MetaData to a string representation in the specified format. + """ + if format not in transformers: + raise ValueError(f"Unknown Format: {format}") + # Mermaid supports extra options (enum truncation + type parameter delimiter sanitization). if format in ["mermaid", "mmd"]: return str( Mermaid( - filtered_metadata, + metadata, column_sort, omit_comments=omit_comments, + max_enum_members=max_enum_members, layout=layout, type_parameter_delimiter=type_parameter_delimiter, ) ) - else: - return str(Dot(filtered_metadata, column_sort, omit_comments=omit_comments)) + + return str(Dot(metadata, column_sort, omit_comments=omit_comments, layout=layout)) + + +def get_graph_string( + *, + base_class_path: str, + import_module: List[str], + include_tables: Set[str], + exclude_tables: Set[str], + python_dir: List[Path], + format: str, + column_sort: str, + omit_comments: bool = False, + max_enum_members: int = 0, + layout: Optional[Layouts] = None, + type_parameter_delimiter: str = "-", +) -> str: + """ + Builds a graph structure and returns it as a serialized string. + + This is a convenience wrapper that combines get_graph_metadata() and serialize_metadata() + for backward compatibility. + """ + metadata = get_graph_metadata( + base_class_path=base_class_path, + import_module=import_module, + include_tables=include_tables, + exclude_tables=exclude_tables, + python_dir=python_dir, + ) + + return serialize_metadata( + metadata, + format=format, + column_sort=column_sort, + omit_comments=omit_comments, + max_enum_members=max_enum_members, + layout=layout, + type_parameter_delimiter=type_parameter_delimiter, + ) def resolve_included_tables( @@ -135,3 +374,120 @@ def filter_metadata( table = table.tometadata(filtered_metadata) return filtered_metadata + + +def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = False) -> None: + """ + Compares two MetaData objects and raises AssertionError if they differ. + This function performs a structural comparison of two graph representations, + checking tables, columns, types, constraints, and relationships. + """ + + actual_tables = set(actual.tables.keys()) + expected_tables = set(expected.tables.keys()) + + # Check table names + if actual_tables != expected_tables: + missing = expected_tables - actual_tables + extra = actual_tables - expected_tables + error_msg = "Table mismatch:\n" + if missing: + error_msg += f" Missing tables: {missing}\n" + if extra: + error_msg += f" Extra tables: {extra}\n" + raise AssertionError(error_msg) + + # Check each table's structure + for table_name in expected_tables: + actual_table = actual.tables[table_name] + expected_table = expected.tables[table_name] + + actual_columns = {col.name: col for col in actual_table.columns} + expected_columns = {col.name: col for col in expected_table.columns} + + # Check column names + if set(actual_columns.keys()) != set(expected_columns.keys()): + missing = set(expected_columns.keys()) - set(actual_columns.keys()) + extra = set(actual_columns.keys()) - set(expected_columns.keys()) + error_msg = f"Column mismatch in table '{table_name}':\n" + if missing: + error_msg += f" Missing columns: {missing}\n" + if extra: + error_msg += f" Extra columns: {extra}\n" + raise AssertionError(error_msg) + + # Check each column's properties + for col_name in expected_columns.keys(): + actual_col = actual_columns[col_name] + expected_col = expected_columns[col_name] + + # Check type + actual_type_str = str(actual_col.type) + expected_type_str = str(expected_col.type) + if actual_type_str != expected_type_str: + raise AssertionError( + f"Type mismatch in table '{table_name}', column '{col_name}': " + f"expected {expected_type_str}, got {actual_type_str}" + ) + + # Check constraints + actual_pk = actual_col.primary_key + expected_pk = expected_col.primary_key + if actual_pk != expected_pk: + raise AssertionError( + f"Primary key mismatch in table '{table_name}', column '{col_name}': " + f"expected {expected_pk}, got {actual_pk}" + ) + + actual_fk_count = len(actual_col.foreign_keys) + expected_fk_count = len(expected_col.foreign_keys) + if actual_fk_count != expected_fk_count: + raise AssertionError( + f"Foreign key count mismatch in table '{table_name}', column '{col_name}': " + f"expected {expected_fk_count}, got {actual_fk_count}" + ) + + # Check nullable + if actual_col.nullable != expected_col.nullable: + raise AssertionError( + f"Nullable mismatch in table '{table_name}', column '{col_name}': " + f"expected {expected_col.nullable}, got {actual_col.nullable}" + ) + + # Check comments (if not omitted) + if not omit_comments: + actual_comment = actual_col.comment + expected_comment = expected_col.comment + if actual_comment != expected_comment: + raise AssertionError( + f"Comment mismatch in table '{table_name}', column '{col_name}': " + f"expected {expected_comment!r}, got {actual_comment!r}" + ) + + # Check foreign key relationships + actual_fks: set[tuple[str, ...]] = set() + for col in actual_table.columns: + for fk in col.foreign_keys: + # Format: (table_name, column_name) -> (target_table, target_column) + target_parts = fk.target_fullname.split(".") + target_table = ".".join(target_parts[:-1]) + target_column = target_parts[-1] + actual_fks.add((table_name, col.name, target_table, target_column)) + + expected_fks: set[tuple[str, ...]] = set() + for col in expected_table.columns: + for fk in col.foreign_keys: + target_parts = fk.target_fullname.split(".") + target_table = ".".join(target_parts[:-1]) + target_column = target_parts[-1] + expected_fks.add((table_name, col.name, target_table, target_column)) + + if actual_fks != expected_fks: + missing_fks: set[tuple[str, ...]] = expected_fks - actual_fks + extra_fks: set[tuple[str, ...]] = actual_fks - expected_fks + error_msg = f"Foreign key mismatch in table '{table_name}':\n" + if missing: + error_msg += f" Missing FKs: {missing_fks}\n" + if extra: + error_msg += f" Extra FKs: {extra_fks}\n" + raise AssertionError(error_msg) diff --git a/paracelsus/models/__init__.py b/paracelsus/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/paracelsus/models/base.py b/paracelsus/models/base.py new file mode 100644 index 0000000..bcb1aa0 --- /dev/null +++ b/paracelsus/models/base.py @@ -0,0 +1,47 @@ +from typing import Callable, Generic, Protocol, Sequence, TypeVar, Type + +ReturnT = TypeVar("ReturnT") + + +class ValidationError(ValueError): + pass + + +class ErrorContainer(Protocol): + def add_error(self, error: ValidationError) -> None: ... + + @property + def errors(self) -> list[ValidationError]: ... + + +class Attribute(Generic[ReturnT]): + """ + A simple descriptor to implement attributes validation upon assignment. + """ + + def __init__(self, type: Type, validators: Sequence[Callable[[str, ReturnT], ReturnT]] = ()): + self.type = type + self.validators = validators + + def __set_name__(self, owner, name): + self.name = name + + def __get__(self, instance: ErrorContainer, owner): + if not instance: + return self + return instance.__dict__[self.name] + + def __delete__(self, instance: ErrorContainer): + del instance.__dict__[self.name] + + def __set__(self, instance: ErrorContainer, value): + if not isinstance(value, self.type): + raise TypeError(f"{self.name!r} values must be of type {self.type!r}") + + for validator in self.validators: + try: + validator(self.name, value) + except ValidationError as e: + instance.add_error(e) + + instance.__dict__[self.name] = value diff --git a/paracelsus/models/pattern.py b/paracelsus/models/pattern.py new file mode 100644 index 0000000..0a2f79c --- /dev/null +++ b/paracelsus/models/pattern.py @@ -0,0 +1,79 @@ +import re +from dataclasses import dataclass, field + +from .base import Attribute, ValidationError + + +def forbid_empty_path(name: str, value: str) -> str: + if not value: + raise ValidationError("Empty path not allowed") + return value + + +def forbid_wildcard_for_modules(name: str, value: str) -> str: + if value.endswith("**"): + raise ValidationError("Wildcard (**) not allowed for modules scope") + return value + + +def forbid_empty_segment(name: str, value: str) -> str: + if any(not segment for segment in value.split(".")): + raise ValidationError("Empty segment not allowed") + return value + + +def enforce_globbing_grammar(name: str, value: str) -> str: + regex = re.compile(r"^[a-zA-Z0-9_*]+(\.[a-zA-Z0-9_*]+)*$") + if not regex.fullmatch(value): + raise ValidationError("Invalid globbing pattern") + return value + + +def forbid_greedy_lookup(name: str, value: str) -> str: + tokens = value.split(".") + for i in range(len(tokens) - 1): + current = tokens[i] + next_token = tokens[i + 1] + + if current == "**" and next_token == "*": + raise ValidationError( + f"Invalid Mask: '**. *' is ambiguous and forbidden. " + f"Found at segment {i}: '...{current}.{next_token}...'" + ) + + return value + + +@dataclass(init=True, frozen=True) +class Pattern: + _errors: list[ValidationError] = field(default_factory=list) + mask: Attribute[str] = field( + default=Attribute[str]( + str, + validators=[ + enforce_globbing_grammar, + forbid_wildcard_for_modules, + forbid_greedy_lookup, + forbid_empty_segment, + forbid_empty_path, + ], + ) + ) + + def add_error(self, error: ValidationError) -> None: + self._errors.append(error) + + @property + def errors(self) -> list[ValidationError]: + return self._errors + + @property + def tokens(self) -> list[str]: + return self.mask.split(".") + + @property + def serialized_errors(self) -> str: + messages = [f"Errors found for {self.mask}:"] + for error in self.errors: + messages.append(f" - {error}") + return "\n".join(messages) diff --git a/pyproject.toml b/pyproject.toml index 33fa147..23e47ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ requires-python = ">= 3.10" dev = [ "build", "dapperdata", + "pre-commit>=4.5.1", "glom", "mypy", "pip-tools", diff --git a/tests/assets/character_classes/example/__init__.py b/tests/assets/character_classes/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/__init__.py b/tests/assets/character_classes/example/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v0/__init__.py b/tests/assets/character_classes/example/api/v0/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v0/models.py b/tests/assets/character_classes/example/api/v0/models.py new file mode 100644 index 0000000..3d6ecdd --- /dev/null +++ b/tests/assets/character_classes/example/api/v0/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V0Model(Base): + __tablename__ = "v0_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v1/__init__.py b/tests/assets/character_classes/example/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v1/models.py b/tests/assets/character_classes/example/api/v1/models.py new file mode 100644 index 0000000..28dc046 --- /dev/null +++ b/tests/assets/character_classes/example/api/v1/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V1Model(Base): + __tablename__ = "v1_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v10/__init__.py b/tests/assets/character_classes/example/api/v10/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v10/models.py b/tests/assets/character_classes/example/api/v10/models.py new file mode 100644 index 0000000..409bad9 --- /dev/null +++ b/tests/assets/character_classes/example/api/v10/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V10Model(Base): + __tablename__ = "v10_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v2/__init__.py b/tests/assets/character_classes/example/api/v2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v2/models.py b/tests/assets/character_classes/example/api/v2/models.py new file mode 100644 index 0000000..4217946 --- /dev/null +++ b/tests/assets/character_classes/example/api/v2/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V2Model(Base): + __tablename__ = "v2_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v3/__init__.py b/tests/assets/character_classes/example/api/v3/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v3/models.py b/tests/assets/character_classes/example/api/v3/models.py new file mode 100644 index 0000000..2061e93 --- /dev/null +++ b/tests/assets/character_classes/example/api/v3/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V3Model(Base): + __tablename__ = "v3_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v4/__init__.py b/tests/assets/character_classes/example/api/v4/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v4/models.py b/tests/assets/character_classes/example/api/v4/models.py new file mode 100644 index 0000000..05dc8b8 --- /dev/null +++ b/tests/assets/character_classes/example/api/v4/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V4Model(Base): + __tablename__ = "v4_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v5/__init__.py b/tests/assets/character_classes/example/api/v5/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v5/models.py b/tests/assets/character_classes/example/api/v5/models.py new file mode 100644 index 0000000..cd373a4 --- /dev/null +++ b/tests/assets/character_classes/example/api/v5/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V5Model(Base): + __tablename__ = "v5_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v6/__init__.py b/tests/assets/character_classes/example/api/v6/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v6/models.py b/tests/assets/character_classes/example/api/v6/models.py new file mode 100644 index 0000000..2ffb0a6 --- /dev/null +++ b/tests/assets/character_classes/example/api/v6/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V6Model(Base): + __tablename__ = "v6_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v7/__init__.py b/tests/assets/character_classes/example/api/v7/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v7/models.py b/tests/assets/character_classes/example/api/v7/models.py new file mode 100644 index 0000000..c89a94f --- /dev/null +++ b/tests/assets/character_classes/example/api/v7/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V7Model(Base): + __tablename__ = "v7_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v8/__init__.py b/tests/assets/character_classes/example/api/v8/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v8/models.py b/tests/assets/character_classes/example/api/v8/models.py new file mode 100644 index 0000000..7093ecc --- /dev/null +++ b/tests/assets/character_classes/example/api/v8/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V8Model(Base): + __tablename__ = "v8_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v9/__init__.py b/tests/assets/character_classes/example/api/v9/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v9/models.py b/tests/assets/character_classes/example/api/v9/models.py new file mode 100644 index 0000000..410002d --- /dev/null +++ b/tests/assets/character_classes/example/api/v9/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V9Model(Base): + __tablename__ = "v9_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/va/__init__.py b/tests/assets/character_classes/example/api/va/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/va/models.py b/tests/assets/character_classes/example/api/va/models.py new file mode 100644 index 0000000..787b4c4 --- /dev/null +++ b/tests/assets/character_classes/example/api/va/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class VaModel(Base): + __tablename__ = "va_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/vb/__init__.py b/tests/assets/character_classes/example/api/vb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/vb/models.py b/tests/assets/character_classes/example/api/vb/models.py new file mode 100644 index 0000000..a9bc41d --- /dev/null +++ b/tests/assets/character_classes/example/api/vb/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class VbModel(Base): + __tablename__ = "vb_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/base.py b/tests/assets/character_classes/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/character_classes/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/multi_star/example/__init__.py b/tests/assets/multi_star/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/multi_star/example/base.py b/tests/assets/multi_star/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/multi_star/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/multi_star/example/v1/api/users/__init__.py b/tests/assets/multi_star/example/v1/api/users/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/multi_star/example/v1/api/users/models.py b/tests/assets/multi_star/example/v1/api/users/models.py new file mode 100644 index 0000000..bd17ca9 --- /dev/null +++ b/tests/assets/multi_star/example/v1/api/users/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ....base import Base + + +class V1User(Base): + __tablename__ = "v1_users" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/multi_star/example/v2/api/products/__init__.py b/tests/assets/multi_star/example/v2/api/products/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/multi_star/example/v2/api/products/models.py b/tests/assets/multi_star/example/v2/api/products/models.py new file mode 100644 index 0000000..fc1416c --- /dev/null +++ b/tests/assets/multi_star/example/v2/api/products/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ....base import Base + + +class V2Product(Base): + __tablename__ = "v2_products" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/namespace/project1/example/base.py b/tests/assets/namespace/project1/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/namespace/project1/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/namespace/project1/example/subpackage_a/__init__.py b/tests/assets/namespace/project1/example/subpackage_a/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/namespace/project1/example/subpackage_a/models.py b/tests/assets/namespace/project1/example/subpackage_a/models.py new file mode 100644 index 0000000..762e170 --- /dev/null +++ b/tests/assets/namespace/project1/example/subpackage_a/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ..base import Base + + +class SubpackageAModel(Base): + __tablename__ = "subpackage_a_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/namespace/project2/example/base.py b/tests/assets/namespace/project2/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/namespace/project2/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/namespace/project2/example/subpackage_b/__init__.py b/tests/assets/namespace/project2/example/subpackage_b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/namespace/project2/example/subpackage_b/models.py b/tests/assets/namespace/project2/example/subpackage_b/models.py new file mode 100644 index 0000000..3bb28a9 --- /dev/null +++ b/tests/assets/namespace/project2/example/subpackage_b/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ..base import Base + + +class SubpackageBModel(Base): + __tablename__ = "subpackage_b_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/nested/example/__init__.py b/tests/assets/nested/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/nested/example/api/v1/__init__.py b/tests/assets/nested/example/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/nested/example/api/v1/models.py b/tests/assets/nested/example/api/v1/models.py new file mode 100644 index 0000000..fe4a781 --- /dev/null +++ b/tests/assets/nested/example/api/v1/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class APIResource(Base): + __tablename__ = "api_resources" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/nested/example/base.py b/tests/assets/nested/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/nested/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/nested/example/domain/products/__init__.py b/tests/assets/nested/example/domain/products/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/nested/example/domain/products/models.py b/tests/assets/nested/example/domain/products/models.py new file mode 100644 index 0000000..b1f8cb7 --- /dev/null +++ b/tests/assets/nested/example/domain/products/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class Product(Base): + __tablename__ = "products" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/nested/example/domain/users/__init__.py b/tests/assets/nested/example/domain/users/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/nested/example/domain/users/models.py b/tests/assets/nested/example/domain/users/models.py new file mode 100644 index 0000000..dfd6458 --- /dev/null +++ b/tests/assets/nested/example/domain/users/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class User(Base): + __tablename__ = "users" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/recursive/example/__init__.py b/tests/assets/recursive/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/api/__init__.py b/tests/assets/recursive/example/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/api/v1/__init__.py b/tests/assets/recursive/example/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/api/v1/models.py b/tests/assets/recursive/example/api/v1/models.py new file mode 100644 index 0000000..fe4a781 --- /dev/null +++ b/tests/assets/recursive/example/api/v1/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class APIResource(Base): + __tablename__ = "api_resources" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/recursive/example/base.py b/tests/assets/recursive/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/recursive/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/recursive/example/domain/__init__.py b/tests/assets/recursive/example/domain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/domain/users/__init__.py b/tests/assets/recursive/example/domain/users/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/domain/users/models.py b/tests/assets/recursive/example/domain/users/models.py new file mode 100644 index 0000000..dfd6458 --- /dev/null +++ b/tests/assets/recursive/example/domain/users/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class User(Base): + __tablename__ = "users" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/recursive/example/level1/__init__.py b/tests/assets/recursive/example/level1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/level1/level2/__init__.py b/tests/assets/recursive/example/level1/level2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/level1/level2/api/__init__.py b/tests/assets/recursive/example/level1/level2/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/level1/level2/api/v3/__init__.py b/tests/assets/recursive/example/level1/level2/api/v3/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/level1/level2/api/v3/models.py b/tests/assets/recursive/example/level1/level2/api/v3/models.py new file mode 100644 index 0000000..58f437d --- /dev/null +++ b/tests/assets/recursive/example/level1/level2/api/v3/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ......base import Base + + +class DeepAPIResource(Base): + __tablename__ = "deep_api_resources" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/recursive/example/something/__init__.py b/tests/assets/recursive/example/something/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/something/api/__init__.py b/tests/assets/recursive/example/something/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/something/api/v2/__init__.py b/tests/assets/recursive/example/something/api/v2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/something/api/v2/models.py b/tests/assets/recursive/example/something/api/v2/models.py new file mode 100644 index 0000000..e17bcbe --- /dev/null +++ b/tests/assets/recursive/example/something/api/v2/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ....base import Base + + +class SomethingAPIResource(Base): + __tablename__ = "something_api_resources" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/single_level/example/__init__.py b/tests/assets/single_level/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/single_level/example/bar/__init__.py b/tests/assets/single_level/example/bar/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/single_level/example/bar/models.py b/tests/assets/single_level/example/bar/models.py new file mode 100644 index 0000000..1893459 --- /dev/null +++ b/tests/assets/single_level/example/bar/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ..base import Base + + +class BarModel(Base): + __tablename__ = "bar_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/single_level/example/base.py b/tests/assets/single_level/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/single_level/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/single_level/example/foo/__init__.py b/tests/assets/single_level/example/foo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/single_level/example/foo/models.py b/tests/assets/single_level/example/foo/models.py new file mode 100644 index 0000000..aae5a59 --- /dev/null +++ b/tests/assets/single_level/example/foo/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ..base import Base + + +class FooModel(Base): + __tablename__ = "foo_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/conftest.py b/tests/conftest.py index 0aab644..2c5a563 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import shutil import sys import tempfile +from functools import singledispatch from collections.abc import Generator from datetime import datetime, timezone from pathlib import Path @@ -15,6 +16,42 @@ UTC = timezone.utc +@singledispatch +def setup_sys_path_for_test(paths, module_prefix: str = "example") -> list[str]: + # Clear module cache + for name in list(sys.modules.keys()): + if name == module_prefix or name.startswith(f"{module_prefix}."): + del sys.modules[name] + + # Add paths to sys.path + path_strings = [] + for path in paths: + path_str = str(path) + path_strings.append(path_str) + sys.path.insert(0, path_str) + + return path_strings + + +@setup_sys_path_for_test.register +def _(paths: Path, module_prefix: str = "example") -> str: + result = setup_sys_path_for_test([paths], module_prefix) + + return result[0] + + +@singledispatch +def cleanup_sys_path(paths) -> None: + for path_str in paths: + if path_str in sys.path: + sys.path.remove(path_str) + + +@cleanup_sys_path.register +def _(paths: str) -> None: + cleanup_sys_path([paths]) + + @pytest.fixture def metaclass(): Base = declarative_base() @@ -262,3 +299,113 @@ def fixture_expected_mermaid_cardinalities_graph() -> str: ``` """) + + +@pytest.fixture +def single_level_package_path() -> Generator[Path, None, None]: + """Create a package structure with single-level subpackages for testing pattern example.*.models.""" + template_path = Path(os.path.dirname(os.path.realpath(__file__))) / "assets" / "single_level" + with tempfile.TemporaryDirectory() as package_path: + shutil.copytree(template_path, package_path, dirs_exist_ok=True) + package_dir = Path(package_path) + os.chdir(package_path) + + path_str = setup_sys_path_for_test(package_dir) + try: + yield Path(package_path) + finally: + cleanup_sys_path(path_str) + + +@pytest.fixture +def nested_package_path() -> Generator[Path, None, None]: + """Create a package structure with nested subpackages for testing multi-level glob patterns.""" + template_path = Path(os.path.dirname(os.path.realpath(__file__))) / "assets" / "nested" + with tempfile.TemporaryDirectory() as package_path: + shutil.copytree(template_path, package_path, dirs_exist_ok=True) + package_dir = Path(package_path) + os.chdir(package_path) + + path_str = setup_sys_path_for_test(package_dir) + try: + yield Path(package_path) + finally: + cleanup_sys_path(path_str) + + +@pytest.fixture +def multi_star_package_path() -> Generator[Path, None, None]: + """Create a package structure for testing patterns with multiple stars.""" + template_path = Path(os.path.dirname(os.path.realpath(__file__))) / "assets" / "multi_star" + with tempfile.TemporaryDirectory() as package_path: + shutil.copytree(template_path, package_path, dirs_exist_ok=True) + package_dir = Path(package_path) + os.chdir(package_path) + + path_str = setup_sys_path_for_test(package_dir) + try: + yield Path(package_path) + finally: + cleanup_sys_path(path_str) + + +@pytest.fixture +def character_classes_package_path() -> Generator[Path, None, None]: + """Create a package structure for testing character classes and ranges patterns.""" + template_path = Path(os.path.dirname(os.path.realpath(__file__))) / "assets" / "character_classes" + with tempfile.TemporaryDirectory() as package_path: + shutil.copytree(template_path, package_path, dirs_exist_ok=True) + package_dir = Path(package_path) + os.chdir(package_path) + + path_str = setup_sys_path_for_test(package_dir) + try: + yield Path(package_path) + finally: + cleanup_sys_path(path_str) + + +@pytest.fixture +def recursive_package_path() -> Generator[Path, None, None]: + """Create a package structure for testing recursive lookup patterns (**).""" + template_path = Path(os.path.dirname(os.path.realpath(__file__))) / "assets" / "recursive" + with tempfile.TemporaryDirectory() as package_path: + shutil.copytree(template_path, package_path, dirs_exist_ok=True) + package_dir = Path(package_path) + os.chdir(package_path) + + path_str = setup_sys_path_for_test(package_dir) + try: + yield Path(package_path) + finally: + cleanup_sys_path(path_str) + + +@pytest.fixture +def namespace_package_path() -> Generator[Path, None, None]: + """Create a namespace package structure (PEP 420) for testing.""" + template_base = Path(os.path.dirname(os.path.realpath(__file__))) / "assets" / "namespace" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Copy both projects + for project_key in range(1, 3): + project_dir = f"project{project_key}" + shutil.copytree(template_base / project_dir, temp_path / project_dir, dirs_exist_ok=True) + + project1 = temp_path / "project1" + project2 = temp_path / "project2" + + os.chdir(str(temp_path)) + + path_strings = setup_sys_path_for_test([project1, project2]) + + # Import example to check that it is a namespace package + import example + + assert hasattr(example, "__path__"), "example should be a namespace package with __path__ attribute" + + try: + yield temp_path + finally: + cleanup_sys_path(path_strings) diff --git a/tests/test_graph.py b/tests/test_graph.py index 2460aa7..e124e5b 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,13 +1,33 @@ import pytest +import importlib +import sys +from pathlib import Path from paracelsus.config import Layouts -from paracelsus.graph import get_graph_string +from paracelsus.graph import get_graph_string, get_graph_metadata, compare_metadata +from paracelsus.finders import ModuleFinder +from paracelsus.models.pattern import Pattern +from paracelsus.graph import to_module_name from .utils import mermaid_assert @pytest.mark.parametrize("column_sort_arg", ["key-based", "preserve-order"]) -def test_get_graph_string(column_sort_arg, package_path): +def test_get_graph_string(column_sort_arg, package_path, metaclass): + """Test get_graph_string with dynamic metadata comparison.""" + # Get actual metadata + actual_metadata = get_graph_metadata( + base_class_path="example.base:Base", + import_module=["example.models"], + include_tables=set(), + exclude_tables=set(), + python_dir=[package_path], + ) + + # Compare with expected metadata from fixture + mermaid_assert(actual_metadata, expected=metaclass) + + # Also test that serialization still works graph_string = get_graph_string( base_class_path="example.base:Base", import_module=["example.models"], @@ -17,9 +37,80 @@ def test_get_graph_string(column_sort_arg, package_path): format="mermaid", column_sort=column_sort_arg, ) + # Legacy string assertion for backward compatibility mermaid_assert(graph_string) +def test_get_graph_string_with_wildcard(single_level_package_path): + """Test that wildcard patterns work correctly with dynamic metadata comparison.""" + actual_metadata = get_graph_metadata( + base_class_path="example.base:Base", + import_module=["example.*.models"], + include_tables=set(), + exclude_tables=set(), + python_dir=[single_level_package_path], + ) + + # Build expected metadata by manually importing all matching modules + sys.path.insert(0, str(single_level_package_path)) + try: + # Find all modules matching the pattern + pattern = Pattern(mask="example.*.models") + current_root = Path.cwd() + finder = ModuleFinder(current_root, pattern.tokens) + + # Import all matching modules + for file_path in finder.find(): + module_path = to_module_name(current_root, file_path) + importlib.import_module(module_path) + + # Get expected metadata from base class + base_module = importlib.import_module("example.base") + base_class = getattr(base_module, "Base") + expected_metadata = base_class.metadata + + # Compare metadata + compare_metadata(actual_metadata, expected_metadata) + finally: + # Cleanup + if str(single_level_package_path) in sys.path: + sys.path.remove(str(single_level_package_path)) + # Clear imported modules + for name in list(sys.modules.keys()): + if name.startswith("example."): + del sys.modules[name] + + +def test_get_graph_with_wildcard_mask_in_namespace_package(namespace_package_path): + """Test namespace packages with wildcard patterns and merged metadata.""" + # Get actual metadata with namespace merging enabled + actual_metadata = get_graph_metadata( + base_class_path="project*.example.base:Base", + import_module=["project*.example.*.models"], + include_tables=set(), + exclude_tables=set(), + python_dir=[namespace_package_path], + merge_namespace_metadata=True, + ) + + # Verify that we have tables from both projects + table_names = set(actual_metadata.tables.keys()) + # Should have tables from both project1 and project2 + assert "subpackage_a_table" in table_names or "project1_subpackage_a_table" in table_names + assert "subpackage_b_table" in table_names or "project2_subpackage_b_table" in table_names + + # Verify the graph can be serialized + from paracelsus.graph import serialize_metadata + + graph_string = serialize_metadata( + actual_metadata, + format="mermaid", + column_sort="key-based", + ) + assert "subpackage_a_table" in graph_string or "project1_subpackage_a_table" in graph_string + assert "subpackage_b_table" in graph_string or "project2_subpackage_b_table" in graph_string + + def test_get_graph_string_with_exclude(package_path): """Excluding tables removes them from the graph string.""" graph_string = get_graph_string( @@ -82,7 +173,21 @@ def test_get_graph_string_with_include(package_path): @pytest.mark.parametrize("layout_arg", ["dagre", "elk"]) -def test_get_graph_string_with_layout(layout_arg, package_path): +def test_get_graph_string_with_layout(layout_arg, package_path, metaclass): + """Test get_graph_string with layout using dynamic metadata comparison.""" + # Get actual metadata + actual_metadata = get_graph_metadata( + base_class_path="example.base:Base", + import_module=["example.models"], + include_tables=set(), + exclude_tables=set(), + python_dir=[package_path], + ) + + # Compare with expected metadata + mermaid_assert(actual_metadata, expected=metaclass) + + # Also test serialization with layout graph_string = get_graph_string( base_class_path="example.base:Base", import_module=["example.models"], @@ -93,4 +198,46 @@ def test_get_graph_string_with_layout(layout_arg, package_path): column_sort="key-based", layout=Layouts(layout_arg), ) + # Legacy string assertion mermaid_assert(graph_string) + + +def test_get_graph_string_with_nested_glob_pattern(nested_package_path): + """Integration test: get_graph_string with nested glob pattern.""" + + graph_string = get_graph_string( + base_class_path="example.base:Base", + import_module=["example.*.*.models"], + include_tables=set(), + exclude_tables=set(), + python_dir=[nested_package_path], + format="mermaid", + column_sort="key-based", + ) + + assert "users {" in graph_string or "products {" in graph_string or "api_resources {" in graph_string + + +def test_compare_metadata(metaclass): + """Test compare_metadata function directly.""" + # Same metadata should compare successfully + compare_metadata(metaclass, metaclass) + + # Different metadata should raise AssertionError + from sqlalchemy.orm import declarative_base + from sqlalchemy import String, Uuid + from sqlalchemy.orm import mapped_column + from uuid import uuid4 + + Base2 = declarative_base() + + class DifferentTable(Base2): + __tablename__ = "different_table" + id = mapped_column(Uuid, primary_key=True, default=uuid4()) + name = mapped_column(String(100)) + + different_metadata = Base2.metadata + + # Should raise AssertionError when comparing different metadata + with pytest.raises(AssertionError): + compare_metadata(metaclass, different_metadata) diff --git a/tests/transformers/test_find_modules.py b/tests/transformers/test_find_modules.py new file mode 100644 index 0000000..63d0311 --- /dev/null +++ b/tests/transformers/test_find_modules.py @@ -0,0 +1,243 @@ +import pytest +from pathlib import Path + +from paracelsus.finders import ModuleFinder +from paracelsus.graph import to_module_name +from paracelsus.models.pattern import Pattern + + +def test_find_modules_by_pattern_single_level(single_level_package_path): + """Test basic glob pattern matching with single-level subpackages.""" + + pattern = Pattern(mask="example.*.models") + found = [ + to_module_name(Path.cwd(), module_path) + for module_path in ModuleFinder(single_level_package_path, pattern.tokens).find() + ] + expected_modules = { + "example.foo.models", + "example.bar.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + +def test_find_modules_by_pattern_nested_levels(nested_package_path): + """Test glob pattern with nested levels (example.*.*.models).""" + + pattern = Pattern(mask="example.*.*.models") + found = [ + to_module_name(Path.cwd(), module_path) + for module_path in ModuleFinder(nested_package_path, pattern.tokens).find() + ] + expected_modules = { + "example.domain.users.models", + "example.domain.products.models", + "example.api.v1.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + +def test_find_modules_by_pattern_multiple_stars(multi_star_package_path): + """Test glob pattern with multiple stars (example.*.api.*.models).""" + + pattern = Pattern(mask="example.*.api.*.models") + found = [ + to_module_name(Path.cwd(), module_path) + for module_path in ModuleFinder(multi_star_package_path, pattern.tokens).find() + ] + expected_modules = { + "example.v1.api.users.models", + "example.v2.api.products.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + +def test_find_modules_by_pattern_namespace_package(namespace_package_path): + """Test glob pattern with namespace packages (PEP 420). + + Should handle namespace packages where __path__ is a list of paths. + """ + + pattern = Pattern(mask="project*.example.*.models") + found = [ + to_module_name(Path.cwd(), module_path) + for module_path in ModuleFinder(namespace_package_path, pattern.tokens).find() + ] + expected_modules = { + "project1.example.subpackage_a.models", + "project2.example.subpackage_b.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + +def test_find_modules_by_pattern_single_character(single_level_package_path): + """Test glob pattern with single character matching (example.fo?.models).""" + + pattern = Pattern(mask="example.fo?.models") + found = [ + to_module_name(Path.cwd(), module_path) + for module_path in ModuleFinder(single_level_package_path, pattern.tokens).find() + ] + expected_modules = { + "example.foo.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + assert "example.bar.models" not in found + + +def test_find_modules_by_pattern_character_class(character_classes_package_path): + """Test glob pattern with character class (example.api.v[12].models). + + Character class [12] matches exactly one character: '1' or '2'. + """ + pattern = Pattern(mask="example.api.v[12].models") + found = [ + to_module_name(Path.cwd(), module_path) + for module_path in ModuleFinder(character_classes_package_path, pattern.tokens).find() + ] + expected_modules = { + "example.api.v1.models", + "example.api.v2.models", + } + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + assert "example.api.v3.models" not in found + + +def test_find_modules_by_pattern_character_range(character_classes_package_path): + """Test glob pattern with character range (example.api.v[0-9].models). + + Character range [0-9] matches exactly one digit from 0 to 9. + """ + pattern = Pattern(mask="example.api.v[0-9].models") + found = [ + to_module_name(Path.cwd(), module_path) + for module_path in ModuleFinder(character_classes_package_path, pattern.tokens).find() + ] + expected_modules = { + "example.api.v0.models", + "example.api.v1.models", + "example.api.v2.models", + "example.api.v3.models", + "example.api.v4.models", + "example.api.v5.models", + "example.api.v6.models", + "example.api.v7.models", + "example.api.v8.models", + "example.api.v9.models", + } + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + assert "example.api.v10.models" not in found + assert "example.api.va.models" not in found + + +def test_find_modules_by_pattern_complementation_character_class(character_classes_package_path): + """Test glob pattern with complementation character class (example.api.v[!1].models). + + Complementation [!1] matches any single character except '1'. + """ + pattern = Pattern(mask="example.api.v[!1].models") + found = [ + to_module_name(Path.cwd(), module_path) + for module_path in ModuleFinder(character_classes_package_path, pattern.tokens).find() + ] + + assert "example.api.v0.models" in found + assert "example.api.v2.models" in found + assert "example.api.va.models" in found + assert "example.api.v1.models" not in found + + +def test_find_modules_by_pattern_complementation_character_range(character_classes_package_path): + """Test glob pattern with complementation range (example.api.v[!0-9].models). + + Complementation [!0-9] matches any single character except digits 0-9. + """ + pattern = Pattern(mask="example.api.v[!0-9].models") + found = [ + to_module_name(Path.cwd(), module_path) + for module_path in ModuleFinder(character_classes_package_path, pattern.tokens).find() + ] + + assert "example.api.va.models" in found + assert "example.api.vb.models" in found + + for i in range(10): + assert f"example.api.v{i}.models" not in found + + +def test_find_modules_by_pattern_mixed_wildcards(multi_star_package_path): + """Test glob pattern with mixed wildcards (example.v?.*.*.models). + + Pattern combines single character match (v?) with any string matches (*). + v? matches one char (v1, v2, va, etc.), then *.* matches two package levels. + Example: example.v1.api.users.models, example.v2.api.products.models + """ + pattern = Pattern(mask="example.v?.*.*.models") + found = [ + to_module_name(Path.cwd(), module_path) + for module_path in ModuleFinder(multi_star_package_path, pattern.tokens).find() + ] + expected_modules = { + "example.v1.api.users.models", + "example.v2.api.products.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + +def test_find_modules_by_pattern_recursive_lookup(recursive_package_path): + """Test glob pattern with recursive lookup (example.**.api.*.models). + + Recursive lookup (**) matches any number of package levels (0 or more). + Should find modules at different depths: + - example.api.v1.models (0 levels between example and api) + - example.something.api.v2.models (1 level deep) + - example.level1.level2.api.v3.models (2 levels deep) + """ + pattern = Pattern(mask="example.**.api.*.models") + found = [ + to_module_name(Path.cwd(), module_path) + for module_path in ModuleFinder(recursive_package_path, pattern.tokens).find() + ] + expected_modules = { + "example.api.v1.models", + "example.something.api.v2.models", + "example.level1.level2.api.v3.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + # Should not find modules that don't have 'api' in their path + assert "example.domain.users.models" not in found + + +@pytest.mark.parametrize( + "pattern", + [ + "example.v,.models", # Wrong grammar token + "example.v?..models", # Empty tokens + "example.v[1.models", # Unclosed + "example.v]1[.models", # Reversed + "example.v[1[2]].models", # Nested + "example.v[**.models", # Invalid recursive lookup + "example.**.*.models", # Greedy lookup + ], +) +def test_find_modules_by_pattern_missing_rule_error(pattern): + """Test token validation rule. Must raise errors""" + pattern = Pattern(mask=pattern) + assert any(pattern.errors) diff --git a/tests/utils.py b/tests/utils.py index 49138fe..c90146c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,18 +1,43 @@ -def mermaid_assert(output: str) -> None: - assert "users {" in output - assert "posts {" in output - assert "comments {" in output +from typing import Union +from sqlalchemy.schema import MetaData +from paracelsus.graph import compare_metadata - assert "users ||--o{ posts : author" in output - assert "posts ||--o{ comments : post" in output - assert "users ||--o{ comments : author" in output - assert "CHAR(32) author FK" in output - assert 'CHAR(32) post FK "nullable"' in output - assert 'BOOLEAN live "True if post is published,nullable"' in output - assert "DATETIME created" in output +def mermaid_assert( + actual: Union[str, MetaData], + expected: Union[MetaData, None] = None, + omit_comments: bool = False, +) -> None: + """ + Asserts that a mermaid graph (string or MetaData) matches expected structure. - trailing_newline_assert(output) + This function supports two modes: + 1. Legacy mode: If actual is a string, performs basic string assertions (for backward compatibility) + 2. Dynamic mode: If actual is MetaData, compares it with expected MetaData + """ + # Legacy mode: string comparison (for backward compatibility) + if isinstance(actual, str): + # Basic assertions for backward compatibility + assert "users {" in actual + assert "posts {" in actual + assert "comments {" in actual + + assert "users ||--o{ posts : author" in actual + assert "posts ||--o{ comments : post" in actual + assert "users ||--o{ comments : author" in actual + + assert "CHAR(32) author FK" in actual + assert 'CHAR(32) post FK "nullable"' in actual + assert 'BOOLEAN live "True if post is published,nullable"' in actual + assert "DATETIME created" in actual + + trailing_newline_assert(actual) + + # Dynamic mode: MetaData comparison + else: + if expected is None: + raise ValueError("expected MetaData is required when actual is MetaData") + compare_metadata(actual, expected, omit_comments=omit_comments) def dot_assert(output: str) -> None: