From 952199dba20328c81ef4bd368705925e73312a14 Mon Sep 17 00:00:00 2001 From: Apti Date: Thu, 11 Dec 2025 18:19:10 +0300 Subject: [PATCH 01/20] WIP: add glob pattern support for --import-module --- paracelsus/graph.py | 63 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/paracelsus/graph.py b/paracelsus/graph.py index 935137a..d6d9aeb 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -3,6 +3,7 @@ import sys from pathlib import Path import re +import pkgutil from typing import List, Set, Optional, Dict, Union from sqlalchemy.schema import MetaData @@ -18,6 +19,46 @@ "gv": Dot, } +def find_modules_by_pattern(pattern: str) -> List[str]: + """Finds all modules that match the glob pattern.""" + parts = pattern.split(".") + + star_index = None + for i, part in enumerate(parts): + if part == "*": + star_index = i + break + + prefix_parts = parts[:star_index] + suffix_parts = parts[star_index + 1:] + + if not suffix_parts: + raise ValueError( + f"Glob pattern '{pattern}' must specify a module name after '*'. " + ) + + base_package_name = ".".join(prefix_parts) + base_package = importlib.import_module(base_package_name) + base_path = base_package.__path__[0] + + found_modules = [] + + for importer, modname, ispkg in pkgutil.iter_modules([base_path]): + # Form a subpackage name + subpackage_name = f"{base_package_name}.{modname}" + + # Form the full name of the target module + target_module_name = f"{subpackage_name}.{'.'.join(suffix_parts)}" + + # Check that the module exists and add it to the list. + try: + importlib.import_module(target_module_name) + found_modules.append(target_module_name) + except ImportError: + continue + + return found_modules + def get_graph_string( *, @@ -48,10 +89,26 @@ def get_graph_string( # These modules aren't actually used in any way, so they are discarded. # They are also imported in scope of this function to prevent namespace pollution. for module in import_module: - if ":*" in module: - # Sure, execs are gross, but this is the only way to dynamically import wildcards. - exec(f"from {module[:-2]} import *") + needs_wildcards_import = module.endswith(":*") + + search_pattern = module[:-2] if needs_wildcards_import else module + + if "*" in search_pattern: + # This is a glob pattern, find all the corresponding modules + found_models = find_modules_by_pattern(search_pattern) + + for found_model in found_models: + if needs_wildcards_import: + # Combination: glob search + wildcard import + exec(f"from {found_model} import *") + else: + # Glob search only, normal import + importlib.import_module(found_model) + elif needs_wildcards_import: + # Wildcard import only + exec(f"from {search_pattern} import *") else: + # Normal module import importlib.import_module(module) # Grab a transformer. From b8d6f62d6803ec6c85723ac458adc95306e8e089 Mon Sep 17 00:00:00 2001 From: Apti Date: Sat, 13 Dec 2025 19:20:06 +0300 Subject: [PATCH 02/20] Implemented tests --- paracelsus/graph.py | 2 +- tests/conftest.py | 303 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_graph.py | 102 ++++++++++++++- 3 files changed, 405 insertions(+), 2 deletions(-) diff --git a/paracelsus/graph.py b/paracelsus/graph.py index d6d9aeb..4cc46f7 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -20,7 +20,7 @@ } def find_modules_by_pattern(pattern: str) -> List[str]: - """Finds all modules that match the glob pattern.""" + """Finds all modules that match the glob-like pattern for Python modules.""" parts = pattern.split(".") star_index = None diff --git a/tests/conftest.py b/tests/conftest.py index 01864b9..da41915 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -262,3 +262,306 @@ def fixture_expected_mermaid_cardinalities_graph() -> str: ``` """) + + +@pytest.fixture +def single_level_package_path() -> Generator[Path, None, None]: + """Create a package structure with single-level subpackages for testing pattern example.*.models. + + Structure: + example/ + base.py + foo/ + models.py + bar/ + models.py + """ + with tempfile.TemporaryDirectory() as package_path: + package_dir = Path(package_path) + example_dir = package_dir / "example" + example_dir.mkdir(parents=True, exist_ok=True) + + # Create base + (example_dir / "base.py").write_text( + dedent("""\ + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + """) + ) + (example_dir / "__init__.py").write_text("") + + # Create example.foo.models + (example_dir / "foo").mkdir(parents=True, exist_ok=True) + (example_dir / "foo" / "__init__.py").write_text("") + (example_dir / "foo" / "models.py").write_text( + dedent("""\ + from sqlalchemy import String + from sqlalchemy.orm import mapped_column + from ..base import Base + + class FooModel(Base): + __tablename__ = 'foo_table' + id = mapped_column(String, primary_key=True) + """) + ) + + # Create example.bar.models + (example_dir / "bar").mkdir(parents=True, exist_ok=True) + (example_dir / "bar" / "__init__.py").write_text("") + (example_dir / "bar" / "models.py").write_text( + dedent("""\ + from sqlalchemy import String + from sqlalchemy.orm import mapped_column + from ..base import Base + + class BarModel(Base): + __tablename__ = 'bar_table' + id = mapped_column(String, primary_key=True) + """) + ) + + os.chdir(package_path) + + # Cleanup + for name in list(sys.modules.keys()): + if name == "example" or name.startswith("example."): + del sys.modules[name] + + yield Path(package_path) + + +@pytest.fixture +def nested_package_path() -> Generator[Path, None, None]: + """Create a package structure with nested subpackages for testing multi-level glob patterns. + + Structure: + example/ + domain/ + users/ + models.py + products/ + models.py + api/ + v1/ + models.py + """ + with tempfile.TemporaryDirectory() as package_path: + package_dir = Path(package_path) + example_dir = package_dir / "example" + example_dir.mkdir(parents=True, exist_ok=True) + + # Create base + (example_dir / "base.py").write_text( + dedent("""\ + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + """) + ) + (example_dir / "__init__.py").write_text("") + + # Create domain.users.models + (example_dir / "domain" / "users").mkdir(parents=True, exist_ok=True) + (example_dir / "domain" / "users" / "__init__.py").write_text("") + (example_dir / "domain" / "users" / "models.py").write_text( + dedent("""\ + from sqlalchemy import String + from sqlalchemy.orm import mapped_column + from ...base import Base + + class User(Base): + __tablename__ = 'users' + id = mapped_column(String, primary_key=True) + """) + ) + + # Create domain.products.models + (example_dir / "domain" / "products").mkdir(parents=True, exist_ok=True) + (example_dir / "domain" / "products" / "__init__.py").write_text("") + (example_dir / "domain" / "products" / "models.py").write_text( + dedent("""\ + from sqlalchemy import String + from sqlalchemy.orm import mapped_column + from ...base import Base + + class Product(Base): + __tablename__ = 'products' + id = mapped_column(String, primary_key=True) + """) + ) + + # Create api.v1.models + (example_dir / "api" / "v1").mkdir(parents=True, exist_ok=True) + (example_dir / "api" / "v1" / "__init__.py").write_text("") + (example_dir / "api" / "v1" / "models.py").write_text( + dedent("""\ + from sqlalchemy import String + from sqlalchemy.orm import mapped_column + from ...base import Base + + class APIResource(Base): + __tablename__ = 'api_resources' + id = mapped_column(String, primary_key=True) + """) + ) + + os.chdir(package_path) + + # Cleanup + for name in list(sys.modules.keys()): + if name == "example" or name.startswith("example."): + del sys.modules[name] + + yield Path(package_path) + + +@pytest.fixture +def multi_star_package_path() -> Generator[Path, None, None]: + """Create a package structure for testing patterns with multiple stars. + + Structure: + example/ + v1/ + api/ + users/ + models.py + v2/ + api/ + products/ + models.py + """ + with tempfile.TemporaryDirectory() as package_path: + package_dir = Path(package_path) + example_dir = package_dir / "example" + example_dir.mkdir(parents=True, exist_ok=True) + + # Create base + (example_dir / "base.py").write_text( + dedent("""\ + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + """) + ) + (example_dir / "__init__.py").write_text("") + + # Create v1.api.users.models + (example_dir / "v1" / "api" / "users").mkdir(parents=True, exist_ok=True) + (example_dir / "v1" / "api" / "users" / "__init__.py").write_text("") + (example_dir / "v1" / "api" / "users" / "models.py").write_text( + dedent("""\ + from sqlalchemy import String + from sqlalchemy.orm import mapped_column + from ....base import Base + + class V1User(Base): + __tablename__ = 'v1_users' + id = mapped_column(String, primary_key=True) + """) + ) + + # Create v2.api.products.models + (example_dir / "v2" / "api" / "products").mkdir(parents=True, exist_ok=True) + (example_dir / "v2" / "api" / "products" / "__init__.py").write_text("") + (example_dir / "v2" / "api" / "products" / "models.py").write_text( + dedent("""\ + from sqlalchemy import String + from sqlalchemy.orm import mapped_column + from ....base import Base + + class V2Product(Base): + __tablename__ = 'v2_products' + id = mapped_column(String, primary_key=True) + """) + ) + + os.chdir(package_path) + + # Cleanup + for name in list(sys.modules.keys()): + if name == "example" or name.startswith("example."): + del sys.modules[name] + + yield Path(package_path) + + +@pytest.fixture +def namespace_package_path() -> Generator[Path, None, None]: + """Create a namespace package structure (PEP 420) for testing. + + Structure (two separate directories that form one namespace): + project1/ + example/ (NO __init__.py) + subpackage_a/ + models.py + project2/ + example/ (NO __init__.py) + subpackage_b/ + models.py + + Both are added to sys.path separately. + """ + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Project 1 + project1 = temp_path / "project1" + (project1 / "example").mkdir(parents=True, exist_ok=True) + (project1 / "example" / "subpackage_a").mkdir(parents=True, exist_ok=True) + (project1 / "example" / "subpackage_a" / "__init__.py").write_text("") + (project1 / "example" / "base.py").write_text( + dedent("""\ + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + """) + ) + (project1 / "example" / "subpackage_a" / "models.py").write_text( + dedent("""\ + from sqlalchemy import String + from sqlalchemy.orm import mapped_column + from ..base import Base + + class SubpackageAModel(Base): + __tablename__ = 'subpackage_a_table' + id = mapped_column(String, primary_key=True) + """) + ) + + # Project 2 + project2 = temp_path / "project2" + (project2 / "example").mkdir(parents=True, exist_ok=True) + (project2 / "example" / "subpackage_b").mkdir(parents=True, exist_ok=True) + (project2 / "example" / "subpackage_b" / "__init__.py").write_text("") + (project2 / "example" / "subpackage_b" / "models.py").write_text( + dedent("""\ + from sqlalchemy import String + from sqlalchemy.orm import mapped_column + from ..base import Base + + class SubpackageBModel(Base): + __tablename__ = 'subpackage_b_table' + id = mapped_column(String, primary_key=True) + """) + ) + + # We need to add both to sys.path + sys.path.insert(0, str(project1)) + sys.path.insert(0, str(project2)) + + os.chdir(str(temp_path)) + + # Cleanup + for name in list(sys.modules.keys()): + if name == "example" or name.startswith("example."): + del sys.modules[name] + + try: + yield temp_path + finally: + # Remove from sys.path + if str(project1) in sys.path: + sys.path.remove(str(project1)) + if str(project2) in sys.path: + sys.path.remove(str(project2)) \ No newline at end of file diff --git a/tests/test_graph.py b/tests/test_graph.py index 2460aa7..988c101 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,7 +1,8 @@ +import sys import pytest from paracelsus.config import Layouts -from paracelsus.graph import get_graph_string +from paracelsus.graph import get_graph_string, find_modules_by_pattern from .utils import mermaid_assert @@ -94,3 +95,102 @@ def test_get_graph_string_with_layout(layout_arg, package_path): layout=Layouts(layout_arg), ) mermaid_assert(graph_string) + + +def test_find_modules_by_pattern_single_level(single_level_package_path): + """Test basic glob pattern matching with single-level subpackages.""" + + if str(single_level_package_path) not in sys.path: + sys.path.insert(0, str(single_level_package_path)) + + found = find_modules_by_pattern("example.*.models") + expected_modules = { + "example.foo.models", + "example.bar.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + +def test_find_modules_by_pattern_nested_levels(nested_package_path): + """Test glob pattern with nested levels (example.*.*.models). + + Should find: + - example.domain.users.models + - example.domain.products.models + - example.api.v1.models + """ + + if str(nested_package_path) not in sys.path: + sys.path.insert(0, str(nested_package_path)) + + found = find_modules_by_pattern("example.*.*.models") + expected_modules = { + "example.domain.users.models", + "example.domain.products.models", + "example.api.v1.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + +def test_find_modules_by_pattern_multiple_stars(multi_star_package_path): + """Test glob pattern with multiple stars (example.*.api.*.models). + + Should find: + - example.v1.api.users.models + - example.v2.api.products.models + """ + + if str(multi_star_package_path) not in sys.path: + sys.path.insert(0, str(multi_star_package_path)) + + found = find_modules_by_pattern("example.*.api.*.models") + expected_modules = { + "example.v1.api.users.models", + "example.v2.api.products.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + +def test_get_graph_string_with_nested_glob_pattern(nested_package_path): + """Integration test: get_graph_string with nested glob pattern.""" + + if str(nested_package_path) not in sys.path: + sys.path.insert(0, str(nested_package_path)) + + graph_string = get_graph_string( + base_class_path="example.base:Base", + import_module=["example.*.*.models"], + include_tables=set(), + exclude_tables=set(), + python_dir=[nested_package_path], + format="mermaid", + column_sort="key-based", + ) + + assert "users {" in graph_string or "products {" in graph_string or "api_resources {" in graph_string + + +def test_find_modules_by_pattern_namespace_package(namespace_package_path): + """Test glob pattern with namespace packages (PEP 420). + + Should handle namespace packages where __path__ is a list of paths. + """ + import example + + + assert hasattr(example, '__path__') + + found = find_modules_by_pattern("example.*.models") + expected_modules = { + "example.subpackage_a.models", + "example.subpackage_b.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules From e1d0d178e69ae5764bb91dd70863e15ffa8b2e11 Mon Sep 17 00:00:00 2001 From: Apti Date: Tue, 16 Dec 2025 15:55:07 +0300 Subject: [PATCH 03/20] Refactor fixtures to use asset templates --- paracelsus/graph.py | 13 +- tests/assets/multi_star/example/__init__.py | 0 tests/assets/multi_star/example/base.py | 3 + .../example/v1/api/users/__init__.py | 0 .../multi_star/example/v1/api/users/models.py | 8 + .../example/v2/api/products/__init__.py | 0 .../example/v2/api/products/models.py | 8 + .../assets/namespace/project1/example/base.py | 3 + .../project1/example/subpackage_a/__init__.py | 0 .../project1/example/subpackage_a/models.py | 8 + .../project2/example/subpackage_b/__init__.py | 0 .../project2/example/subpackage_b/models.py | 8 + tests/assets/nested/example/__init__.py | 0 .../assets/nested/example/api/v1/__init__.py | 0 tests/assets/nested/example/api/v1/models.py | 8 + tests/assets/nested/example/base.py | 3 + .../example/domain/products/__init__.py | 0 .../nested/example/domain/products/models.py | 8 + .../nested/example/domain/users/__init__.py | 0 .../nested/example/domain/users/models.py | 8 + tests/assets/single_level/example/__init__.py | 0 .../single_level/example/bar/__init__.py | 0 .../assets/single_level/example/bar/models.py | 8 + tests/assets/single_level/example/base.py | 3 + .../single_level/example/foo/__init__.py | 0 .../assets/single_level/example/foo/models.py | 8 + tests/conftest.py | 345 ++++-------------- tests/test_graph.py | 90 +---- tests/transformers/test_find_modules.py | 57 +++ 29 files changed, 228 insertions(+), 361 deletions(-) create mode 100644 tests/assets/multi_star/example/__init__.py create mode 100644 tests/assets/multi_star/example/base.py create mode 100644 tests/assets/multi_star/example/v1/api/users/__init__.py create mode 100644 tests/assets/multi_star/example/v1/api/users/models.py create mode 100644 tests/assets/multi_star/example/v2/api/products/__init__.py create mode 100644 tests/assets/multi_star/example/v2/api/products/models.py create mode 100644 tests/assets/namespace/project1/example/base.py create mode 100644 tests/assets/namespace/project1/example/subpackage_a/__init__.py create mode 100644 tests/assets/namespace/project1/example/subpackage_a/models.py create mode 100644 tests/assets/namespace/project2/example/subpackage_b/__init__.py create mode 100644 tests/assets/namespace/project2/example/subpackage_b/models.py create mode 100644 tests/assets/nested/example/__init__.py create mode 100644 tests/assets/nested/example/api/v1/__init__.py create mode 100644 tests/assets/nested/example/api/v1/models.py create mode 100644 tests/assets/nested/example/base.py create mode 100644 tests/assets/nested/example/domain/products/__init__.py create mode 100644 tests/assets/nested/example/domain/products/models.py create mode 100644 tests/assets/nested/example/domain/users/__init__.py create mode 100644 tests/assets/nested/example/domain/users/models.py create mode 100644 tests/assets/single_level/example/__init__.py create mode 100644 tests/assets/single_level/example/bar/__init__.py create mode 100644 tests/assets/single_level/example/bar/models.py create mode 100644 tests/assets/single_level/example/base.py create mode 100644 tests/assets/single_level/example/foo/__init__.py create mode 100644 tests/assets/single_level/example/foo/models.py create mode 100644 tests/transformers/test_find_modules.py diff --git a/paracelsus/graph.py b/paracelsus/graph.py index 4cc46f7..c66d091 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -19,6 +19,7 @@ "gv": Dot, } + def find_modules_by_pattern(pattern: str) -> List[str]: """Finds all modules that match the glob-like pattern for Python modules.""" parts = pattern.split(".") @@ -28,14 +29,12 @@ def find_modules_by_pattern(pattern: str) -> List[str]: if part == "*": star_index = i break - + prefix_parts = parts[:star_index] - suffix_parts = parts[star_index + 1:] + suffix_parts = parts[star_index + 1 :] if not suffix_parts: - raise ValueError( - f"Glob pattern '{pattern}' must specify a module name after '*'. " - ) + raise ValueError(f"Glob pattern '{pattern}' must specify a module name after '*'. ") base_package_name = ".".join(prefix_parts) base_package = importlib.import_module(base_package_name) @@ -56,9 +55,9 @@ def find_modules_by_pattern(pattern: str) -> List[str]: found_modules.append(target_module_name) except ImportError: continue - + return found_modules - + def get_graph_string( *, diff --git a/tests/assets/multi_star/example/__init__.py b/tests/assets/multi_star/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/multi_star/example/base.py b/tests/assets/multi_star/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/multi_star/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/multi_star/example/v1/api/users/__init__.py b/tests/assets/multi_star/example/v1/api/users/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/multi_star/example/v1/api/users/models.py b/tests/assets/multi_star/example/v1/api/users/models.py new file mode 100644 index 0000000..bd17ca9 --- /dev/null +++ b/tests/assets/multi_star/example/v1/api/users/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ....base import Base + + +class V1User(Base): + __tablename__ = "v1_users" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/multi_star/example/v2/api/products/__init__.py b/tests/assets/multi_star/example/v2/api/products/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/multi_star/example/v2/api/products/models.py b/tests/assets/multi_star/example/v2/api/products/models.py new file mode 100644 index 0000000..fc1416c --- /dev/null +++ b/tests/assets/multi_star/example/v2/api/products/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ....base import Base + + +class V2Product(Base): + __tablename__ = "v2_products" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/namespace/project1/example/base.py b/tests/assets/namespace/project1/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/namespace/project1/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/namespace/project1/example/subpackage_a/__init__.py b/tests/assets/namespace/project1/example/subpackage_a/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/namespace/project1/example/subpackage_a/models.py b/tests/assets/namespace/project1/example/subpackage_a/models.py new file mode 100644 index 0000000..762e170 --- /dev/null +++ b/tests/assets/namespace/project1/example/subpackage_a/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ..base import Base + + +class SubpackageAModel(Base): + __tablename__ = "subpackage_a_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/namespace/project2/example/subpackage_b/__init__.py b/tests/assets/namespace/project2/example/subpackage_b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/namespace/project2/example/subpackage_b/models.py b/tests/assets/namespace/project2/example/subpackage_b/models.py new file mode 100644 index 0000000..3bb28a9 --- /dev/null +++ b/tests/assets/namespace/project2/example/subpackage_b/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ..base import Base + + +class SubpackageBModel(Base): + __tablename__ = "subpackage_b_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/nested/example/__init__.py b/tests/assets/nested/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/nested/example/api/v1/__init__.py b/tests/assets/nested/example/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/nested/example/api/v1/models.py b/tests/assets/nested/example/api/v1/models.py new file mode 100644 index 0000000..fe4a781 --- /dev/null +++ b/tests/assets/nested/example/api/v1/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class APIResource(Base): + __tablename__ = "api_resources" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/nested/example/base.py b/tests/assets/nested/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/nested/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/nested/example/domain/products/__init__.py b/tests/assets/nested/example/domain/products/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/nested/example/domain/products/models.py b/tests/assets/nested/example/domain/products/models.py new file mode 100644 index 0000000..b1f8cb7 --- /dev/null +++ b/tests/assets/nested/example/domain/products/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class Product(Base): + __tablename__ = "products" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/nested/example/domain/users/__init__.py b/tests/assets/nested/example/domain/users/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/nested/example/domain/users/models.py b/tests/assets/nested/example/domain/users/models.py new file mode 100644 index 0000000..dfd6458 --- /dev/null +++ b/tests/assets/nested/example/domain/users/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class User(Base): + __tablename__ = "users" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/single_level/example/__init__.py b/tests/assets/single_level/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/single_level/example/bar/__init__.py b/tests/assets/single_level/example/bar/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/single_level/example/bar/models.py b/tests/assets/single_level/example/bar/models.py new file mode 100644 index 0000000..1893459 --- /dev/null +++ b/tests/assets/single_level/example/bar/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ..base import Base + + +class BarModel(Base): + __tablename__ = "bar_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/single_level/example/base.py b/tests/assets/single_level/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/single_level/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/single_level/example/foo/__init__.py b/tests/assets/single_level/example/foo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/single_level/example/foo/models.py b/tests/assets/single_level/example/foo/models.py new file mode 100644 index 0000000..aae5a59 --- /dev/null +++ b/tests/assets/single_level/example/foo/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ..base import Base + + +class FooModel(Base): + __tablename__ = "foo_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/conftest.py b/tests/conftest.py index da41915..15a8587 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,43 @@ UTC = timezone.utc +def setup_sys_path_for_test(paths: Path | list[Path], module_prefix: str = "example") -> str | list[str]: + """Add paths to sys.path and clear module cache. Returns path(s) for cleanup.""" + # Clear module cache + for name in list(sys.modules.keys()): + if name == module_prefix or name.startswith(f"{module_prefix}."): + del sys.modules[name] + + # Normalize input + if isinstance(paths, Path): + paths_list = [paths] + single_path = True + else: + paths_list = paths + single_path = False + + # Add paths to sys.path + path_strings = [] + for path in paths_list: + path_str = str(path) + path_strings.append(path_str) + sys.path.insert(0, path_str) + + return path_strings[0] if single_path else path_strings + + +def cleanup_sys_path(paths: str | list[str]) -> None: + """Remove path(s) from sys.path.""" + if isinstance(paths, str): + paths_list = [paths] + else: + paths_list = paths + + for path_str in paths_list: + if path_str in sys.path: + sys.path.remove(path_str) + + @pytest.fixture def metaclass(): Base = declarative_base() @@ -266,302 +303,76 @@ def fixture_expected_mermaid_cardinalities_graph() -> str: @pytest.fixture def single_level_package_path() -> Generator[Path, None, None]: - """Create a package structure with single-level subpackages for testing pattern example.*.models. - - Structure: - example/ - base.py - foo/ - models.py - bar/ - models.py - """ + """Create a package structure with single-level subpackages for testing pattern example.*.models.""" + template_path = Path(os.path.dirname(os.path.realpath(__file__))) / "assets" / "single_level" with tempfile.TemporaryDirectory() as package_path: + shutil.copytree(template_path, package_path, dirs_exist_ok=True) package_dir = Path(package_path) - example_dir = package_dir / "example" - example_dir.mkdir(parents=True, exist_ok=True) - - # Create base - (example_dir / "base.py").write_text( - dedent("""\ - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - """) - ) - (example_dir / "__init__.py").write_text("") - - # Create example.foo.models - (example_dir / "foo").mkdir(parents=True, exist_ok=True) - (example_dir / "foo" / "__init__.py").write_text("") - (example_dir / "foo" / "models.py").write_text( - dedent("""\ - from sqlalchemy import String - from sqlalchemy.orm import mapped_column - from ..base import Base - - class FooModel(Base): - __tablename__ = 'foo_table' - id = mapped_column(String, primary_key=True) - """) - ) - - # Create example.bar.models - (example_dir / "bar").mkdir(parents=True, exist_ok=True) - (example_dir / "bar" / "__init__.py").write_text("") - (example_dir / "bar" / "models.py").write_text( - dedent("""\ - from sqlalchemy import String - from sqlalchemy.orm import mapped_column - from ..base import Base - - class BarModel(Base): - __tablename__ = 'bar_table' - id = mapped_column(String, primary_key=True) - """) - ) - os.chdir(package_path) - # Cleanup - for name in list(sys.modules.keys()): - if name == "example" or name.startswith("example."): - del sys.modules[name] - - yield Path(package_path) + path_str = setup_sys_path_for_test(package_dir) + try: + yield Path(package_path) + finally: + cleanup_sys_path(path_str) @pytest.fixture def nested_package_path() -> Generator[Path, None, None]: - """Create a package structure with nested subpackages for testing multi-level glob patterns. - - Structure: - example/ - domain/ - users/ - models.py - products/ - models.py - api/ - v1/ - models.py - """ + """Create a package structure with nested subpackages for testing multi-level glob patterns.""" + template_path = Path(os.path.dirname(os.path.realpath(__file__))) / "assets" / "nested" with tempfile.TemporaryDirectory() as package_path: + shutil.copytree(template_path, package_path, dirs_exist_ok=True) package_dir = Path(package_path) - example_dir = package_dir / "example" - example_dir.mkdir(parents=True, exist_ok=True) - - # Create base - (example_dir / "base.py").write_text( - dedent("""\ - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - """) - ) - (example_dir / "__init__.py").write_text("") - - # Create domain.users.models - (example_dir / "domain" / "users").mkdir(parents=True, exist_ok=True) - (example_dir / "domain" / "users" / "__init__.py").write_text("") - (example_dir / "domain" / "users" / "models.py").write_text( - dedent("""\ - from sqlalchemy import String - from sqlalchemy.orm import mapped_column - from ...base import Base - - class User(Base): - __tablename__ = 'users' - id = mapped_column(String, primary_key=True) - """) - ) - - # Create domain.products.models - (example_dir / "domain" / "products").mkdir(parents=True, exist_ok=True) - (example_dir / "domain" / "products" / "__init__.py").write_text("") - (example_dir / "domain" / "products" / "models.py").write_text( - dedent("""\ - from sqlalchemy import String - from sqlalchemy.orm import mapped_column - from ...base import Base - - class Product(Base): - __tablename__ = 'products' - id = mapped_column(String, primary_key=True) - """) - ) - - # Create api.v1.models - (example_dir / "api" / "v1").mkdir(parents=True, exist_ok=True) - (example_dir / "api" / "v1" / "__init__.py").write_text("") - (example_dir / "api" / "v1" / "models.py").write_text( - dedent("""\ - from sqlalchemy import String - from sqlalchemy.orm import mapped_column - from ...base import Base - - class APIResource(Base): - __tablename__ = 'api_resources' - id = mapped_column(String, primary_key=True) - """) - ) - os.chdir(package_path) - # Cleanup - for name in list(sys.modules.keys()): - if name == "example" or name.startswith("example."): - del sys.modules[name] - - yield Path(package_path) + path_str = setup_sys_path_for_test(package_dir) + try: + yield Path(package_path) + finally: + cleanup_sys_path(path_str) @pytest.fixture def multi_star_package_path() -> Generator[Path, None, None]: - """Create a package structure for testing patterns with multiple stars. - - Structure: - example/ - v1/ - api/ - users/ - models.py - v2/ - api/ - products/ - models.py - """ + """Create a package structure for testing patterns with multiple stars.""" + template_path = Path(os.path.dirname(os.path.realpath(__file__))) / "assets" / "multi_star" with tempfile.TemporaryDirectory() as package_path: + shutil.copytree(template_path, package_path, dirs_exist_ok=True) package_dir = Path(package_path) - example_dir = package_dir / "example" - example_dir.mkdir(parents=True, exist_ok=True) - - # Create base - (example_dir / "base.py").write_text( - dedent("""\ - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - """) - ) - (example_dir / "__init__.py").write_text("") - - # Create v1.api.users.models - (example_dir / "v1" / "api" / "users").mkdir(parents=True, exist_ok=True) - (example_dir / "v1" / "api" / "users" / "__init__.py").write_text("") - (example_dir / "v1" / "api" / "users" / "models.py").write_text( - dedent("""\ - from sqlalchemy import String - from sqlalchemy.orm import mapped_column - from ....base import Base - - class V1User(Base): - __tablename__ = 'v1_users' - id = mapped_column(String, primary_key=True) - """) - ) - - # Create v2.api.products.models - (example_dir / "v2" / "api" / "products").mkdir(parents=True, exist_ok=True) - (example_dir / "v2" / "api" / "products" / "__init__.py").write_text("") - (example_dir / "v2" / "api" / "products" / "models.py").write_text( - dedent("""\ - from sqlalchemy import String - from sqlalchemy.orm import mapped_column - from ....base import Base - - class V2Product(Base): - __tablename__ = 'v2_products' - id = mapped_column(String, primary_key=True) - """) - ) - os.chdir(package_path) - # Cleanup - for name in list(sys.modules.keys()): - if name == "example" or name.startswith("example."): - del sys.modules[name] - - yield Path(package_path) + path_str = setup_sys_path_for_test(package_dir) + try: + yield Path(package_path) + finally: + cleanup_sys_path(path_str) @pytest.fixture def namespace_package_path() -> Generator[Path, None, None]: - """Create a namespace package structure (PEP 420) for testing. - - Structure (two separate directories that form one namespace): - project1/ - example/ (NO __init__.py) - subpackage_a/ - models.py - project2/ - example/ (NO __init__.py) - subpackage_b/ - models.py - - Both are added to sys.path separately. - """ + """Create a namespace package structure (PEP 420) for testing.""" + template_base = Path(os.path.dirname(os.path.realpath(__file__))) / "assets" / "namespace" with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) - - # Project 1 + + # Copy both projects + shutil.copytree(template_base / "project1", temp_path / "project1", dirs_exist_ok=True) + shutil.copytree(template_base / "project2", temp_path / "project2", dirs_exist_ok=True) + project1 = temp_path / "project1" - (project1 / "example").mkdir(parents=True, exist_ok=True) - (project1 / "example" / "subpackage_a").mkdir(parents=True, exist_ok=True) - (project1 / "example" / "subpackage_a" / "__init__.py").write_text("") - (project1 / "example" / "base.py").write_text( - dedent("""\ - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - """) - ) - (project1 / "example" / "subpackage_a" / "models.py").write_text( - dedent("""\ - from sqlalchemy import String - from sqlalchemy.orm import mapped_column - from ..base import Base - - class SubpackageAModel(Base): - __tablename__ = 'subpackage_a_table' - id = mapped_column(String, primary_key=True) - """) - ) - - # Project 2 project2 = temp_path / "project2" - (project2 / "example").mkdir(parents=True, exist_ok=True) - (project2 / "example" / "subpackage_b").mkdir(parents=True, exist_ok=True) - (project2 / "example" / "subpackage_b" / "__init__.py").write_text("") - (project2 / "example" / "subpackage_b" / "models.py").write_text( - dedent("""\ - from sqlalchemy import String - from sqlalchemy.orm import mapped_column - from ..base import Base - - class SubpackageBModel(Base): - __tablename__ = 'subpackage_b_table' - id = mapped_column(String, primary_key=True) - """) - ) - - # We need to add both to sys.path - sys.path.insert(0, str(project1)) - sys.path.insert(0, str(project2)) - + os.chdir(str(temp_path)) - - # Cleanup - for name in list(sys.modules.keys()): - if name == "example" or name.startswith("example."): - del sys.modules[name] - + + path_strings = setup_sys_path_for_test([project1, project2]) + + # Import example to check that it is a namespace package + import example + + assert hasattr(example, "__path__"), "example should be a namespace package with __path__ attribute" + try: yield temp_path finally: - # Remove from sys.path - if str(project1) in sys.path: - sys.path.remove(str(project1)) - if str(project2) in sys.path: - sys.path.remove(str(project2)) \ No newline at end of file + cleanup_sys_path(path_strings) diff --git a/tests/test_graph.py b/tests/test_graph.py index 988c101..ec182ac 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,8 +1,7 @@ -import sys import pytest from paracelsus.config import Layouts -from paracelsus.graph import get_graph_string, find_modules_by_pattern +from paracelsus.graph import get_graph_string from .utils import mermaid_assert @@ -97,72 +96,9 @@ def test_get_graph_string_with_layout(layout_arg, package_path): mermaid_assert(graph_string) -def test_find_modules_by_pattern_single_level(single_level_package_path): - """Test basic glob pattern matching with single-level subpackages.""" - - if str(single_level_package_path) not in sys.path: - sys.path.insert(0, str(single_level_package_path)) - - found = find_modules_by_pattern("example.*.models") - expected_modules = { - "example.foo.models", - "example.bar.models", - } - - assert len(found) == len(expected_modules) - assert set(found) == expected_modules - - -def test_find_modules_by_pattern_nested_levels(nested_package_path): - """Test glob pattern with nested levels (example.*.*.models). - - Should find: - - example.domain.users.models - - example.domain.products.models - - example.api.v1.models - """ - - if str(nested_package_path) not in sys.path: - sys.path.insert(0, str(nested_package_path)) - - found = find_modules_by_pattern("example.*.*.models") - expected_modules = { - "example.domain.users.models", - "example.domain.products.models", - "example.api.v1.models", - } - - assert len(found) == len(expected_modules) - assert set(found) == expected_modules - - -def test_find_modules_by_pattern_multiple_stars(multi_star_package_path): - """Test glob pattern with multiple stars (example.*.api.*.models). - - Should find: - - example.v1.api.users.models - - example.v2.api.products.models - """ - - if str(multi_star_package_path) not in sys.path: - sys.path.insert(0, str(multi_star_package_path)) - - found = find_modules_by_pattern("example.*.api.*.models") - expected_modules = { - "example.v1.api.users.models", - "example.v2.api.products.models", - } - - assert len(found) == len(expected_modules) - assert set(found) == expected_modules - - def test_get_graph_string_with_nested_glob_pattern(nested_package_path): """Integration test: get_graph_string with nested glob pattern.""" - - if str(nested_package_path) not in sys.path: - sys.path.insert(0, str(nested_package_path)) - + graph_string = get_graph_string( base_class_path="example.base:Base", import_module=["example.*.*.models"], @@ -172,25 +108,5 @@ def test_get_graph_string_with_nested_glob_pattern(nested_package_path): format="mermaid", column_sort="key-based", ) - - assert "users {" in graph_string or "products {" in graph_string or "api_resources {" in graph_string - -def test_find_modules_by_pattern_namespace_package(namespace_package_path): - """Test glob pattern with namespace packages (PEP 420). - - Should handle namespace packages where __path__ is a list of paths. - """ - import example - - - assert hasattr(example, '__path__') - - found = find_modules_by_pattern("example.*.models") - expected_modules = { - "example.subpackage_a.models", - "example.subpackage_b.models", - } - - assert len(found) == len(expected_modules) - assert set(found) == expected_modules + assert "users {" in graph_string or "products {" in graph_string or "api_resources {" in graph_string diff --git a/tests/transformers/test_find_modules.py b/tests/transformers/test_find_modules.py new file mode 100644 index 0000000..269fecc --- /dev/null +++ b/tests/transformers/test_find_modules.py @@ -0,0 +1,57 @@ +from paracelsus.graph import find_modules_by_pattern + + +def test_find_modules_by_pattern_single_level(single_level_package_path): + """Test basic glob pattern matching with single-level subpackages.""" + + found = find_modules_by_pattern("example.*.models") + expected_modules = { + "example.foo.models", + "example.bar.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + +def test_find_modules_by_pattern_nested_levels(nested_package_path): + """Test glob pattern with nested levels (example.*.*.models).""" + + found = find_modules_by_pattern("example.*.*.models") + expected_modules = { + "example.domain.users.models", + "example.domain.products.models", + "example.api.v1.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + +def test_find_modules_by_pattern_multiple_stars(multi_star_package_path): + """Test glob pattern with multiple stars (example.*.api.*.models).""" + + found = find_modules_by_pattern("example.*.api.*.models") + expected_modules = { + "example.v1.api.users.models", + "example.v2.api.products.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + +def test_find_modules_by_pattern_namespace_package(namespace_package_path): + """Test glob pattern with namespace packages (PEP 420). + + Should handle namespace packages where __path__ is a list of paths. + """ + + found = find_modules_by_pattern("example.*.models") + expected_modules = { + "example.subpackage_a.models", + "example.subpackage_b.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules From ba08efe0904c6a56215ee0cf3b98c65d753b1b48 Mon Sep 17 00:00:00 2001 From: Apti Date: Fri, 19 Dec 2025 19:22:57 +0300 Subject: [PATCH 04/20] refactor: use singledispatch for test path utilities --- tests/conftest.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 15a8587..bc6605f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import shutil import sys import tempfile +from functools import singledispatch from collections.abc import Generator from datetime import datetime, timezone from pathlib import Path @@ -15,43 +16,42 @@ UTC = timezone.utc -def setup_sys_path_for_test(paths: Path | list[Path], module_prefix: str = "example") -> str | list[str]: - """Add paths to sys.path and clear module cache. Returns path(s) for cleanup.""" +@singledispatch +def setup_sys_path_for_test(paths, module_prefix: str = "example") -> list[str]: # Clear module cache for name in list(sys.modules.keys()): if name == module_prefix or name.startswith(f"{module_prefix}."): del sys.modules[name] - # Normalize input - if isinstance(paths, Path): - paths_list = [paths] - single_path = True - else: - paths_list = paths - single_path = False - # Add paths to sys.path path_strings = [] - for path in paths_list: + for path in paths: path_str = str(path) path_strings.append(path_str) sys.path.insert(0, path_str) - return path_strings[0] if single_path else path_strings + return path_strings + + +@setup_sys_path_for_test.register +def _(paths: Path, module_prefix: str = "example") -> str: + result = setup_sys_path_for_test([paths], module_prefix) + return result[0] -def cleanup_sys_path(paths: str | list[str]) -> None: - """Remove path(s) from sys.path.""" - if isinstance(paths, str): - paths_list = [paths] - else: - paths_list = paths - for path_str in paths_list: +@singledispatch +def cleanup_sys_path(paths) -> None: + for path_str in paths: if path_str in sys.path: sys.path.remove(path_str) +@cleanup_sys_path.register +def _(paths: str) -> None: + cleanup_sys_path([paths]) + + @pytest.fixture def metaclass(): Base = declarative_base() @@ -357,8 +357,9 @@ def namespace_package_path() -> Generator[Path, None, None]: temp_path = Path(temp_dir) # Copy both projects - shutil.copytree(template_base / "project1", temp_path / "project1", dirs_exist_ok=True) - shutil.copytree(template_base / "project2", temp_path / "project2", dirs_exist_ok=True) + for project_key in range(1, 3): + project_dir = f"project{project_key}" + shutil.copytree(template_base / project_dir, temp_path / project_dir, dirs_exist_ok=True) project1 = temp_path / "project1" project2 = temp_path / "project2" From fbf6cb02bfcba7ccafb92519e419364cc1ae7666 Mon Sep 17 00:00:00 2001 From: Apti Date: Fri, 26 Dec 2025 17:45:09 +0300 Subject: [PATCH 05/20] refactor and extend glob pattern matching with advanced features --- paracelsus/finders.py | 248 ++++++++++++++++++ paracelsus/graph.py | 49 ++-- .../character_classes/example/__init__.py | 0 .../character_classes/example/api/__init__.py | 0 .../example/api/v0/__init__.py | 0 .../example/api/v0/models.py | 8 + .../example/api/v1/__init__.py | 0 .../example/api/v1/models.py | 8 + .../example/api/v10/__init__.py | 0 .../example/api/v10/models.py | 8 + .../example/api/v2/__init__.py | 0 .../example/api/v2/models.py | 8 + .../example/api/v3/__init__.py | 0 .../example/api/v3/models.py | 8 + .../example/api/v4/__init__.py | 0 .../example/api/v4/models.py | 8 + .../example/api/v5/__init__.py | 0 .../example/api/v5/models.py | 8 + .../example/api/v6/__init__.py | 0 .../example/api/v6/models.py | 8 + .../example/api/v7/__init__.py | 0 .../example/api/v7/models.py | 8 + .../example/api/v8/__init__.py | 0 .../example/api/v8/models.py | 8 + .../example/api/v9/__init__.py | 0 .../example/api/v9/models.py | 8 + .../example/api/va/__init__.py | 0 .../example/api/va/models.py | 8 + .../example/api/vb/__init__.py | 0 .../example/api/vb/models.py | 8 + .../assets/character_classes/example/base.py | 3 + tests/assets/recursive/example/__init__.py | 0 .../assets/recursive/example/api/__init__.py | 0 .../recursive/example/api/v1/__init__.py | 0 .../assets/recursive/example/api/v1/models.py | 8 + tests/assets/recursive/example/base.py | 3 + .../recursive/example/domain/__init__.py | 0 .../example/domain/users/__init__.py | 0 .../recursive/example/domain/users/models.py | 8 + .../recursive/example/level1/__init__.py | 0 .../example/level1/level2/__init__.py | 0 .../example/level1/level2/api/__init__.py | 0 .../example/level1/level2/api/v3/__init__.py | 0 .../example/level1/level2/api/v3/models.py | 8 + .../recursive/example/something/__init__.py | 0 .../example/something/api/__init__.py | 0 .../example/something/api/v2/__init__.py | 0 .../example/something/api/v2/models.py | 8 + tests/conftest.py | 32 +++ tests/transformers/test_find_modules.py | 143 +++++++++- 50 files changed, 579 insertions(+), 35 deletions(-) create mode 100644 paracelsus/finders.py create mode 100644 tests/assets/character_classes/example/__init__.py create mode 100644 tests/assets/character_classes/example/api/__init__.py create mode 100644 tests/assets/character_classes/example/api/v0/__init__.py create mode 100644 tests/assets/character_classes/example/api/v0/models.py create mode 100644 tests/assets/character_classes/example/api/v1/__init__.py create mode 100644 tests/assets/character_classes/example/api/v1/models.py create mode 100644 tests/assets/character_classes/example/api/v10/__init__.py create mode 100644 tests/assets/character_classes/example/api/v10/models.py create mode 100644 tests/assets/character_classes/example/api/v2/__init__.py create mode 100644 tests/assets/character_classes/example/api/v2/models.py create mode 100644 tests/assets/character_classes/example/api/v3/__init__.py create mode 100644 tests/assets/character_classes/example/api/v3/models.py create mode 100644 tests/assets/character_classes/example/api/v4/__init__.py create mode 100644 tests/assets/character_classes/example/api/v4/models.py create mode 100644 tests/assets/character_classes/example/api/v5/__init__.py create mode 100644 tests/assets/character_classes/example/api/v5/models.py create mode 100644 tests/assets/character_classes/example/api/v6/__init__.py create mode 100644 tests/assets/character_classes/example/api/v6/models.py create mode 100644 tests/assets/character_classes/example/api/v7/__init__.py create mode 100644 tests/assets/character_classes/example/api/v7/models.py create mode 100644 tests/assets/character_classes/example/api/v8/__init__.py create mode 100644 tests/assets/character_classes/example/api/v8/models.py create mode 100644 tests/assets/character_classes/example/api/v9/__init__.py create mode 100644 tests/assets/character_classes/example/api/v9/models.py create mode 100644 tests/assets/character_classes/example/api/va/__init__.py create mode 100644 tests/assets/character_classes/example/api/va/models.py create mode 100644 tests/assets/character_classes/example/api/vb/__init__.py create mode 100644 tests/assets/character_classes/example/api/vb/models.py create mode 100644 tests/assets/character_classes/example/base.py create mode 100644 tests/assets/recursive/example/__init__.py create mode 100644 tests/assets/recursive/example/api/__init__.py create mode 100644 tests/assets/recursive/example/api/v1/__init__.py create mode 100644 tests/assets/recursive/example/api/v1/models.py create mode 100644 tests/assets/recursive/example/base.py create mode 100644 tests/assets/recursive/example/domain/__init__.py create mode 100644 tests/assets/recursive/example/domain/users/__init__.py create mode 100644 tests/assets/recursive/example/domain/users/models.py create mode 100644 tests/assets/recursive/example/level1/__init__.py create mode 100644 tests/assets/recursive/example/level1/level2/__init__.py create mode 100644 tests/assets/recursive/example/level1/level2/api/__init__.py create mode 100644 tests/assets/recursive/example/level1/level2/api/v3/__init__.py create mode 100644 tests/assets/recursive/example/level1/level2/api/v3/models.py create mode 100644 tests/assets/recursive/example/something/__init__.py create mode 100644 tests/assets/recursive/example/something/api/__init__.py create mode 100644 tests/assets/recursive/example/something/api/v2/__init__.py create mode 100644 tests/assets/recursive/example/something/api/v2/models.py diff --git a/paracelsus/finders.py b/paracelsus/finders.py new file mode 100644 index 0000000..0a4aff7 --- /dev/null +++ b/paracelsus/finders.py @@ -0,0 +1,248 @@ +import fnmatch +import importlib +import os +import pkgutil +import re +import types +from typing import List, Set, Optional + + +def _get_package_paths(package: types.ModuleType) -> List[str]: + """Get all paths for a package, handling namespace packages. + Namespace packages (PEP 420) can have multiple paths in __path__. + """ + if not hasattr(package, "__path__"): + return [] + + paths = package.__path__ + + # Handle _NamespacePath and other iterable path objects + try: + if hasattr(paths, "__iter__") and not isinstance(paths, (str, bytes)): + return [str(p) for p in paths] + except (TypeError, ValueError): + pass + + return [str(paths)] + + +def _match_pattern(name: str, pattern: str) -> bool: + """Match a name against a glob pattern. + + Supports standard fnmatch patterns plus character classes with prefix (e.g., "v[12]"). + """ + # Handle character classes with prefix (e.g., "v[12]") + # fnmatch doesn't support this directly, so we need custom handling + char_class_match = re.search(r"\[([!]?)([^\]]+)\]", pattern) + if char_class_match and not pattern.startswith("["): + # Pattern has a prefix before the character class + prefix = pattern[: char_class_match.start()] + char_class_content = char_class_match.group(2) + negation = char_class_match.group(1) == "!" + + # Check prefix first + if not name.startswith(prefix): + return False + + # Check character class - must match exactly one character + remaining = name[len(prefix) :] + if len(remaining) != 1: + return False + + char = remaining[0] + + # Handle ranges like [0-9] + if "-" in char_class_content and len(char_class_content) == 3: + start, end = char_class_content[0], char_class_content[2] + if negation: + return not (start <= char <= end) + return start <= char <= end + + # Handle character sets like [12] or [abc] + if negation: + return char not in char_class_content + return char in char_class_content + + # Use fnmatch for standard patterns (*, ?, [abc], etc.) + return fnmatch.fnmatch(name, pattern) + + +def _get_modules_in_path(path: str, package_name: str) -> Set[tuple[str, bool]]: + """Get all modules and packages in a given path.""" + + items = set() + + # Use pkgutil for standard modules/packages + try: + for importer, modname, ispkg in pkgutil.iter_modules([path]): + items.add((modname, ispkg)) + except (OSError, TypeError): + pass + + # Also check directories for namespace packages (without __init__.py) + try: + for item in os.listdir(path): + if item.startswith(".") or item == "__pycache__": + continue + item_path = os.path.join(path, item) + if os.path.isdir(item_path) and item not in [name for name, _ in items]: + # Try to import to verify it's a valid package + try: + full_name = f"{package_name}.{item}" if package_name else item + test_module = importlib.import_module(full_name) + is_package = hasattr(test_module, "__path__") + items.add((item, is_package)) + except (ImportError, ValueError): + pass + except (OSError, PermissionError): + pass + + return items + + +def _find_modules_recursive( + package_name: str, + package_paths: List[str], + pattern_segments: List[str], + found_modules: Optional[Set[str]] = None, +) -> Set[str]: + """Recursively find modules matching pattern segments.""" + + if found_modules is None: + found_modules = set() + + # Base case: no more segments to match + if not pattern_segments: + found_modules.add(package_name) + return found_modules + + current_pattern = pattern_segments[0] + remaining_patterns = pattern_segments[1:] + + # Handle recursive pattern (**) + if current_pattern == "**": + # ** matches zero or more levels + # First, try matching remaining segments at current level (zero levels) + if remaining_patterns: + _find_modules_recursive( + package_name, + package_paths, + remaining_patterns, + found_modules, + ) + + # Then, recursively search all subpackages (one or more levels) + for path in package_paths: + path_str = str(path) if not isinstance(path, str) else path + items = _get_modules_in_path(path_str, package_name) + + for modname, ispkg in items: + subpackage_name = f"{package_name}.{modname}" if package_name else modname + try: + subpackage = importlib.import_module(subpackage_name) + subpackage_paths = _get_package_paths(subpackage) + + if subpackage_paths: + # Continue recursive search with ** pattern + _find_modules_recursive( + subpackage_name, + subpackage_paths, + pattern_segments, + found_modules, + ) + except (ImportError, AttributeError): + continue + + # Handle normal segments + else: + for path in package_paths: + path_str = str(path) if not isinstance(path, str) else path + items = _get_modules_in_path(path_str, package_name) + + for modname, ispkg in items: + # Check if name matches current pattern + if not _match_pattern(modname, current_pattern): + continue + + subpackage_name = f"{package_name}.{modname}" if package_name else modname + + # If this is the last segment, add it + if not remaining_patterns: + found_modules.add(subpackage_name) + else: + # Continue searching in subpackage + try: + subpackage = importlib.import_module(subpackage_name) + subpackage_paths = _get_package_paths(subpackage) + + if subpackage_paths: + _find_modules_recursive( + subpackage_name, + subpackage_paths, + remaining_patterns, + found_modules, + ) + except (ImportError, AttributeError): + continue + + return found_modules + + +def find_modules_by_pattern(pattern: str) -> List[str]: + """Finds all modules that match the glob-like pattern for Python modules. + + Supports patterns like: + - example.*.models + - example.fo?.models + - example.*.*.models + - example.**.api.*.models + - example.api.v[12].models + - example.api.v[0-9].models + - example.api.v[!1].models + """ + # Validate pattern + if ".." in pattern: + raise ValueError(f"Invalid pattern '{pattern}': consecutive dots are not allowed") + + if "," in pattern: + raise ValueError(f"Invalid pattern '{pattern}': invalid delimiter, commas are not valid delimiters, use dots") + + # Split pattern into segments + segments = [s for s in pattern.split(".") if s] + + if not segments: + raise ValueError(f"Invalid pattern '{pattern}': pattern cannot be empty") + + # Find base package (all literal segments before first wildcard) + base_parts = [] + for seg in segments: + # Check if segment contains any wildcard characters + if any(c in seg for c in "*?["): + break + base_parts.append(seg) + + if not base_parts: + raise ValueError(f"Invalid pattern '{pattern}': pattern must start with at least one literal package name") + + # Import base package + base_name = ".".join(base_parts) + try: + base_package = importlib.import_module(base_name) + except ImportError as e: + raise ValueError(f"Cannot import base package '{base_name}': {e}") + + base_paths = _get_package_paths(base_package) + if not base_paths: + raise ValueError(f"Package '{base_name}' is not a package (no __path__)") + + # Remaining pattern segments + remaining_segments = segments[len(base_parts) :] + + # Start recursive search + found_modules = _find_modules_recursive( + base_name, + base_paths, + remaining_segments, + ) + + return sorted(list(found_modules)) diff --git a/paracelsus/graph.py b/paracelsus/graph.py index c66d091..84261ca 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -11,6 +11,7 @@ from .config import Layouts from .transformers.dot import Dot from .transformers.mermaid import Mermaid +from .finders import find_modules_by_pattern transformers: Dict[str, type[Union[Mermaid, Dot]]] = { "mmd": Mermaid, @@ -20,43 +21,23 @@ } -def find_modules_by_pattern(pattern: str) -> List[str]: - """Finds all modules that match the glob-like pattern for Python modules.""" - parts = pattern.split(".") +def _is_glob_pattern(pattern: str) -> bool: + """Check if a pattern contains any glob wildcard characters. - star_index = None - for i, part in enumerate(parts): - if part == "*": - star_index = i - break - - prefix_parts = parts[:star_index] - suffix_parts = parts[star_index + 1 :] - - if not suffix_parts: - raise ValueError(f"Glob pattern '{pattern}' must specify a module name after '*'. ") - - base_package_name = ".".join(prefix_parts) - base_package = importlib.import_module(base_package_name) - base_path = base_package.__path__[0] - - found_modules = [] - - for importer, modname, ispkg in pkgutil.iter_modules([base_path]): - # Form a subpackage name - subpackage_name = f"{base_package_name}.{modname}" + Glob patterns can contain: + - * (any string) + - ? (single character) + - ** (recursive) + - [abc], [0-9], [!1] (character classes) + """ - # Form the full name of the target module - target_module_name = f"{subpackage_name}.{'.'.join(suffix_parts)}" + if "*" in pattern or "?" in pattern: + return True - # Check that the module exists and add it to the list. - try: - importlib.import_module(target_module_name) - found_modules.append(target_module_name) - except ImportError: - continue + if "[" in pattern and "]" in pattern: + return True - return found_modules + return False def get_graph_string( @@ -92,7 +73,7 @@ def get_graph_string( search_pattern = module[:-2] if needs_wildcards_import else module - if "*" in search_pattern: + if _is_glob_pattern(search_pattern): # This is a glob pattern, find all the corresponding modules found_models = find_modules_by_pattern(search_pattern) diff --git a/tests/assets/character_classes/example/__init__.py b/tests/assets/character_classes/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/__init__.py b/tests/assets/character_classes/example/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v0/__init__.py b/tests/assets/character_classes/example/api/v0/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v0/models.py b/tests/assets/character_classes/example/api/v0/models.py new file mode 100644 index 0000000..3d6ecdd --- /dev/null +++ b/tests/assets/character_classes/example/api/v0/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V0Model(Base): + __tablename__ = "v0_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v1/__init__.py b/tests/assets/character_classes/example/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v1/models.py b/tests/assets/character_classes/example/api/v1/models.py new file mode 100644 index 0000000..28dc046 --- /dev/null +++ b/tests/assets/character_classes/example/api/v1/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V1Model(Base): + __tablename__ = "v1_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v10/__init__.py b/tests/assets/character_classes/example/api/v10/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v10/models.py b/tests/assets/character_classes/example/api/v10/models.py new file mode 100644 index 0000000..409bad9 --- /dev/null +++ b/tests/assets/character_classes/example/api/v10/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V10Model(Base): + __tablename__ = "v10_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v2/__init__.py b/tests/assets/character_classes/example/api/v2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v2/models.py b/tests/assets/character_classes/example/api/v2/models.py new file mode 100644 index 0000000..4217946 --- /dev/null +++ b/tests/assets/character_classes/example/api/v2/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V2Model(Base): + __tablename__ = "v2_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v3/__init__.py b/tests/assets/character_classes/example/api/v3/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v3/models.py b/tests/assets/character_classes/example/api/v3/models.py new file mode 100644 index 0000000..2061e93 --- /dev/null +++ b/tests/assets/character_classes/example/api/v3/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V3Model(Base): + __tablename__ = "v3_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v4/__init__.py b/tests/assets/character_classes/example/api/v4/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v4/models.py b/tests/assets/character_classes/example/api/v4/models.py new file mode 100644 index 0000000..05dc8b8 --- /dev/null +++ b/tests/assets/character_classes/example/api/v4/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V4Model(Base): + __tablename__ = "v4_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v5/__init__.py b/tests/assets/character_classes/example/api/v5/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v5/models.py b/tests/assets/character_classes/example/api/v5/models.py new file mode 100644 index 0000000..cd373a4 --- /dev/null +++ b/tests/assets/character_classes/example/api/v5/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V5Model(Base): + __tablename__ = "v5_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v6/__init__.py b/tests/assets/character_classes/example/api/v6/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v6/models.py b/tests/assets/character_classes/example/api/v6/models.py new file mode 100644 index 0000000..2ffb0a6 --- /dev/null +++ b/tests/assets/character_classes/example/api/v6/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V6Model(Base): + __tablename__ = "v6_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v7/__init__.py b/tests/assets/character_classes/example/api/v7/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v7/models.py b/tests/assets/character_classes/example/api/v7/models.py new file mode 100644 index 0000000..c89a94f --- /dev/null +++ b/tests/assets/character_classes/example/api/v7/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V7Model(Base): + __tablename__ = "v7_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v8/__init__.py b/tests/assets/character_classes/example/api/v8/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v8/models.py b/tests/assets/character_classes/example/api/v8/models.py new file mode 100644 index 0000000..7093ecc --- /dev/null +++ b/tests/assets/character_classes/example/api/v8/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V8Model(Base): + __tablename__ = "v8_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/v9/__init__.py b/tests/assets/character_classes/example/api/v9/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/v9/models.py b/tests/assets/character_classes/example/api/v9/models.py new file mode 100644 index 0000000..410002d --- /dev/null +++ b/tests/assets/character_classes/example/api/v9/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class V9Model(Base): + __tablename__ = "v9_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/va/__init__.py b/tests/assets/character_classes/example/api/va/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/va/models.py b/tests/assets/character_classes/example/api/va/models.py new file mode 100644 index 0000000..787b4c4 --- /dev/null +++ b/tests/assets/character_classes/example/api/va/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class VaModel(Base): + __tablename__ = "va_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/api/vb/__init__.py b/tests/assets/character_classes/example/api/vb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/character_classes/example/api/vb/models.py b/tests/assets/character_classes/example/api/vb/models.py new file mode 100644 index 0000000..a9bc41d --- /dev/null +++ b/tests/assets/character_classes/example/api/vb/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class VbModel(Base): + __tablename__ = "vb_table" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/character_classes/example/base.py b/tests/assets/character_classes/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/character_classes/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/recursive/example/__init__.py b/tests/assets/recursive/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/api/__init__.py b/tests/assets/recursive/example/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/api/v1/__init__.py b/tests/assets/recursive/example/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/api/v1/models.py b/tests/assets/recursive/example/api/v1/models.py new file mode 100644 index 0000000..fe4a781 --- /dev/null +++ b/tests/assets/recursive/example/api/v1/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class APIResource(Base): + __tablename__ = "api_resources" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/recursive/example/base.py b/tests/assets/recursive/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/recursive/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/recursive/example/domain/__init__.py b/tests/assets/recursive/example/domain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/domain/users/__init__.py b/tests/assets/recursive/example/domain/users/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/domain/users/models.py b/tests/assets/recursive/example/domain/users/models.py new file mode 100644 index 0000000..dfd6458 --- /dev/null +++ b/tests/assets/recursive/example/domain/users/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ...base import Base + + +class User(Base): + __tablename__ = "users" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/recursive/example/level1/__init__.py b/tests/assets/recursive/example/level1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/level1/level2/__init__.py b/tests/assets/recursive/example/level1/level2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/level1/level2/api/__init__.py b/tests/assets/recursive/example/level1/level2/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/level1/level2/api/v3/__init__.py b/tests/assets/recursive/example/level1/level2/api/v3/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/level1/level2/api/v3/models.py b/tests/assets/recursive/example/level1/level2/api/v3/models.py new file mode 100644 index 0000000..58f437d --- /dev/null +++ b/tests/assets/recursive/example/level1/level2/api/v3/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ......base import Base + + +class DeepAPIResource(Base): + __tablename__ = "deep_api_resources" + id = mapped_column(String, primary_key=True) diff --git a/tests/assets/recursive/example/something/__init__.py b/tests/assets/recursive/example/something/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/something/api/__init__.py b/tests/assets/recursive/example/something/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/something/api/v2/__init__.py b/tests/assets/recursive/example/something/api/v2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/recursive/example/something/api/v2/models.py b/tests/assets/recursive/example/something/api/v2/models.py new file mode 100644 index 0000000..e17bcbe --- /dev/null +++ b/tests/assets/recursive/example/something/api/v2/models.py @@ -0,0 +1,8 @@ +from sqlalchemy import String +from sqlalchemy.orm import mapped_column +from ....base import Base + + +class SomethingAPIResource(Base): + __tablename__ = "something_api_resources" + id = mapped_column(String, primary_key=True) diff --git a/tests/conftest.py b/tests/conftest.py index bc6605f..6662813 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -349,6 +349,38 @@ def multi_star_package_path() -> Generator[Path, None, None]: cleanup_sys_path(path_str) +@pytest.fixture +def character_classes_package_path() -> Generator[Path, None, None]: + """Create a package structure for testing character classes and ranges patterns.""" + template_path = Path(os.path.dirname(os.path.realpath(__file__))) / "assets" / "character_classes" + with tempfile.TemporaryDirectory() as package_path: + shutil.copytree(template_path, package_path, dirs_exist_ok=True) + package_dir = Path(package_path) + os.chdir(package_path) + + path_str = setup_sys_path_for_test(package_dir) + try: + yield Path(package_path) + finally: + cleanup_sys_path(path_str) + + +@pytest.fixture +def recursive_package_path() -> Generator[Path, None, None]: + """Create a package structure for testing recursive lookup patterns (**).""" + template_path = Path(os.path.dirname(os.path.realpath(__file__))) / "assets" / "recursive" + with tempfile.TemporaryDirectory() as package_path: + shutil.copytree(template_path, package_path, dirs_exist_ok=True) + package_dir = Path(package_path) + os.chdir(package_path) + + path_str = setup_sys_path_for_test(package_dir) + try: + yield Path(package_path) + finally: + cleanup_sys_path(path_str) + + @pytest.fixture def namespace_package_path() -> Generator[Path, None, None]: """Create a namespace package structure (PEP 420) for testing.""" diff --git a/tests/transformers/test_find_modules.py b/tests/transformers/test_find_modules.py index 269fecc..fa554fa 100644 --- a/tests/transformers/test_find_modules.py +++ b/tests/transformers/test_find_modules.py @@ -1,4 +1,5 @@ -from paracelsus.graph import find_modules_by_pattern +import pytest +from paracelsus.finders import find_modules_by_pattern def test_find_modules_by_pattern_single_level(single_level_package_path): @@ -55,3 +56,143 @@ def test_find_modules_by_pattern_namespace_package(namespace_package_path): assert len(found) == len(expected_modules) assert set(found) == expected_modules + + +def test_find_modules_by_pattern_single_character(single_level_package_path): + """Test glob pattern with single character matching (example.fo?.models).""" + + found = find_modules_by_pattern("example.fo?.models") + expected_modules = { + "example.foo.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + assert "example.bar.models" not in found + + +def test_find_modules_by_pattern_character_class(character_classes_package_path): + """Test glob pattern with character class (example.api.v[12].models). + + Character class [12] matches exactly one character: '1' or '2'. + """ + found = find_modules_by_pattern("example.api.v[12].models") + expected_modules = { + "example.api.v1.models", + "example.api.v2.models", + } + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + assert "example.api.v3.models" not in found + + +def test_find_modules_by_pattern_character_range(character_classes_package_path): + """Test glob pattern with character range (example.api.v[0-9].models). + + Character range [0-9] matches exactly one digit from 0 to 9. + """ + found = find_modules_by_pattern("example.api.v[0-9].models") + expected_modules = { + "example.api.v0.models", + "example.api.v1.models", + "example.api.v2.models", + "example.api.v3.models", + "example.api.v4.models", + "example.api.v5.models", + "example.api.v6.models", + "example.api.v7.models", + "example.api.v8.models", + "example.api.v9.models", + } + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + assert "example.api.v10.models" not in found + assert "example.api.va.models" not in found + + +def test_find_modules_by_pattern_complementation_character_class(character_classes_package_path): + """Test glob pattern with complementation character class (example.api.v[!1].models). + + Complementation [!1] matches any single character except '1'. + """ + found = find_modules_by_pattern("example.api.v[!1].models") + + assert "example.api.v0.models" in found + assert "example.api.v2.models" in found + assert "example.api.va.models" in found + assert "example.api.v1.models" not in found + + +def test_find_modules_by_pattern_complementation_character_range(character_classes_package_path): + """Test glob pattern with complementation range (example.api.v[!0-9].models). + + Complementation [!0-9] matches any single character except digits 0-9. + """ + found = find_modules_by_pattern("example.api.v[!0-9].models") + + assert "example.api.va.models" in found + assert "example.api.vb.models" in found + + for i in range(10): + assert f"example.api.v{i}.models" not in found + + +def test_find_modules_by_pattern_mixed_wildcards(multi_star_package_path): + """Test glob pattern with mixed wildcards (example.v?.*.*.models). + + Pattern combines single character match (v?) with any string matches (*). + v? matches one char (v1, v2, va, etc.), then *.* matches two package levels. + Example: example.v1.api.users.models, example.v2.api.products.models + """ + found = find_modules_by_pattern("example.v?.*.*.models") + expected_modules = { + "example.v1.api.users.models", + "example.v2.api.products.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + +def test_find_modules_by_pattern_recursive_lookup(recursive_package_path): + """Test glob pattern with recursive lookup (example.**.api.*.models). + + Recursive lookup (**) matches any number of package levels (0 or more). + Should find modules at different depths: + - example.api.v1.models (0 levels between example and api) + - example.something.api.v2.models (1 level deep) + - example.level1.level2.api.v3.models (2 levels deep) + """ + found = find_modules_by_pattern("example.**.api.*.models") + expected_modules = { + "example.api.v1.models", + "example.something.api.v2.models", + "example.level1.level2.api.v3.models", + } + + assert len(found) == len(expected_modules) + assert set(found) == expected_modules + + # Should not find modules that don't have 'api' in their path + assert "example.domain.users.models" not in found + + +# Error Cases +def test_find_modules_by_pattern_missing_rule_error(): + """Test that missing rule (example.v?..models) raises ValueError. + + Pattern 'v?..models' has two consecutive dots, which is invalid. + Should raise ValueError with descriptive message. + """ + with pytest.raises(ValueError, match=".*missing.*rule.*|.*invalid.*pattern.*|.*consecutive.*"): + find_modules_by_pattern("example.v?..models") + + +def test_find_modules_by_pattern_invalid_delimiter_error(): + """Test that invalid delimiter (example.v?,,models) raises ValueError. + + Pattern 'v?,,models' uses comma instead of dot as delimiter, which is invalid. + Should raise ValueError with descriptive message. + """ + with pytest.raises(ValueError, match=".*invalid.*delimiter.*|.*invalid.*pattern.*"): + find_modules_by_pattern("example.v?,,models") From 559b393a4a116c96a7eb9059650756f5eae4d699 Mon Sep 17 00:00:00 2001 From: TheLazzzies Date: Thu, 1 Jan 2026 22:28:52 +0300 Subject: [PATCH 06/20] Feature, Add Pattern model --- paracelsus/models/__init__.py | 0 paracelsus/models/base.py | 43 +++++++++++++++++++++++++++++++++++ paracelsus/models/pattern.py | 39 +++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+) create mode 100644 paracelsus/models/__init__.py create mode 100644 paracelsus/models/base.py create mode 100644 paracelsus/models/pattern.py diff --git a/paracelsus/models/__init__.py b/paracelsus/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/paracelsus/models/base.py b/paracelsus/models/base.py new file mode 100644 index 0000000..e34fb74 --- /dev/null +++ b/paracelsus/models/base.py @@ -0,0 +1,43 @@ +from typing import Callable, Generic, Protocol, Sequence, TypeVar, Type + +ReturnT = TypeVar("ReturnT") + + +class ValidationError(ValueError): + pass + +class ErrorContainer(Protocol): + errors: list[ValidationError] + + +class Attribute(Generic[ReturnT]): + """ + A simple descriptor to implement attributes validation upon assignment. + """ + + def __init__(self, type: Type, validators: Sequence[Callable[[str, ReturnT], ReturnT]] = ()): + self.type = type + self.validators = validators + + def __set_name__(self, owner, name): + self.name = name + + def __get__(self, instance: ErrorContainer, owner): + if not instance: + return self + return instance.__dict__[self.name] + + def __delete__(self, instance: ErrorContainer): + del instance.__dict__[self.name] + + def __set__(self, instance: ErrorContainer, value): + if not isinstance(value, self.type): + raise TypeError(f"{self.name!r} values must be of type {self.type!r}") + + for validator in self.validators: + try: + validator(self.name, value) + except ValidationError as e: + instance.errors.append(e) + + instance.__dict__[self.name] = value diff --git a/paracelsus/models/pattern.py b/paracelsus/models/pattern.py new file mode 100644 index 0000000..b4ba05d --- /dev/null +++ b/paracelsus/models/pattern.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass, field + +from .base import ValidationError, Attribute + + + +def forbid_wildcard_for_modules(name: str, value: str) -> str: + if value.endswith("**"): + raise ValidationError("Wildcard (**) not allowed for modules scope") + return value + + +def forbid_empty_segment(name: str, value: str) -> str: + if any(not segment for segment in value.split(".")): + raise ValidationError("Empty segment not allowed") + return value + + +@dataclass(init=True, frozen=True) +class Pattern: + errors: list[ValidationError] = field(default_factory=list) + mask: Attribute[str] = field( + default=Attribute[str](str, validators=[ + forbid_wildcard_for_modules, + forbid_empty_segment + ]) + ) + + @property + def tokens(self) -> list[str]: + return self.mask.split(".") + + @property + def serialized_errors(self) -> str: + messages = [f"Errors found for {self.mask}:"] + for error in self.errors: + messages.append(f" - {error}") + return "\n".join(messages) + From 7f35f75974acfd4b3221462deb514092f25b3afb Mon Sep 17 00:00:00 2001 From: TheLazzzies Date: Thu, 1 Jan 2026 22:35:47 +0300 Subject: [PATCH 07/20] Feature, Add ModuleFinder --- paracelsus/finders.py | 368 ++++++++++++++++-------------------------- 1 file changed, 135 insertions(+), 233 deletions(-) diff --git a/paracelsus/finders.py b/paracelsus/finders.py index 0a4aff7..99d892f 100644 --- a/paracelsus/finders.py +++ b/paracelsus/finders.py @@ -1,248 +1,150 @@ -import fnmatch -import importlib -import os -import pkgutil -import re -import types -from typing import List, Set, Optional - - -def _get_package_paths(package: types.ModuleType) -> List[str]: - """Get all paths for a package, handling namespace packages. - Namespace packages (PEP 420) can have multiple paths in __path__. - """ - if not hasattr(package, "__path__"): - return [] - - paths = package.__path__ - - # Handle _NamespacePath and other iterable path objects - try: - if hasattr(paths, "__iter__") and not isinstance(paths, (str, bytes)): - return [str(p) for p in paths] - except (TypeError, ValueError): - pass - - return [str(paths)] - - -def _match_pattern(name: str, pattern: str) -> bool: - """Match a name against a glob pattern. - - Supports standard fnmatch patterns plus character classes with prefix (e.g., "v[12]"). - """ - # Handle character classes with prefix (e.g., "v[12]") - # fnmatch doesn't support this directly, so we need custom handling - char_class_match = re.search(r"\[([!]?)([^\]]+)\]", pattern) - if char_class_match and not pattern.startswith("["): - # Pattern has a prefix before the character class - prefix = pattern[: char_class_match.start()] - char_class_content = char_class_match.group(2) - negation = char_class_match.group(1) == "!" - - # Check prefix first - if not name.startswith(prefix): - return False +from collections import deque +from dataclasses import dataclass +from pathlib import Path +from typing import Generator, Optional, Set - # Check character class - must match exactly one character - remaining = name[len(prefix) :] - if len(remaining) != 1: - return False - char = remaining[0] +@dataclass(frozen=True, eq=True) +class GlobNode: + pattern: str + next: Optional["GlobNode"] = None + + @staticmethod + def nodify(tokens: list[str]) -> Optional["GlobNode"]: + head = None + for part in reversed(tokens): + new_node = GlobNode(pattern=part, next=head) + head = new_node + return head + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, value: object) -> bool: + return isinstance(value, GlobNode) and self is value + + @property + def is_final(self): + return self.next is None + + def __repr__(self): + return f"Node({self.pattern})" + + +@dataclass(frozen=True) +class SearchState: + """Represents a snapshot of the traversal cursor.""" + + path: Path + node: GlobNode - # Handle ranges like [0-9] - if "-" in char_class_content and len(char_class_content) == 3: - start, end = char_class_content[0], char_class_content[2] - if negation: - return not (start <= char <= end) - return start <= char <= end - # Handle character sets like [12] or [abc] - if negation: - return char not in char_class_content - return char in char_class_content +class ModuleFinder: + def __init__(self, root: Path, segments: list[str]): + self.root = root + self.head = GlobNode.nodify(segments) - # Use fnmatch for standard patterns (*, ?, [abc], etc.) - return fnmatch.fnmatch(name, pattern) + self.queue: deque[SearchState] = deque() + # To prevent infinite loops with symlinks or redundant '**' paths + self.visited: Set[SearchState] = set() + def find(self) -> Generator[Path, None, None]: + """ + Finds all modules that match the glob-like pattern for Python modules. -def _get_modules_in_path(path: str, package_name: str) -> Set[tuple[str, bool]]: - """Get all modules and packages in a given path.""" + Supports patterns like: + - example.*.models + - example.fo?.models + - example.*.*.models + - example.**.api.*.models + - example.api.v[12].models + - example.api.v[0-9].models + - example.api.v[!1].models - items = set() + """ + if self.head is None: + return - # Use pkgutil for standard modules/packages - try: - for importer, modname, ispkg in pkgutil.iter_modules([path]): - items.add((modname, ispkg)) - except (OSError, TypeError): - pass + # Initialize state + self.queue.append(SearchState(self.root, self.head)) - # Also check directories for namespace packages (without __init__.py) - try: - for item in os.listdir(path): - if item.startswith(".") or item == "__pycache__": + while self.queue: + state = self.queue.popleft() + + # Optimization: distinct paths to the same state are redundant + if state in self.visited: continue - item_path = os.path.join(path, item) - if os.path.isdir(item_path) and item not in [name for name, _ in items]: - # Try to import to verify it's a valid package - try: - full_name = f"{package_name}.{item}" if package_name else item - test_module = importlib.import_module(full_name) - is_package = hasattr(test_module, "__path__") - items.add((item, is_package)) - except (ImportError, ValueError): - pass - except (OSError, PermissionError): - pass - - return items - - -def _find_modules_recursive( - package_name: str, - package_paths: List[str], - pattern_segments: List[str], - found_modules: Optional[Set[str]] = None, -) -> Set[str]: - """Recursively find modules matching pattern segments.""" - - if found_modules is None: - found_modules = set() - - # Base case: no more segments to match - if not pattern_segments: - found_modules.add(package_name) - return found_modules - - current_pattern = pattern_segments[0] - remaining_patterns = pattern_segments[1:] - - # Handle recursive pattern (**) - if current_pattern == "**": - # ** matches zero or more levels - # First, try matching remaining segments at current level (zero levels) - if remaining_patterns: - _find_modules_recursive( - package_name, - package_paths, - remaining_patterns, - found_modules, - ) - - # Then, recursively search all subpackages (one or more levels) - for path in package_paths: - path_str = str(path) if not isinstance(path, str) else path - items = _get_modules_in_path(path_str, package_name) - - for modname, ispkg in items: - subpackage_name = f"{package_name}.{modname}" if package_name else modname - try: - subpackage = importlib.import_module(subpackage_name) - subpackage_paths = _get_package_paths(subpackage) - - if subpackage_paths: - # Continue recursive search with ** pattern - _find_modules_recursive( - subpackage_name, - subpackage_paths, - pattern_segments, - found_modules, - ) - except (ImportError, AttributeError): + self.visited.add(state) + + yield from self._process_state(state) + + def _process_state(self, state: SearchState) -> Generator[Path, None, None]: + """ + Implement BFS + """ + node = state.node + path = state.path + + # === 1. Recursive Wildcard (**) === + if node.pattern == "**": + # Branch A: Skip (0 matches). + # Move to next node, keep path same. + if node.next: + self.queue.append(SearchState(path, node.next)) + + # Branch B: Consume (1+ matches). + # Stay on current node, move deeper into filesystem. + for child in self._safe_iterdir(path): + if child.is_dir(): + self.queue.append(SearchState(child, node)) + + else: + for child in self._safe_iterdir(path): + is_match = False + + if child.is_dir(): + if child.match(node.pattern): + is_match = True + + elif child.is_file(): + if Path(child.stem).match(node.pattern): + is_match = True + + if not is_match: continue - # Handle normal segments - else: - for path in package_paths: - path_str = str(path) if not isinstance(path, str) else path - items = _get_modules_in_path(path_str, package_name) + if node.is_final and self._is_valid_module(child): + yield child - for modname, ispkg in items: - # Check if name matches current pattern - if not _match_pattern(modname, current_pattern): - continue + elif node.next is not None and child.is_dir(): + self.queue.append(SearchState(child, node.next)) + + def _safe_iterdir(self, path: Path) -> Generator[Path, None, None]: + """Safe wrapper around iterdir to handle permission errors.""" + try: + if path.is_dir(): + yield from path.iterdir() + except PermissionError: + pass + + def _is_valid_module(self, path: Path) -> bool: + """ + Determines if a path is a valid python module. + 1. File: my_module.py (but not __init__.py) + 2. Package: my_package/ (must contain __init__.py) + 3. Supports PEP 420 Namespace Packages. + """ + + name = path.name + + # 1. Ignore common garbage/internal directories + if name == "__pycache__" or name.startswith("."): + return False + + if path.is_file(): + return path.suffix == ".py" and path.name != "__init__.py" + + if path.is_dir(): + return (path / "__init__.py").exists() or name.isidentifier() - subpackage_name = f"{package_name}.{modname}" if package_name else modname - - # If this is the last segment, add it - if not remaining_patterns: - found_modules.add(subpackage_name) - else: - # Continue searching in subpackage - try: - subpackage = importlib.import_module(subpackage_name) - subpackage_paths = _get_package_paths(subpackage) - - if subpackage_paths: - _find_modules_recursive( - subpackage_name, - subpackage_paths, - remaining_patterns, - found_modules, - ) - except (ImportError, AttributeError): - continue - - return found_modules - - -def find_modules_by_pattern(pattern: str) -> List[str]: - """Finds all modules that match the glob-like pattern for Python modules. - - Supports patterns like: - - example.*.models - - example.fo?.models - - example.*.*.models - - example.**.api.*.models - - example.api.v[12].models - - example.api.v[0-9].models - - example.api.v[!1].models - """ - # Validate pattern - if ".." in pattern: - raise ValueError(f"Invalid pattern '{pattern}': consecutive dots are not allowed") - - if "," in pattern: - raise ValueError(f"Invalid pattern '{pattern}': invalid delimiter, commas are not valid delimiters, use dots") - - # Split pattern into segments - segments = [s for s in pattern.split(".") if s] - - if not segments: - raise ValueError(f"Invalid pattern '{pattern}': pattern cannot be empty") - - # Find base package (all literal segments before first wildcard) - base_parts = [] - for seg in segments: - # Check if segment contains any wildcard characters - if any(c in seg for c in "*?["): - break - base_parts.append(seg) - - if not base_parts: - raise ValueError(f"Invalid pattern '{pattern}': pattern must start with at least one literal package name") - - # Import base package - base_name = ".".join(base_parts) - try: - base_package = importlib.import_module(base_name) - except ImportError as e: - raise ValueError(f"Cannot import base package '{base_name}': {e}") - - base_paths = _get_package_paths(base_package) - if not base_paths: - raise ValueError(f"Package '{base_name}' is not a package (no __path__)") - - # Remaining pattern segments - remaining_segments = segments[len(base_parts) :] - - # Start recursive search - found_modules = _find_modules_recursive( - base_name, - base_paths, - remaining_segments, - ) - - return sorted(list(found_modules)) + return False From e6afae30168c0ba3288febf33a49e35ebb9b6e1d Mon Sep 17 00:00:00 2001 From: TheLazzzies Date: Fri, 2 Jan 2026 00:37:25 +0300 Subject: [PATCH 08/20] Feature, Fix validation issues --- paracelsus/models/base.py | 8 ++++++-- paracelsus/models/pattern.py | 20 ++++++++++++++------ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/paracelsus/models/base.py b/paracelsus/models/base.py index e34fb74..bcb1aa0 100644 --- a/paracelsus/models/base.py +++ b/paracelsus/models/base.py @@ -6,8 +6,12 @@ class ValidationError(ValueError): pass + class ErrorContainer(Protocol): - errors: list[ValidationError] + def add_error(self, error: ValidationError) -> None: ... + + @property + def errors(self) -> list[ValidationError]: ... class Attribute(Generic[ReturnT]): @@ -38,6 +42,6 @@ def __set__(self, instance: ErrorContainer, value): try: validator(self.name, value) except ValidationError as e: - instance.errors.append(e) + instance.add_error(e) instance.__dict__[self.name] = value diff --git a/paracelsus/models/pattern.py b/paracelsus/models/pattern.py index b4ba05d..b403cce 100644 --- a/paracelsus/models/pattern.py +++ b/paracelsus/models/pattern.py @@ -3,6 +3,11 @@ from .base import ValidationError, Attribute +def forbid_empty_path(name: str, value: str) -> str: + if not value: + raise ValidationError("Empty path not allowed") + return value + def forbid_wildcard_for_modules(name: str, value: str) -> str: if value.endswith("**"): @@ -18,14 +23,18 @@ def forbid_empty_segment(name: str, value: str) -> str: @dataclass(init=True, frozen=True) class Pattern: - errors: list[ValidationError] = field(default_factory=list) + _errors: list[ValidationError] = field(default_factory=list) mask: Attribute[str] = field( - default=Attribute[str](str, validators=[ - forbid_wildcard_for_modules, - forbid_empty_segment - ]) + default=Attribute[str](str, validators=[forbid_wildcard_for_modules, forbid_empty_segment, forbid_empty_path]) ) + def add_error(self, error: ValidationError) -> None: + self._errors.append(error) + + @property + def errors(self) -> list[ValidationError]: + return self._errors + @property def tokens(self) -> list[str]: return self.mask.split(".") @@ -36,4 +45,3 @@ def serialized_errors(self) -> str: for error in self.errors: messages.append(f" - {error}") return "\n".join(messages) - From 8c3a769e226a79db093e6917074e1129158cf6be Mon Sep 17 00:00:00 2001 From: TheLazzzies Date: Fri, 2 Jan 2026 00:54:02 +0300 Subject: [PATCH 09/20] Feature, Reimplment module lookup block. Update tests --- paracelsus/graph.py | 81 +++++++++--------- tests/transformers/test_find_modules.py | 109 +++++++++++++++++------- 2 files changed, 119 insertions(+), 71 deletions(-) diff --git a/paracelsus/graph.py b/paracelsus/graph.py index 84261ca..2472935 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -1,17 +1,18 @@ import importlib import os +import re import sys +from concurrent.futures import ThreadPoolExecutor from pathlib import Path -import re -import pkgutil from typing import List, Set, Optional, Dict, Union +from paracelsus.models.pattern import Pattern from sqlalchemy.schema import MetaData from .config import Layouts from .transformers.dot import Dot from .transformers.mermaid import Mermaid -from .finders import find_modules_by_pattern +from .finders import ModuleFinder transformers: Dict[str, type[Union[Mermaid, Dot]]] = { "mmd": Mermaid, @@ -21,23 +22,31 @@ } -def _is_glob_pattern(pattern: str) -> bool: - """Check if a pattern contains any glob wildcard characters. +def do_import(needs_wildcards_import: bool, module_path: str) -> bool: + if needs_wildcards_import: + exec(f"from {module_path} import *") + else: + importlib.import_module(module_path) + return True - Glob patterns can contain: - - * (any string) - - ? (single character) - - ** (recursive) - - [abc], [0-9], [!1] (character classes) - """ - if "*" in pattern or "?" in pattern: - return True +def to_module_name(root: Path, path: Path) -> str: + """ + Converts a filesystem path to a Python dotted module string. + Example: /root/app/models.py -> app.models + """ + try: + relative_path = path.resolve().relative_to(root) + except ValueError: + # Fallback if path is not relative to root (should not happen in normal usage) + return path.name - if "[" in pattern and "]" in pattern: - return True + if path.is_file(): + clean_path = relative_path.with_suffix("") + else: + clean_path = relative_path - return False + return ".".join(clean_path.parts) def get_graph_string( @@ -68,28 +77,24 @@ def get_graph_string( # The modules holding the model classes have to be imported to get put in the metaclass model registry. # These modules aren't actually used in any way, so they are discarded. # They are also imported in scope of this function to prevent namespace pollution. - for module in import_module: - needs_wildcards_import = module.endswith(":*") - - search_pattern = module[:-2] if needs_wildcards_import else module - - if _is_glob_pattern(search_pattern): - # This is a glob pattern, find all the corresponding modules - found_models = find_modules_by_pattern(search_pattern) - - for found_model in found_models: - if needs_wildcards_import: - # Combination: glob search + wildcard import - exec(f"from {found_model} import *") - else: - # Glob search only, normal import - importlib.import_module(found_model) - elif needs_wildcards_import: - # Wildcard import only - exec(f"from {search_pattern} import *") - else: - # Normal module import - importlib.import_module(module) + for module_lookup_mask in import_module: + module_path, import_modifier = module_lookup_mask, None + + if module_path.endswith(":*"): + module_path, import_modifier = module_path.split(":", 1) + + pattern = Pattern(mask=module_path) + + if any(pattern.errors): + raise ValueError(pattern.serialized_errors) + + current_root = Path.cwd() + finder = ModuleFinder(current_root, pattern.tokens) + + with ThreadPoolExecutor() as executor: + for found_module_path in finder.find(): + dot_path = to_module_name(current_root, found_module_path) + executor.submit(do_import, import_modifier == "*", dot_path) # Grab a transformer. if format not in transformers: diff --git a/tests/transformers/test_find_modules.py b/tests/transformers/test_find_modules.py index fa554fa..ac4f5be 100644 --- a/tests/transformers/test_find_modules.py +++ b/tests/transformers/test_find_modules.py @@ -1,11 +1,18 @@ import pytest -from paracelsus.finders import find_modules_by_pattern + +from paracelsus.graph import to_module_name +from paracelsus.finders import ModuleFinder +from paracelsus.models.pattern import Pattern def test_find_modules_by_pattern_single_level(single_level_package_path): """Test basic glob pattern matching with single-level subpackages.""" - found = find_modules_by_pattern("example.*.models") + pattern = Pattern(mask="example.*.models") + found = [ + to_module_name(single_level_package_path, module_path) + for module_path in ModuleFinder(single_level_package_path, pattern.tokens).find() + ] expected_modules = { "example.foo.models", "example.bar.models", @@ -18,7 +25,11 @@ def test_find_modules_by_pattern_single_level(single_level_package_path): def test_find_modules_by_pattern_nested_levels(nested_package_path): """Test glob pattern with nested levels (example.*.*.models).""" - found = find_modules_by_pattern("example.*.*.models") + pattern = Pattern(mask="example.*.*.models") + found = [ + to_module_name(nested_package_path, module_path) + for module_path in ModuleFinder(nested_package_path, pattern.tokens).find() + ] expected_modules = { "example.domain.users.models", "example.domain.products.models", @@ -32,7 +43,11 @@ def test_find_modules_by_pattern_nested_levels(nested_package_path): def test_find_modules_by_pattern_multiple_stars(multi_star_package_path): """Test glob pattern with multiple stars (example.*.api.*.models).""" - found = find_modules_by_pattern("example.*.api.*.models") + pattern = Pattern(mask="example.*.api.*.models") + found = [ + to_module_name(multi_star_package_path, module_path) + for module_path in ModuleFinder(multi_star_package_path, pattern.tokens).find() + ] expected_modules = { "example.v1.api.users.models", "example.v2.api.products.models", @@ -48,10 +63,14 @@ def test_find_modules_by_pattern_namespace_package(namespace_package_path): Should handle namespace packages where __path__ is a list of paths. """ - found = find_modules_by_pattern("example.*.models") + pattern = Pattern(mask="project*.example.*.models") + found = [ + to_module_name(namespace_package_path, module_path) + for module_path in ModuleFinder(namespace_package_path, pattern.tokens).find() + ] expected_modules = { - "example.subpackage_a.models", - "example.subpackage_b.models", + "project1.example.subpackage_a.models", + "project2.example.subpackage_b.models", } assert len(found) == len(expected_modules) @@ -61,7 +80,11 @@ def test_find_modules_by_pattern_namespace_package(namespace_package_path): def test_find_modules_by_pattern_single_character(single_level_package_path): """Test glob pattern with single character matching (example.fo?.models).""" - found = find_modules_by_pattern("example.fo?.models") + pattern = Pattern(mask="example.fo?.models") + found = [ + to_module_name(single_level_package_path, module_path) + for module_path in ModuleFinder(single_level_package_path, pattern.tokens).find() + ] expected_modules = { "example.foo.models", } @@ -76,7 +99,11 @@ def test_find_modules_by_pattern_character_class(character_classes_package_path) Character class [12] matches exactly one character: '1' or '2'. """ - found = find_modules_by_pattern("example.api.v[12].models") + pattern = Pattern(mask="example.api.v[12].models") + found = [ + to_module_name(character_classes_package_path, module_path) + for module_path in ModuleFinder(character_classes_package_path, pattern.tokens).find() + ] expected_modules = { "example.api.v1.models", "example.api.v2.models", @@ -91,7 +118,11 @@ def test_find_modules_by_pattern_character_range(character_classes_package_path) Character range [0-9] matches exactly one digit from 0 to 9. """ - found = find_modules_by_pattern("example.api.v[0-9].models") + pattern = Pattern(mask="example.api.v[0-9].models") + found = [ + to_module_name(character_classes_package_path, module_path) + for module_path in ModuleFinder(character_classes_package_path, pattern.tokens).find() + ] expected_modules = { "example.api.v0.models", "example.api.v1.models", @@ -115,7 +146,11 @@ def test_find_modules_by_pattern_complementation_character_class(character_class Complementation [!1] matches any single character except '1'. """ - found = find_modules_by_pattern("example.api.v[!1].models") + pattern = Pattern(mask="example.api.v[!1].models") + found = [ + to_module_name(character_classes_package_path, module_path) + for module_path in ModuleFinder(character_classes_package_path, pattern.tokens).find() + ] assert "example.api.v0.models" in found assert "example.api.v2.models" in found @@ -128,7 +163,11 @@ def test_find_modules_by_pattern_complementation_character_range(character_class Complementation [!0-9] matches any single character except digits 0-9. """ - found = find_modules_by_pattern("example.api.v[!0-9].models") + pattern = Pattern(mask="example.api.v[!0-9].models") + found = [ + to_module_name(character_classes_package_path, module_path) + for module_path in ModuleFinder(character_classes_package_path, pattern.tokens).find() + ] assert "example.api.va.models" in found assert "example.api.vb.models" in found @@ -144,7 +183,11 @@ def test_find_modules_by_pattern_mixed_wildcards(multi_star_package_path): v? matches one char (v1, v2, va, etc.), then *.* matches two package levels. Example: example.v1.api.users.models, example.v2.api.products.models """ - found = find_modules_by_pattern("example.v?.*.*.models") + pattern = Pattern(mask="example.v?.*.*.models") + found = [ + to_module_name(multi_star_package_path, module_path) + for module_path in ModuleFinder(multi_star_package_path, pattern.tokens).find() + ] expected_modules = { "example.v1.api.users.models", "example.v2.api.products.models", @@ -163,7 +206,11 @@ def test_find_modules_by_pattern_recursive_lookup(recursive_package_path): - example.something.api.v2.models (1 level deep) - example.level1.level2.api.v3.models (2 levels deep) """ - found = find_modules_by_pattern("example.**.api.*.models") + pattern = Pattern(mask="example.**.api.*.models") + found = [ + to_module_name(recursive_package_path, module_path) + for module_path in ModuleFinder(recursive_package_path, pattern.tokens).find() + ] expected_modules = { "example.api.v1.models", "example.something.api.v2.models", @@ -177,22 +224,18 @@ def test_find_modules_by_pattern_recursive_lookup(recursive_package_path): assert "example.domain.users.models" not in found -# Error Cases -def test_find_modules_by_pattern_missing_rule_error(): - """Test that missing rule (example.v?..models) raises ValueError. - - Pattern 'v?..models' has two consecutive dots, which is invalid. - Should raise ValueError with descriptive message. - """ - with pytest.raises(ValueError, match=".*missing.*rule.*|.*invalid.*pattern.*|.*consecutive.*"): - find_modules_by_pattern("example.v?..models") - - -def test_find_modules_by_pattern_invalid_delimiter_error(): - """Test that invalid delimiter (example.v?,,models) raises ValueError. - - Pattern 'v?,,models' uses comma instead of dot as delimiter, which is invalid. - Should raise ValueError with descriptive message. - """ - with pytest.raises(ValueError, match=".*invalid.*delimiter.*|.*invalid.*pattern.*"): - find_modules_by_pattern("example.v?,,models") +@pytest.mark.skip(reason="Need to implement validation rules") +@pytest.mark.parametrize( + "pattern", + [ + "example.v,.models", # Wrong grammar token + "example.v?..models", # Empty tokens + "example.v[1.models", # Unclosed + "example.v]1[.models", # Reversed + "example.v[1[2]].models", # Nested + ], +) +def test_find_modules_by_pattern_missing_rule_error(pattern): + """Test token validation rule (example.v?..models) raises ValueError.""" + pattern = Pattern(mask=pattern) + assert any(pattern.errors) From 2a0fc66f78e6ab85dc88e28b5ac88e3bd3195935 Mon Sep 17 00:00:00 2001 From: TheLazzzies Date: Fri, 2 Jan 2026 00:54:47 +0300 Subject: [PATCH 10/20] Feature, Add pre-commit to optional dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 3add789..e4f106b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ requires-python = ">= 3.10" dev = [ "build", "dapperdata", + "pre-commit>=4.5.1", "glom", "mypy", "pip-tools", From eabad477cdacf32fc7a3b432585b6c64a565f805 Mon Sep 17 00:00:00 2001 From: TheLazzzies Date: Fri, 2 Jan 2026 22:59:14 +0300 Subject: [PATCH 11/20] Feature, Improve the method description for state processing --- paracelsus/finders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paracelsus/finders.py b/paracelsus/finders.py index 99d892f..62dd23c 100644 --- a/paracelsus/finders.py +++ b/paracelsus/finders.py @@ -80,7 +80,7 @@ def find(self) -> Generator[Path, None, None]: def _process_state(self, state: SearchState) -> Generator[Path, None, None]: """ - Implement BFS + Implement BFS to search for files and directories matching the given pattern. """ node = state.node path = state.path From 408019b12312120d94b58d0b38e78757e46bbf38 Mon Sep 17 00:00:00 2001 From: TheLazzzies Date: Fri, 2 Jan 2026 23:01:46 +0300 Subject: [PATCH 12/20] Feature, Replace pool executor with a separate thread --- paracelsus/graph.py | 53 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/paracelsus/graph.py b/paracelsus/graph.py index 2472935..f263297 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -1,18 +1,21 @@ import importlib +import logging import os import re import sys -from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import List, Set, Optional, Dict, Union +from queue import Queue +from threading import Thread +from typing import Dict, List, Optional, Set, Union -from paracelsus.models.pattern import Pattern from sqlalchemy.schema import MetaData +from paracelsus.models.pattern import Pattern + from .config import Layouts +from .finders import ModuleFinder from .transformers.dot import Dot from .transformers.mermaid import Mermaid -from .finders import ModuleFinder transformers: Dict[str, type[Union[Mermaid, Dot]]] = { "mmd": Mermaid, @@ -21,6 +24,8 @@ "gv": Dot, } +logger = logging.getLogger(__name__) + def do_import(needs_wildcards_import: bool, module_path: str) -> bool: if needs_wildcards_import: @@ -49,6 +54,31 @@ def to_module_name(root: Path, path: Path) -> str: return ".".join(clean_path.parts) +def consume_import_tasks(queue: Queue[dict], sentinel: object): + while True: + item = queue.get() + + if item is sentinel: + break + + needs_wildcards_import, module_name = item.values() + try: + # Check if already loaded to save time + if module_name in sys.modules: + continue + + if needs_wildcards_import: + exec(f"from {module_name} import *") + else: + importlib.import_module(module_name) + + except ImportError as e: + logger.error(f"Failed to load {module_name}: {e}") + raise e + finally: + queue.task_done() + + def get_graph_string( *, base_class_path: str, @@ -74,6 +104,10 @@ def get_graph_string( base_class = getattr(base_module, class_name) metadata = base_class.metadata + import_queue_sentinel = object() + import_queue: Queue[Union[Dict[str, str], object]] = Queue() + import_worker = Thread(target=consume_import_tasks, args=(import_queue, import_queue_sentinel), daemon=True) + import_worker.start() # The modules holding the model classes have to be imported to get put in the metaclass model registry. # These modules aren't actually used in any way, so they are discarded. # They are also imported in scope of this function to prevent namespace pollution. @@ -90,11 +124,14 @@ def get_graph_string( current_root = Path.cwd() finder = ModuleFinder(current_root, pattern.tokens) + needs_wildcards_import = import_modifier == "*" + + for file_path in finder.find(): + module_path = to_module_name(current_root, file_path) + import_queue.put({"needs_wildcards_import": needs_wildcards_import, "module_name": module_path}) - with ThreadPoolExecutor() as executor: - for found_module_path in finder.find(): - dot_path = to_module_name(current_root, found_module_path) - executor.submit(do_import, import_modifier == "*", dot_path) + import_queue.put(import_queue_sentinel) + import_worker.join() # Grab a transformer. if format not in transformers: From de0a59882203b20e5a284a53659c71d98e96f2d0 Mon Sep 17 00:00:00 2001 From: TheLazzzies Date: Fri, 2 Jan 2026 23:02:46 +0300 Subject: [PATCH 13/20] Feature, Add validation rules for pattern masks --- paracelsus/models/pattern.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/paracelsus/models/pattern.py b/paracelsus/models/pattern.py index b403cce..0a2f79c 100644 --- a/paracelsus/models/pattern.py +++ b/paracelsus/models/pattern.py @@ -1,6 +1,7 @@ +import re from dataclasses import dataclass, field -from .base import ValidationError, Attribute +from .base import Attribute, ValidationError def forbid_empty_path(name: str, value: str) -> str: @@ -21,11 +22,42 @@ def forbid_empty_segment(name: str, value: str) -> str: return value +def enforce_globbing_grammar(name: str, value: str) -> str: + regex = re.compile(r"^[a-zA-Z0-9_*]+(\.[a-zA-Z0-9_*]+)*$") + if not regex.fullmatch(value): + raise ValidationError("Invalid globbing pattern") + return value + + +def forbid_greedy_lookup(name: str, value: str) -> str: + tokens = value.split(".") + for i in range(len(tokens) - 1): + current = tokens[i] + next_token = tokens[i + 1] + + if current == "**" and next_token == "*": + raise ValidationError( + f"Invalid Mask: '**. *' is ambiguous and forbidden. " + f"Found at segment {i}: '...{current}.{next_token}...'" + ) + + return value + + @dataclass(init=True, frozen=True) class Pattern: _errors: list[ValidationError] = field(default_factory=list) mask: Attribute[str] = field( - default=Attribute[str](str, validators=[forbid_wildcard_for_modules, forbid_empty_segment, forbid_empty_path]) + default=Attribute[str]( + str, + validators=[ + enforce_globbing_grammar, + forbid_wildcard_for_modules, + forbid_greedy_lookup, + forbid_empty_segment, + forbid_empty_path, + ], + ) ) def add_error(self, error: ValidationError) -> None: From 8eb7a619489b15311be7bb67719cdf704e69a18b Mon Sep 17 00:00:00 2001 From: TheLazzzies Date: Fri, 2 Jan 2026 23:03:15 +0300 Subject: [PATCH 14/20] Feature, Add base.py module to the namespace case --- tests/assets/namespace/project2/example/base.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 tests/assets/namespace/project2/example/base.py diff --git a/tests/assets/namespace/project2/example/base.py b/tests/assets/namespace/project2/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/namespace/project2/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() From 2156ce102f12b53254569dbf24d38c59782b31cf Mon Sep 17 00:00:00 2001 From: TheLazzzies Date: Fri, 2 Jan 2026 23:04:07 +0300 Subject: [PATCH 15/20] Feature, Fix validation tests for patterns. Add tests for get_graph_string --- tests/test_graph.py | 28 +++++++++++++++++++++++++ tests/transformers/test_find_modules.py | 7 ++++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/tests/test_graph.py b/tests/test_graph.py index ec182ac..0371e98 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -20,6 +20,34 @@ def test_get_graph_string(column_sort_arg, package_path): mermaid_assert(graph_string) +@pytest.mark.skip(reason="Update mermaid_assert function to dynamically detect required models for validation") +def test_get_graph_string_with_wildcard(single_level_package_path): + get_graph_string( + base_class_path="example.base:Base", + import_module=["example.*.models"], + include_tables=set(), + exclude_tables=set(), + python_dir=[single_level_package_path], + format="mermaid", + column_sort="key-based", + ) + # mermaid_assert(graph_string) + + +@pytest.mark.skip(reason="Update mermaid_assert function to dynamically detect required models for validation") +def test_get_graph_with_wildcard_mask_in_namespace_package(namespace_package_path): + get_graph_string( + base_class_path="project1.example.base:Base", # @TODO: How to resolve a base class within separate multiple packages + import_module=["project*.example.*.models"], + include_tables=set(), + exclude_tables=set(), + python_dir=[namespace_package_path], + format="mermaid", + column_sort="key-based", + ) + # mermaid_assert(graph_string) + + def test_get_graph_string_with_exclude(package_path): """Excluding tables removes them from the graph string.""" graph_string = get_graph_string( diff --git a/tests/transformers/test_find_modules.py b/tests/transformers/test_find_modules.py index ac4f5be..bcb3947 100644 --- a/tests/transformers/test_find_modules.py +++ b/tests/transformers/test_find_modules.py @@ -1,7 +1,7 @@ import pytest -from paracelsus.graph import to_module_name from paracelsus.finders import ModuleFinder +from paracelsus.graph import to_module_name from paracelsus.models.pattern import Pattern @@ -224,7 +224,6 @@ def test_find_modules_by_pattern_recursive_lookup(recursive_package_path): assert "example.domain.users.models" not in found -@pytest.mark.skip(reason="Need to implement validation rules") @pytest.mark.parametrize( "pattern", [ @@ -233,9 +232,11 @@ def test_find_modules_by_pattern_recursive_lookup(recursive_package_path): "example.v[1.models", # Unclosed "example.v]1[.models", # Reversed "example.v[1[2]].models", # Nested + "example.v[**.models", # Invalid recursive lookup + "example.**.*.models", # Greedy lookup ], ) def test_find_modules_by_pattern_missing_rule_error(pattern): - """Test token validation rule (example.v?..models) raises ValueError.""" + """Test token validation rule. Must raise errors""" pattern = Pattern(mask=pattern) assert any(pattern.errors) From c5d99f2d5d847afffdf6e622976164d73cbadc6f Mon Sep 17 00:00:00 2001 From: TheLazzzies Date: Fri, 2 Jan 2026 23:18:01 +0300 Subject: [PATCH 16/20] Feature, Remove do_import --- paracelsus/graph.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/paracelsus/graph.py b/paracelsus/graph.py index f263297..8f2992e 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -27,14 +27,6 @@ logger = logging.getLogger(__name__) -def do_import(needs_wildcards_import: bool, module_path: str) -> bool: - if needs_wildcards_import: - exec(f"from {module_path} import *") - else: - importlib.import_module(module_path) - return True - - def to_module_name(root: Path, path: Path) -> str: """ Converts a filesystem path to a Python dotted module string. From ca98f29e77d021d55556377ce85b3f59690eb1e6 Mon Sep 17 00:00:00 2001 From: Apti Date: Wed, 14 Jan 2026 18:10:11 +0300 Subject: [PATCH 17/20] Separate graph building from serialization and add dynamic comparison --- paracelsus/graph.py | 346 ++++++++++++++++++++++-- tests/test_graph.py | 126 ++++++++- tests/transformers/test_find_modules.py | 23 +- tests/utils.py | 49 +++- 4 files changed, 490 insertions(+), 54 deletions(-) diff --git a/paracelsus/graph.py b/paracelsus/graph.py index 8f2992e..e39c9f0 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -71,30 +71,149 @@ def consume_import_tasks(queue: Queue[dict], sentinel: object): queue.task_done() -def get_graph_string( +def _find_base_classes_by_pattern( + base_class_path: str, + python_dir: List[Path], + current_root: Path, +) -> List[tuple[str, MetaData]]: + """ + Finds all base classes matching a glob pattern and returns their MetaData. + """ + if "*" not in base_class_path and "?" not in base_class_path: + # No wildcards, return single base class + module_path, class_name = base_class_path.split(":", 2) + try: + base_module = importlib.import_module(module_path) + base_class = getattr(base_module, class_name) + return [(module_path, base_class.metadata)] + except (ImportError, AttributeError) as e: + raise ValueError(f"Could not import base class from {base_class_path}: {e}") + + # Extract pattern parts + parts = base_class_path.split(":") + if len(parts) != 2: + raise ValueError(f"Invalid base_class_path format: {base_class_path}") + + pattern_str, class_name = parts + + # Create pattern for finding base.py modules + pattern = Pattern(mask=pattern_str) + + if any(pattern.errors): + raise ValueError(pattern.serialized_errors) + + # Find all matching base.py files + finder = ModuleFinder(current_root, pattern.tokens) + base_metadata_list = [] + + for file_path in finder.find(): + # Only consider base.py files + if file_path.name != "base.py" and not file_path.name.endswith("base.py"): + continue + + module_path = to_module_name(current_root, file_path) + + try: + base_module = importlib.import_module(module_path) + if hasattr(base_module, class_name): + base_class = getattr(base_module, class_name) + base_metadata_list.append((module_path, base_class.metadata)) + except (ImportError, AttributeError) as e: + logger.warning(f"Could not import base class from {module_path}: {e}") + continue + + if not base_metadata_list: + raise ValueError(f"No base classes found matching pattern: {base_class_path}") + + return base_metadata_list + + +def _merge_metadata(metadata_list: List[tuple[str, MetaData]]) -> MetaData: + """ + Merges multiple MetaData objects into a single MetaData. + If there are table name conflicts, prefixes are added based on the module path. + """ + merged_metadata = MetaData() + + for module_path, metadata in metadata_list: + # Extract a prefix from module path to avoid conflicts + # e.g., "project1.example.base" -> "project1_" + parts = module_path.split(".") + prefix = "" + if len(parts) > 1: + # Use first part as prefix (e.g., "project1") + prefix = f"{parts[0]}_" + + for tablename, table in metadata.tables.items(): + # Check for conflicts + prefixed_name = f"{prefix}{tablename}" if prefix else tablename + + # If there's a conflict and we have a prefix, use prefixed name + if prefixed_name in merged_metadata.tables and prefix: + logger.warning( + f"Table name conflict: '{tablename}' from {module_path} conflicts. " + f"Using prefixed name: '{prefixed_name}'" + ) + final_name = prefixed_name + elif tablename in merged_metadata.tables: + # Conflict without prefix - use original name (tables are the same) + final_name = tablename + else: + # No conflict + final_name = tablename if not prefix else prefixed_name + + # Copy table to merged metadata + if final_name not in merged_metadata.tables: + if hasattr(table, "to_metadata"): + table.to_metadata(merged_metadata, name=final_name) + else: + table.tometadata(merged_metadata, name=final_name) + + return merged_metadata + + +def get_graph_metadata( *, base_class_path: str, import_module: List[str], include_tables: Set[str], exclude_tables: Set[str], python_dir: List[Path], - format: str, - column_sort: str, - omit_comments: bool = False, - max_enum_members: int = 0, - layout: Optional[Layouts] = None, -) -> str: + merge_namespace_metadata: bool = False, +) -> MetaData: + """ + Builds a graph structure by importing modules and returns the filtered MetaData. + This function separates the graph building logic from serialization, allowing + tests to compare MetaData objects directly without parsing strings. + """ # Update the PYTHON_PATH to allow more module imports. sys.path.append(str(os.getcwd())) for dir in python_dir: sys.path.append(str(dir)) - # Import the base class so the metadata class can be extracted from it. - # The metadata class is passed to the transformer. - module_path, class_name = base_class_path.split(":", 2) - base_module = importlib.import_module(module_path) - base_class = getattr(base_module, class_name) - metadata = base_class.metadata + current_root = Path.cwd() + + # Handle base class path with or without wildcards + has_wildcards = "*" in base_class_path or "?" in base_class_path + base_metadata_list = None + + if has_wildcards or (merge_namespace_metadata and has_wildcards): + # Find all matching base classes and merge their metadata + base_metadata_list = _find_base_classes_by_pattern(base_class_path, python_dir, current_root) + + if len(base_metadata_list) > 1: + metadata = _merge_metadata(base_metadata_list) + elif len(base_metadata_list) == 1: + # Single base class found + metadata = base_metadata_list[0][1] + else: + raise ValueError(f"No base classes found matching pattern: {base_class_path}") + else: + # No wildcards, use single base class + module_path, class_name = base_class_path.split(":", 2) + base_module = importlib.import_module(module_path) + base_class = getattr(base_module, class_name) + metadata = base_class.metadata import_queue_sentinel = object() import_queue: Queue[Union[Dict[str, str], object]] = Queue() @@ -125,10 +244,25 @@ def get_graph_string( import_queue.put(import_queue_sentinel) import_worker.join() - # Grab a transformer. - if format not in transformers: - raise ValueError(f"Unknown Format: {format}") - transformer = transformers[format] + # If we merged metadata from multiple base classes, we need to re-merge after models are imported + # because models register themselves in the original base class metadata, not the merged one + if base_metadata_list and len(base_metadata_list) > 1: + # Re-collect metadata from all base classes after models have been imported + updated_metadata_list = [] + for module_path, _ in base_metadata_list: + try: + base_module = importlib.import_module(module_path) + # Extract class name from base_class_path + _, class_name = base_class_path.split(":", 2) + if hasattr(base_module, class_name): + base_class = getattr(base_module, class_name) + updated_metadata_list.append((module_path, base_class.metadata)) + except (ImportError, AttributeError) as e: + logger.warning(f"Could not re-import base class from {module_path}: {e}") + continue + + if updated_metadata_list: + metadata = _merge_metadata(updated_metadata_list) # Keep only the tables which were included / not-excluded include_tables = resolve_included_tables( @@ -136,8 +270,65 @@ def get_graph_string( ) filtered_metadata = filter_metadata(metadata=metadata, include_tables=include_tables) - # Save the graph structure to string. - return str(transformer(filtered_metadata, column_sort, omit_comments=omit_comments, layout=layout)) + return filtered_metadata + + +def serialize_metadata( + metadata: MetaData, + *, + format: str, + column_sort: str, + omit_comments: bool = False, + max_enum_members: int = 0, + layout: Optional[Layouts] = None, +) -> str: + """ + Serializes MetaData to a string representation in the specified format. + """ + # Grab a transformer. + if format not in transformers: + raise ValueError(f"Unknown Format: {format}") + transformer = transformers[format] + + # Serialize the graph structure to string. + return str(transformer(metadata, column_sort, omit_comments=omit_comments, layout=layout)) + + +def get_graph_string( + *, + base_class_path: str, + import_module: List[str], + include_tables: Set[str], + exclude_tables: Set[str], + python_dir: List[Path], + format: str, + column_sort: str, + omit_comments: bool = False, + max_enum_members: int = 0, + layout: Optional[Layouts] = None, +) -> str: + """ + Builds a graph structure and returns it as a serialized string. + + This is a convenience wrapper that combines get_graph_metadata() and serialize_metadata() + for backward compatibility. + """ + metadata = get_graph_metadata( + base_class_path=base_class_path, + import_module=import_module, + include_tables=include_tables, + exclude_tables=exclude_tables, + python_dir=python_dir, + ) + + return serialize_metadata( + metadata, + format=format, + column_sort=column_sort, + omit_comments=omit_comments, + max_enum_members=max_enum_members, + layout=layout, + ) def resolve_included_tables( @@ -194,3 +385,120 @@ def filter_metadata( table = table.tometadata(filtered_metadata) return filtered_metadata + + +def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = False) -> None: + """ + Compares two MetaData objects and raises AssertionError if they differ. + This function performs a structural comparison of two graph representations, + checking tables, columns, types, constraints, and relationships. + """ + + actual_tables = set(actual.tables.keys()) + expected_tables = set(expected.tables.keys()) + + # Check table names + if actual_tables != expected_tables: + missing = expected_tables - actual_tables + extra = actual_tables - expected_tables + error_msg = "Table mismatch:\n" + if missing: + error_msg += f" Missing tables: {missing}\n" + if extra: + error_msg += f" Extra tables: {extra}\n" + raise AssertionError(error_msg) + + # Check each table's structure + for table_name in expected_tables: + actual_table = actual.tables[table_name] + expected_table = expected.tables[table_name] + + actual_columns = {col.name: col for col in actual_table.columns} + expected_columns = {col.name: col for col in expected_table.columns} + + # Check column names + if set(actual_columns.keys()) != set(expected_columns.keys()): + missing = set(expected_columns.keys()) - set(actual_columns.keys()) + extra = set(actual_columns.keys()) - set(expected_columns.keys()) + error_msg = f"Column mismatch in table '{table_name}':\n" + if missing: + error_msg += f" Missing columns: {missing}\n" + if extra: + error_msg += f" Extra columns: {extra}\n" + raise AssertionError(error_msg) + + # Check each column's properties + for col_name in expected_columns.keys(): + actual_col = actual_columns[col_name] + expected_col = expected_columns[col_name] + + # Check type + actual_type_str = str(actual_col.type) + expected_type_str = str(expected_col.type) + if actual_type_str != expected_type_str: + raise AssertionError( + f"Type mismatch in table '{table_name}', column '{col_name}': " + f"expected {expected_type_str}, got {actual_type_str}" + ) + + # Check constraints + actual_pk = actual_col.primary_key + expected_pk = expected_col.primary_key + if actual_pk != expected_pk: + raise AssertionError( + f"Primary key mismatch in table '{table_name}', column '{col_name}': " + f"expected {expected_pk}, got {actual_pk}" + ) + + actual_fk_count = len(actual_col.foreign_keys) + expected_fk_count = len(expected_col.foreign_keys) + if actual_fk_count != expected_fk_count: + raise AssertionError( + f"Foreign key count mismatch in table '{table_name}', column '{col_name}': " + f"expected {expected_fk_count}, got {actual_fk_count}" + ) + + # Check nullable + if actual_col.nullable != expected_col.nullable: + raise AssertionError( + f"Nullable mismatch in table '{table_name}', column '{col_name}': " + f"expected {expected_col.nullable}, got {actual_col.nullable}" + ) + + # Check comments (if not omitted) + if not omit_comments: + actual_comment = actual_col.comment + expected_comment = expected_col.comment + if actual_comment != expected_comment: + raise AssertionError( + f"Comment mismatch in table '{table_name}', column '{col_name}': " + f"expected {expected_comment!r}, got {actual_comment!r}" + ) + + # Check foreign key relationships + actual_fks = set() + for col in actual_table.columns: + for fk in col.foreign_keys: + # Format: (table_name, column_name) -> (target_table, target_column) + target_parts = fk.target_fullname.split(".") + target_table = ".".join(target_parts[:-1]) + target_column = target_parts[-1] + actual_fks.add((table_name, col.name, target_table, target_column)) + + expected_fks = set() + for col in expected_table.columns: + for fk in col.foreign_keys: + target_parts = fk.target_fullname.split(".") + target_table = ".".join(target_parts[:-1]) + target_column = target_parts[-1] + expected_fks.add((table_name, col.name, target_table, target_column)) + + if actual_fks != expected_fks: + missing = expected_fks - actual_fks + extra = actual_fks - expected_fks + error_msg = f"Foreign key mismatch in table '{table_name}':\n" + if missing: + error_msg += f" Missing FKs: {missing}\n" + if extra: + error_msg += f" Extra FKs: {extra}\n" + raise AssertionError(error_msg) diff --git a/tests/test_graph.py b/tests/test_graph.py index 0371e98..3193e11 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,13 +1,33 @@ import pytest +import importlib +import sys +from pathlib import Path from paracelsus.config import Layouts -from paracelsus.graph import get_graph_string +from paracelsus.graph import get_graph_string, get_graph_metadata, compare_metadata +from paracelsus.finders import ModuleFinder +from paracelsus.models.pattern import Pattern +from paracelsus.graph import to_module_name from .utils import mermaid_assert @pytest.mark.parametrize("column_sort_arg", ["key-based", "preserve-order"]) -def test_get_graph_string(column_sort_arg, package_path): +def test_get_graph_string(column_sort_arg, package_path, metaclass): + """Test get_graph_string with dynamic metadata comparison.""" + # Get actual metadata + actual_metadata = get_graph_metadata( + base_class_path="example.base:Base", + import_module=["example.models"], + include_tables=set(), + exclude_tables=set(), + python_dir=[package_path], + ) + + # Compare with expected metadata from fixture + mermaid_assert(actual_metadata, expected=metaclass) + + # Also test that serialization still works graph_string = get_graph_string( base_class_path="example.base:Base", import_module=["example.models"], @@ -17,35 +37,77 @@ def test_get_graph_string(column_sort_arg, package_path): format="mermaid", column_sort=column_sort_arg, ) + # Legacy string assertion for backward compatibility mermaid_assert(graph_string) -@pytest.mark.skip(reason="Update mermaid_assert function to dynamically detect required models for validation") def test_get_graph_string_with_wildcard(single_level_package_path): - get_graph_string( + """Test that wildcard patterns work correctly with dynamic metadata comparison.""" + actual_metadata = get_graph_metadata( base_class_path="example.base:Base", import_module=["example.*.models"], include_tables=set(), exclude_tables=set(), python_dir=[single_level_package_path], - format="mermaid", - column_sort="key-based", ) - # mermaid_assert(graph_string) + + # Build expected metadata by manually importing all matching modules + sys.path.insert(0, str(single_level_package_path)) + try: + # Find all modules matching the pattern + pattern = Pattern(mask="example.*.models") + current_root = Path.cwd() + finder = ModuleFinder(current_root, pattern.tokens) + + # Import all matching modules + for file_path in finder.find(): + module_path = to_module_name(current_root, file_path) + importlib.import_module(module_path) + + # Get expected metadata from base class + base_module = importlib.import_module("example.base") + base_class = getattr(base_module, "Base") + expected_metadata = base_class.metadata + + # Compare metadata + compare_metadata(actual_metadata, expected_metadata) + finally: + # Cleanup + if str(single_level_package_path) in sys.path: + sys.path.remove(str(single_level_package_path)) + # Clear imported modules + for name in list(sys.modules.keys()): + if name.startswith("example."): + del sys.modules[name] -@pytest.mark.skip(reason="Update mermaid_assert function to dynamically detect required models for validation") def test_get_graph_with_wildcard_mask_in_namespace_package(namespace_package_path): - get_graph_string( - base_class_path="project1.example.base:Base", # @TODO: How to resolve a base class within separate multiple packages + """Test namespace packages with wildcard patterns and merged metadata.""" + # Get actual metadata with namespace merging enabled + actual_metadata = get_graph_metadata( + base_class_path="project*.example.base:Base", import_module=["project*.example.*.models"], include_tables=set(), exclude_tables=set(), python_dir=[namespace_package_path], + merge_namespace_metadata=True, + ) + + # Verify that we have tables from both projects + table_names = set(actual_metadata.tables.keys()) + # Should have tables from both project1 and project2 + assert "subpackage_a_table" in table_names or "project1_subpackage_a_table" in table_names + assert "subpackage_b_table" in table_names or "project2_subpackage_b_table" in table_names + + # Verify the graph can be serialized + from paracelsus.graph import serialize_metadata + graph_string = serialize_metadata( + actual_metadata, format="mermaid", column_sort="key-based", ) - # mermaid_assert(graph_string) + assert "subpackage_a_table" in graph_string or "project1_subpackage_a_table" in graph_string + assert "subpackage_b_table" in graph_string or "project2_subpackage_b_table" in graph_string def test_get_graph_string_with_exclude(package_path): @@ -110,7 +172,21 @@ def test_get_graph_string_with_include(package_path): @pytest.mark.parametrize("layout_arg", ["dagre", "elk"]) -def test_get_graph_string_with_layout(layout_arg, package_path): +def test_get_graph_string_with_layout(layout_arg, package_path, metaclass): + """Test get_graph_string with layout using dynamic metadata comparison.""" + # Get actual metadata + actual_metadata = get_graph_metadata( + base_class_path="example.base:Base", + import_module=["example.models"], + include_tables=set(), + exclude_tables=set(), + python_dir=[package_path], + ) + + # Compare with expected metadata + mermaid_assert(actual_metadata, expected=metaclass) + + # Also test serialization with layout graph_string = get_graph_string( base_class_path="example.base:Base", import_module=["example.models"], @@ -121,6 +197,7 @@ def test_get_graph_string_with_layout(layout_arg, package_path): column_sort="key-based", layout=Layouts(layout_arg), ) + # Legacy string assertion mermaid_assert(graph_string) @@ -138,3 +215,28 @@ def test_get_graph_string_with_nested_glob_pattern(nested_package_path): ) assert "users {" in graph_string or "products {" in graph_string or "api_resources {" in graph_string + + +def test_compare_metadata(metaclass): + """Test compare_metadata function directly.""" + # Same metadata should compare successfully + compare_metadata(metaclass, metaclass) + + # Different metadata should raise AssertionError + from sqlalchemy.orm import declarative_base + from sqlalchemy import String, Uuid + from sqlalchemy.orm import mapped_column + from uuid import uuid4 + + Base2 = declarative_base() + + class DifferentTable(Base2): + __tablename__ = "different_table" + id = mapped_column(Uuid, primary_key=True, default=uuid4()) + name = mapped_column(String(100)) + + different_metadata = Base2.metadata + + # Should raise AssertionError when comparing different metadata + with pytest.raises(AssertionError): + compare_metadata(metaclass, different_metadata) diff --git a/tests/transformers/test_find_modules.py b/tests/transformers/test_find_modules.py index bcb3947..63d0311 100644 --- a/tests/transformers/test_find_modules.py +++ b/tests/transformers/test_find_modules.py @@ -1,4 +1,5 @@ import pytest +from pathlib import Path from paracelsus.finders import ModuleFinder from paracelsus.graph import to_module_name @@ -10,7 +11,7 @@ def test_find_modules_by_pattern_single_level(single_level_package_path): pattern = Pattern(mask="example.*.models") found = [ - to_module_name(single_level_package_path, module_path) + to_module_name(Path.cwd(), module_path) for module_path in ModuleFinder(single_level_package_path, pattern.tokens).find() ] expected_modules = { @@ -27,7 +28,7 @@ def test_find_modules_by_pattern_nested_levels(nested_package_path): pattern = Pattern(mask="example.*.*.models") found = [ - to_module_name(nested_package_path, module_path) + to_module_name(Path.cwd(), module_path) for module_path in ModuleFinder(nested_package_path, pattern.tokens).find() ] expected_modules = { @@ -45,7 +46,7 @@ def test_find_modules_by_pattern_multiple_stars(multi_star_package_path): pattern = Pattern(mask="example.*.api.*.models") found = [ - to_module_name(multi_star_package_path, module_path) + to_module_name(Path.cwd(), module_path) for module_path in ModuleFinder(multi_star_package_path, pattern.tokens).find() ] expected_modules = { @@ -65,7 +66,7 @@ def test_find_modules_by_pattern_namespace_package(namespace_package_path): pattern = Pattern(mask="project*.example.*.models") found = [ - to_module_name(namespace_package_path, module_path) + to_module_name(Path.cwd(), module_path) for module_path in ModuleFinder(namespace_package_path, pattern.tokens).find() ] expected_modules = { @@ -82,7 +83,7 @@ def test_find_modules_by_pattern_single_character(single_level_package_path): pattern = Pattern(mask="example.fo?.models") found = [ - to_module_name(single_level_package_path, module_path) + to_module_name(Path.cwd(), module_path) for module_path in ModuleFinder(single_level_package_path, pattern.tokens).find() ] expected_modules = { @@ -101,7 +102,7 @@ def test_find_modules_by_pattern_character_class(character_classes_package_path) """ pattern = Pattern(mask="example.api.v[12].models") found = [ - to_module_name(character_classes_package_path, module_path) + to_module_name(Path.cwd(), module_path) for module_path in ModuleFinder(character_classes_package_path, pattern.tokens).find() ] expected_modules = { @@ -120,7 +121,7 @@ def test_find_modules_by_pattern_character_range(character_classes_package_path) """ pattern = Pattern(mask="example.api.v[0-9].models") found = [ - to_module_name(character_classes_package_path, module_path) + to_module_name(Path.cwd(), module_path) for module_path in ModuleFinder(character_classes_package_path, pattern.tokens).find() ] expected_modules = { @@ -148,7 +149,7 @@ def test_find_modules_by_pattern_complementation_character_class(character_class """ pattern = Pattern(mask="example.api.v[!1].models") found = [ - to_module_name(character_classes_package_path, module_path) + to_module_name(Path.cwd(), module_path) for module_path in ModuleFinder(character_classes_package_path, pattern.tokens).find() ] @@ -165,7 +166,7 @@ def test_find_modules_by_pattern_complementation_character_range(character_class """ pattern = Pattern(mask="example.api.v[!0-9].models") found = [ - to_module_name(character_classes_package_path, module_path) + to_module_name(Path.cwd(), module_path) for module_path in ModuleFinder(character_classes_package_path, pattern.tokens).find() ] @@ -185,7 +186,7 @@ def test_find_modules_by_pattern_mixed_wildcards(multi_star_package_path): """ pattern = Pattern(mask="example.v?.*.*.models") found = [ - to_module_name(multi_star_package_path, module_path) + to_module_name(Path.cwd(), module_path) for module_path in ModuleFinder(multi_star_package_path, pattern.tokens).find() ] expected_modules = { @@ -208,7 +209,7 @@ def test_find_modules_by_pattern_recursive_lookup(recursive_package_path): """ pattern = Pattern(mask="example.**.api.*.models") found = [ - to_module_name(recursive_package_path, module_path) + to_module_name(Path.cwd(), module_path) for module_path in ModuleFinder(recursive_package_path, pattern.tokens).find() ] expected_modules = { diff --git a/tests/utils.py b/tests/utils.py index 0a1fb9f..cdc833c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,18 +1,43 @@ -def mermaid_assert(output: str) -> None: - assert "users {" in output - assert "posts {" in output - assert "comments {" in output +from typing import Union +from sqlalchemy.schema import MetaData +from paracelsus.graph import compare_metadata - assert "users ||--o{ posts : author" in output - assert "posts ||--o{ comments : post" in output - assert "users ||--o{ comments : author" in output - assert "CHAR(32) author FK" in output - assert 'CHAR(32) post FK "nullable"' in output - assert 'BOOLEAN live "True if post is published,nullable"' in output - assert "DATETIME created" in output +def mermaid_assert( + actual: Union[str, MetaData], + expected: Union[MetaData, None] = None, + omit_comments: bool = False, +) -> None: + """ + Asserts that a mermaid graph (string or MetaData) matches expected structure. + + This function supports two modes: + 1. Legacy mode: If actual is a string, performs basic string assertions (for backward compatibility) + 2. Dynamic mode: If actual is MetaData, compares it with expected MetaData + """ + # Legacy mode: string comparison (for backward compatibility) + if isinstance(actual, str): + # Basic assertions for backward compatibility + assert "users {" in actual + assert "posts {" in actual + assert "comments {" in actual - trailing_newline_assert(output) + assert "users ||--o{ posts : author" in actual + assert "posts ||--o{ comments : post" in actual + assert "users ||--o{ comments : author" in actual + + assert "CHAR(32) author FK" in actual + assert 'CHAR(32) post FK "nullable"' in actual + assert 'BOOLEAN live "True if post is published,nullable"' in actual + assert "DATETIME created" in actual + + trailing_newline_assert(actual) + + # Dynamic mode: MetaData comparison + else: + if expected is None: + raise ValueError("expected MetaData is required when actual is MetaData") + compare_metadata(actual, expected, omit_comments=omit_comments) def dot_assert(output: str) -> None: From 26a903e4f48840872853f210c06a20c823114a94 Mon Sep 17 00:00:00 2001 From: Apti Date: Wed, 14 Jan 2026 18:17:23 +0300 Subject: [PATCH 18/20] fix format --- paracelsus/graph.py | 70 ++++++++++++++++++++++----------------------- tests/test_graph.py | 31 ++++++++++---------- tests/utils.py | 4 +-- 3 files changed, 53 insertions(+), 52 deletions(-) diff --git a/paracelsus/graph.py b/paracelsus/graph.py index e39c9f0..253bbf8 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -88,31 +88,31 @@ def _find_base_classes_by_pattern( return [(module_path, base_class.metadata)] except (ImportError, AttributeError) as e: raise ValueError(f"Could not import base class from {base_class_path}: {e}") - + # Extract pattern parts parts = base_class_path.split(":") if len(parts) != 2: raise ValueError(f"Invalid base_class_path format: {base_class_path}") - + pattern_str, class_name = parts - + # Create pattern for finding base.py modules pattern = Pattern(mask=pattern_str) - + if any(pattern.errors): raise ValueError(pattern.serialized_errors) - + # Find all matching base.py files finder = ModuleFinder(current_root, pattern.tokens) base_metadata_list = [] - + for file_path in finder.find(): # Only consider base.py files if file_path.name != "base.py" and not file_path.name.endswith("base.py"): continue - + module_path = to_module_name(current_root, file_path) - + try: base_module = importlib.import_module(module_path) if hasattr(base_module, class_name): @@ -121,10 +121,10 @@ def _find_base_classes_by_pattern( except (ImportError, AttributeError) as e: logger.warning(f"Could not import base class from {module_path}: {e}") continue - + if not base_metadata_list: raise ValueError(f"No base classes found matching pattern: {base_class_path}") - + return base_metadata_list @@ -134,7 +134,7 @@ def _merge_metadata(metadata_list: List[tuple[str, MetaData]]) -> MetaData: If there are table name conflicts, prefixes are added based on the module path. """ merged_metadata = MetaData() - + for module_path, metadata in metadata_list: # Extract a prefix from module path to avoid conflicts # e.g., "project1.example.base" -> "project1_" @@ -143,11 +143,11 @@ def _merge_metadata(metadata_list: List[tuple[str, MetaData]]) -> MetaData: if len(parts) > 1: # Use first part as prefix (e.g., "project1") prefix = f"{parts[0]}_" - + for tablename, table in metadata.tables.items(): # Check for conflicts prefixed_name = f"{prefix}{tablename}" if prefix else tablename - + # If there's a conflict and we have a prefix, use prefixed name if prefixed_name in merged_metadata.tables and prefix: logger.warning( @@ -161,14 +161,14 @@ def _merge_metadata(metadata_list: List[tuple[str, MetaData]]) -> MetaData: else: # No conflict final_name = tablename if not prefix else prefixed_name - + # Copy table to merged metadata if final_name not in merged_metadata.tables: if hasattr(table, "to_metadata"): table.to_metadata(merged_metadata, name=final_name) else: table.tometadata(merged_metadata, name=final_name) - + return merged_metadata @@ -192,15 +192,15 @@ def get_graph_metadata( sys.path.append(str(dir)) current_root = Path.cwd() - + # Handle base class path with or without wildcards has_wildcards = "*" in base_class_path or "?" in base_class_path base_metadata_list = None - + if has_wildcards or (merge_namespace_metadata and has_wildcards): # Find all matching base classes and merge their metadata base_metadata_list = _find_base_classes_by_pattern(base_class_path, python_dir, current_root) - + if len(base_metadata_list) > 1: metadata = _merge_metadata(base_metadata_list) elif len(base_metadata_list) == 1: @@ -260,7 +260,7 @@ def get_graph_metadata( except (ImportError, AttributeError) as e: logger.warning(f"Could not re-import base class from {module_path}: {e}") continue - + if updated_metadata_list: metadata = _merge_metadata(updated_metadata_list) @@ -309,7 +309,7 @@ def get_graph_string( ) -> str: """ Builds a graph structure and returns it as a serialized string. - + This is a convenience wrapper that combines get_graph_metadata() and serialize_metadata() for backward compatibility. """ @@ -320,7 +320,7 @@ def get_graph_string( exclude_tables=exclude_tables, python_dir=python_dir, ) - + return serialize_metadata( metadata, format=format, @@ -393,10 +393,10 @@ def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = This function performs a structural comparison of two graph representations, checking tables, columns, types, constraints, and relationships. """ - + actual_tables = set(actual.tables.keys()) expected_tables = set(expected.tables.keys()) - + # Check table names if actual_tables != expected_tables: missing = expected_tables - actual_tables @@ -407,15 +407,15 @@ def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = if extra: error_msg += f" Extra tables: {extra}\n" raise AssertionError(error_msg) - + # Check each table's structure for table_name in expected_tables: actual_table = actual.tables[table_name] expected_table = expected.tables[table_name] - + actual_columns = {col.name: col for col in actual_table.columns} expected_columns = {col.name: col for col in expected_table.columns} - + # Check column names if set(actual_columns.keys()) != set(expected_columns.keys()): missing = set(expected_columns.keys()) - set(actual_columns.keys()) @@ -426,12 +426,12 @@ def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = if extra: error_msg += f" Extra columns: {extra}\n" raise AssertionError(error_msg) - + # Check each column's properties for col_name in expected_columns.keys(): actual_col = actual_columns[col_name] expected_col = expected_columns[col_name] - + # Check type actual_type_str = str(actual_col.type) expected_type_str = str(expected_col.type) @@ -440,7 +440,7 @@ def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = f"Type mismatch in table '{table_name}', column '{col_name}': " f"expected {expected_type_str}, got {actual_type_str}" ) - + # Check constraints actual_pk = actual_col.primary_key expected_pk = expected_col.primary_key @@ -449,7 +449,7 @@ def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = f"Primary key mismatch in table '{table_name}', column '{col_name}': " f"expected {expected_pk}, got {actual_pk}" ) - + actual_fk_count = len(actual_col.foreign_keys) expected_fk_count = len(expected_col.foreign_keys) if actual_fk_count != expected_fk_count: @@ -457,14 +457,14 @@ def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = f"Foreign key count mismatch in table '{table_name}', column '{col_name}': " f"expected {expected_fk_count}, got {actual_fk_count}" ) - + # Check nullable if actual_col.nullable != expected_col.nullable: raise AssertionError( f"Nullable mismatch in table '{table_name}', column '{col_name}': " f"expected {expected_col.nullable}, got {actual_col.nullable}" ) - + # Check comments (if not omitted) if not omit_comments: actual_comment = actual_col.comment @@ -474,7 +474,7 @@ def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = f"Comment mismatch in table '{table_name}', column '{col_name}': " f"expected {expected_comment!r}, got {actual_comment!r}" ) - + # Check foreign key relationships actual_fks = set() for col in actual_table.columns: @@ -484,7 +484,7 @@ def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = target_table = ".".join(target_parts[:-1]) target_column = target_parts[-1] actual_fks.add((table_name, col.name, target_table, target_column)) - + expected_fks = set() for col in expected_table.columns: for fk in col.foreign_keys: @@ -492,7 +492,7 @@ def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = target_table = ".".join(target_parts[:-1]) target_column = target_parts[-1] expected_fks.add((table_name, col.name, target_table, target_column)) - + if actual_fks != expected_fks: missing = expected_fks - actual_fks extra = actual_fks - expected_fks diff --git a/tests/test_graph.py b/tests/test_graph.py index 3193e11..e124e5b 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -23,10 +23,10 @@ def test_get_graph_string(column_sort_arg, package_path, metaclass): exclude_tables=set(), python_dir=[package_path], ) - + # Compare with expected metadata from fixture mermaid_assert(actual_metadata, expected=metaclass) - + # Also test that serialization still works graph_string = get_graph_string( base_class_path="example.base:Base", @@ -50,7 +50,7 @@ def test_get_graph_string_with_wildcard(single_level_package_path): exclude_tables=set(), python_dir=[single_level_package_path], ) - + # Build expected metadata by manually importing all matching modules sys.path.insert(0, str(single_level_package_path)) try: @@ -58,17 +58,17 @@ def test_get_graph_string_with_wildcard(single_level_package_path): pattern = Pattern(mask="example.*.models") current_root = Path.cwd() finder = ModuleFinder(current_root, pattern.tokens) - + # Import all matching modules for file_path in finder.find(): module_path = to_module_name(current_root, file_path) importlib.import_module(module_path) - + # Get expected metadata from base class base_module = importlib.import_module("example.base") base_class = getattr(base_module, "Base") expected_metadata = base_class.metadata - + # Compare metadata compare_metadata(actual_metadata, expected_metadata) finally: @@ -92,15 +92,16 @@ def test_get_graph_with_wildcard_mask_in_namespace_package(namespace_package_pat python_dir=[namespace_package_path], merge_namespace_metadata=True, ) - + # Verify that we have tables from both projects table_names = set(actual_metadata.tables.keys()) # Should have tables from both project1 and project2 assert "subpackage_a_table" in table_names or "project1_subpackage_a_table" in table_names assert "subpackage_b_table" in table_names or "project2_subpackage_b_table" in table_names - + # Verify the graph can be serialized from paracelsus.graph import serialize_metadata + graph_string = serialize_metadata( actual_metadata, format="mermaid", @@ -182,10 +183,10 @@ def test_get_graph_string_with_layout(layout_arg, package_path, metaclass): exclude_tables=set(), python_dir=[package_path], ) - + # Compare with expected metadata mermaid_assert(actual_metadata, expected=metaclass) - + # Also test serialization with layout graph_string = get_graph_string( base_class_path="example.base:Base", @@ -221,22 +222,22 @@ def test_compare_metadata(metaclass): """Test compare_metadata function directly.""" # Same metadata should compare successfully compare_metadata(metaclass, metaclass) - + # Different metadata should raise AssertionError from sqlalchemy.orm import declarative_base from sqlalchemy import String, Uuid from sqlalchemy.orm import mapped_column from uuid import uuid4 - + Base2 = declarative_base() - + class DifferentTable(Base2): __tablename__ = "different_table" id = mapped_column(Uuid, primary_key=True, default=uuid4()) name = mapped_column(String(100)) - + different_metadata = Base2.metadata - + # Should raise AssertionError when comparing different metadata with pytest.raises(AssertionError): compare_metadata(metaclass, different_metadata) diff --git a/tests/utils.py b/tests/utils.py index cdc833c..f051e50 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,7 +10,7 @@ def mermaid_assert( ) -> None: """ Asserts that a mermaid graph (string or MetaData) matches expected structure. - + This function supports two modes: 1. Legacy mode: If actual is a string, performs basic string assertions (for backward compatibility) 2. Dynamic mode: If actual is MetaData, compares it with expected MetaData @@ -32,7 +32,7 @@ def mermaid_assert( assert "DATETIME created" in actual trailing_newline_assert(actual) - + # Dynamic mode: MetaData comparison else: if expected is None: From c26352d784b86b0fbeb960190bcc13376224905e Mon Sep 17 00:00:00 2001 From: Apti Date: Sun, 18 Jan 2026 19:40:55 +0300 Subject: [PATCH 19/20] Fix mypy type errors --- paracelsus/graph.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/paracelsus/graph.py b/paracelsus/graph.py index 5ecdf00..4833bf1 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -233,7 +233,6 @@ def get_graph_metadata( if any(pattern.errors): raise ValueError(pattern.serialized_errors) - current_root = Path.cwd() finder = ModuleFinder(current_root, pattern.tokens) needs_wildcards_import = import_modifier == "*" @@ -488,7 +487,7 @@ def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = ) # Check foreign key relationships - actual_fks = set() + actual_fks: set[tuple[str, ...]] = set() for col in actual_table.columns: for fk in col.foreign_keys: # Format: (table_name, column_name) -> (target_table, target_column) @@ -497,7 +496,7 @@ def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = target_column = target_parts[-1] actual_fks.add((table_name, col.name, target_table, target_column)) - expected_fks = set() + expected_fks: set[tuple[str, ...]] = set() for col in expected_table.columns: for fk in col.foreign_keys: target_parts = fk.target_fullname.split(".") @@ -506,11 +505,11 @@ def compare_metadata(actual: MetaData, expected: MetaData, omit_comments: bool = expected_fks.add((table_name, col.name, target_table, target_column)) if actual_fks != expected_fks: - missing = expected_fks - actual_fks - extra = actual_fks - expected_fks + missing_fks: set[tuple[str, ...]] = expected_fks - actual_fks + extra_fks: set[tuple[str, ...]] = actual_fks - expected_fks error_msg = f"Foreign key mismatch in table '{table_name}':\n" if missing: - error_msg += f" Missing FKs: {missing}\n" + error_msg += f" Missing FKs: {missing_fks}\n" if extra: - error_msg += f" Extra FKs: {extra}\n" + error_msg += f" Extra FKs: {extra_fks}\n" raise AssertionError(error_msg) From 85da90acf287de93c40930a41d703dc3cdbc5ccf Mon Sep 17 00:00:00 2001 From: Apti Date: Sat, 24 Jan 2026 00:50:59 +0300 Subject: [PATCH 20/20] Refactor: make module imports single-threaded --- paracelsus/graph.py | 46 ++++++++++++--------------------------------- 1 file changed, 12 insertions(+), 34 deletions(-) diff --git a/paracelsus/graph.py b/paracelsus/graph.py index 4833bf1..844e11f 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -4,8 +4,6 @@ import re import sys from pathlib import Path -from queue import Queue -from threading import Thread from typing import Dict, List, Optional, Set, Union from sqlalchemy.schema import MetaData @@ -46,31 +44,6 @@ def to_module_name(root: Path, path: Path) -> str: return ".".join(clean_path.parts) -def consume_import_tasks(queue: Queue[dict], sentinel: object): - while True: - item = queue.get() - - if item is sentinel: - break - - needs_wildcards_import, module_name = item.values() - try: - # Check if already loaded to save time - if module_name in sys.modules: - continue - - if needs_wildcards_import: - exec(f"from {module_name} import *") - else: - importlib.import_module(module_name) - - except ImportError as e: - logger.error(f"Failed to load {module_name}: {e}") - raise e - finally: - queue.task_done() - - def _find_base_classes_by_pattern( base_class_path: str, python_dir: List[Path], @@ -215,10 +188,6 @@ def get_graph_metadata( base_class = getattr(base_module, class_name) metadata = base_class.metadata - import_queue_sentinel = object() - import_queue: Queue[Union[Dict[str, str], object]] = Queue() - import_worker = Thread(target=consume_import_tasks, args=(import_queue, import_queue_sentinel), daemon=True) - import_worker.start() # The modules holding the model classes have to be imported to get put in the metaclass model registry. # These modules aren't actually used in any way, so they are discarded. # They are also imported in scope of this function to prevent namespace pollution. @@ -238,10 +207,19 @@ def get_graph_metadata( for file_path in finder.find(): module_path = to_module_name(current_root, file_path) - import_queue.put({"needs_wildcards_import": needs_wildcards_import, "module_name": module_path}) - import_queue.put(import_queue_sentinel) - import_worker.join() + # Check if already loaded to save time + if module_path in sys.modules: + continue + + try: + if needs_wildcards_import: + exec(f"from {module_path} import *") + else: + importlib.import_module(module_path) + except ImportError as e: + logger.error(f"Failed to load {module_path}: {e}") + raise e # If we merged metadata from multiple base classes, we need to re-merge after models are imported # because models register themselves in the original base class metadata, not the merged one