From 334c93e5a2b54bacf1639b606c0ef68b60b7107b Mon Sep 17 00:00:00 2001 From: Ivan Malison Date: Fri, 23 Jan 2026 13:31:23 -0800 Subject: [PATCH] feat: Add schema selection support to compare method Add one_schema and two_schema parameters to Comparer.compare() method to allow specifying different database schemas for comparison. This enables comparing specific schemas within databases that have multiple schemas, which is common in PostgreSQL setups. All inspector methods have been updated to pass the schema parameter to SQLAlchemy's inspector methods (get_table_names, get_columns, get_pk_constraint, etc.). --- src/sqlalchemydiff/comparer.py | 28 ++++++- src/sqlalchemydiff/inspection/base.py | 5 +- src/sqlalchemydiff/inspection/inspectors.py | 90 +++++++++++++++------ 3 files changed, 93 insertions(+), 30 deletions(-) diff --git a/src/sqlalchemydiff/comparer.py b/src/sqlalchemydiff/comparer.py index b71c5e4..3257f9f 100644 --- a/src/sqlalchemydiff/comparer.py +++ b/src/sqlalchemydiff/comparer.py @@ -127,7 +127,19 @@ def compare( two_alias: str = "two", ignores: list[str] | None = None, ignore_inspectors: Iterable[str] | None = None, + one_schema: str | None = None, + two_schema: str | None = None, ): + """Compare two databases. + + :param one_alias: Alias for the first database in the result. + :param two_alias: Alias for the second database in the result. + :param ignores: List of ignore specifications. + :param ignore_inspectors: List of inspector keys to ignore. + :param one_schema: Schema name for the first database. If None, uses the default schema. + :param two_schema: Schema name for the second database. If None, uses the default schema. + :return: A CompareResult object with the comparison result. + """ ignore_specs = self.ignore_spec_factory_class().create_specs(register, ignores) filtered_inspectors = self._filter_inspectors(set(ignore_inspectors or set())) @@ -138,8 +150,12 @@ def compare( for key, inspector_class in filtered_inspectors: inspector = inspector_class(one_alias=one_alias, two_alias=two_alias) - db_one_info = self._get_db_info(ignore_specs, inspector, self.db_one_engine) - db_two_info = self._get_db_info(ignore_specs, inspector, self.db_two_engine) + db_one_info = self._get_db_info( + ignore_specs, inspector, self.db_one_engine, schema=one_schema + ) + db_two_info = self._get_db_info( + ignore_specs, inspector, self.db_two_engine, schema=two_schema + ) if db_one_info is not None and db_two_info is not None: result[key] = inspector.diff(db_one_info, db_two_info) @@ -163,10 +179,14 @@ def _filter_inspectors( return [(key, cls) for key, (_, cls) in register.items() if key not in ignore_inspectors] def _get_db_info( - self, ignore_specs: list[IgnoreSpecType], inspector: BaseInspector, engine: Engine + self, + ignore_specs: list[IgnoreSpecType], + inspector: BaseInspector, + engine: Engine, + schema: str | None = None, ) -> dict | None: try: - return inspector.inspect(engine, ignore_specs) + return inspector.inspect(engine, ignore_specs, schema=schema) except InspectorNotSupported as e: logger.warning({"engine": engine, "inspector": inspector.key, "error": e.message}) diff --git a/src/sqlalchemydiff/inspection/base.py b/src/sqlalchemydiff/inspection/base.py index e8182a3..f0268f5 100644 --- a/src/sqlalchemydiff/inspection/base.py +++ b/src/sqlalchemydiff/inspection/base.py @@ -59,7 +59,10 @@ def __init__(self, one_alias: str = "one", two_alias: str = "two"): @abc.abstractmethod def inspect( - self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, ) -> Any: ... # pragma: no cover @abc.abstractmethod diff --git a/src/sqlalchemydiff/inspection/inspectors.py b/src/sqlalchemydiff/inspection/inspectors.py index 20c231a..d3ee6fb 100644 --- a/src/sqlalchemydiff/inspection/inspectors.py +++ b/src/sqlalchemydiff/inspection/inspectors.py @@ -14,19 +14,24 @@ class TablesInspector(BaseInspector, DiffMixin): key = "tables" db_level = True - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) def get_comment(table_name: str) -> str | None: try: - return inspector.get_table_comment(table_name)["text"] + return inspector.get_table_comment(table_name, schema=schema)["text"] except NotImplementedError: return return { table_name: self._format_table(table_name, get_comment(table_name)) - for table_name in inspector.get_table_names() + for table_name in inspector.get_table_names(schema=schema) if table_name not in ignore_clauses.tables } @@ -48,11 +53,16 @@ class ColumnsInspector(BaseInspector, DiffMixin): key = "columns" - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) - table_names = inspector.get_table_names() + table_names = inspector.get_table_names(schema=schema) result = {} for table_name in table_names: @@ -61,7 +71,7 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No result[table_name] = [ column_item - for column_item in inspector.get_columns(table_name) + for column_item in inspector.get_columns(table_name, schema=schema) if not ignore_clauses.is_clause(table_name, self.key, column_item["name"]) ] self._process_types(result[table_name], engine) @@ -90,17 +100,22 @@ class PrimaryKeysInspector(BaseInspector, DiffMixin): key = "primary_keys" - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) - table_names = inspector.get_table_names() + table_names = inspector.get_table_names(schema=schema) result = {} for table_name in table_names: if table_name in ignore_clauses.tables: continue - inspection_result = inspector.get_pk_constraint(table_name) + inspection_result = inspector.get_pk_constraint(table_name, schema=schema) if not ignore_clauses.is_clause(table_name, self.key, inspection_result["name"]): result[table_name] = inspection_result @@ -121,11 +136,16 @@ class ForeignKeysInspector(BaseInspector, DiffMixin): key = "foreign_keys" - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) - table_names = inspector.get_table_names() + table_names = inspector.get_table_names(schema=schema) result = {} for table_name in table_names: if table_name in ignore_clauses.tables: @@ -133,7 +153,7 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No result[table_name] = [ self._get_fk_identifier(fk) - for fk in inspector.get_foreign_keys(table_name) + for fk in inspector.get_foreign_keys(table_name, schema=schema) if not ignore_clauses.is_clause(table_name, self.key, fk["name"]) ] return result @@ -155,10 +175,15 @@ class IndexesInspector(BaseInspector, DiffMixin): key = "indexes" - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) - table_names = inspector.get_table_names() + table_names = inspector.get_table_names(schema=schema) result = {} for table_name in table_names: if table_name in ignore_clauses.tables: @@ -166,7 +191,7 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No result[table_name] = [ index - for index in inspector.get_indexes(table_name) + for index in inspector.get_indexes(table_name, schema=schema) if not ignore_clauses.is_clause(table_name, self.key, index["name"]) ] return result @@ -183,10 +208,15 @@ class UniqueConstraintsInspector(BaseInspector, DiffMixin): key = "unique_constraints" - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) - table_names = inspector.get_table_names() + table_names = inspector.get_table_names(schema=schema) result = {} for table_name in table_names: if table_name in ignore_clauses.tables: @@ -194,7 +224,7 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No result[table_name] = [ uc - for uc in self._format_unique_constraint(inspector, table_name) + for uc in self._format_unique_constraint(inspector, table_name, schema=schema) if not ignore_clauses.is_clause(table_name, self.key, uc["name"]) ] return result @@ -202,8 +232,10 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No def diff(self, one: dict, two: dict) -> dict: return self._listdiff(one, two) - def _format_unique_constraint(self, inspector: Inspector, table_name: str) -> list[dict]: - result = inspector.get_unique_constraints(table_name) + def _format_unique_constraint( + self, inspector: Inspector, table_name: str, schema: str | None = None + ) -> list[dict]: + result = inspector.get_unique_constraints(table_name, schema=schema) for constraint in result: name = constraint.get("name") if not name: @@ -220,10 +252,15 @@ class CheckConstraintsInspector(BaseInspector, DiffMixin): key = "check_constraints" - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) - table_names = inspector.get_table_names() + table_names = inspector.get_table_names(schema=schema) result = {} for table_name in table_names: if table_name in ignore_clauses.tables: @@ -231,7 +268,7 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No result[table_name] = [ cc - for cc in inspector.get_check_constraints(table_name) + for cc in inspector.get_check_constraints(table_name, schema=schema) if not ignore_clauses.is_clause(table_name, self.key, cc["name"]) ] return result @@ -250,12 +287,15 @@ class EnumsInspector(BaseInspector, DiffMixin): db_level = True def inspect( - self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, ) -> list[dict]: inspector = self._get_inspector(engine) ignore_clauses = self._filter_ignorers(ignore_specs) - enums = getattr(inspector, "get_enums", lambda: [])() or [] + enums = getattr(inspector, "get_enums", lambda schema=None: [])(schema=schema) or [] return [enum for enum in enums if enum["name"] not in ignore_clauses.enums] def diff(self, one: dict, two: dict) -> dict: