diff --git a/src/sqlalchemydiff/comparer.py b/src/sqlalchemydiff/comparer.py index b71c5e4..3257f9f 100644 --- a/src/sqlalchemydiff/comparer.py +++ b/src/sqlalchemydiff/comparer.py @@ -127,7 +127,19 @@ def compare( two_alias: str = "two", ignores: list[str] | None = None, ignore_inspectors: Iterable[str] | None = None, + one_schema: str | None = None, + two_schema: str | None = None, ): + """Compare two databases. + + :param one_alias: Alias for the first database in the result. + :param two_alias: Alias for the second database in the result. + :param ignores: List of ignore specifications. + :param ignore_inspectors: List of inspector keys to ignore. + :param one_schema: Schema name for the first database. If None, uses the default schema. + :param two_schema: Schema name for the second database. If None, uses the default schema. + :return: A CompareResult object with the comparison result. + """ ignore_specs = self.ignore_spec_factory_class().create_specs(register, ignores) filtered_inspectors = self._filter_inspectors(set(ignore_inspectors or set())) @@ -138,8 +150,12 @@ def compare( for key, inspector_class in filtered_inspectors: inspector = inspector_class(one_alias=one_alias, two_alias=two_alias) - db_one_info = self._get_db_info(ignore_specs, inspector, self.db_one_engine) - db_two_info = self._get_db_info(ignore_specs, inspector, self.db_two_engine) + db_one_info = self._get_db_info( + ignore_specs, inspector, self.db_one_engine, schema=one_schema + ) + db_two_info = self._get_db_info( + ignore_specs, inspector, self.db_two_engine, schema=two_schema + ) if db_one_info is not None and db_two_info is not None: result[key] = inspector.diff(db_one_info, db_two_info) @@ -163,10 +179,14 @@ def _filter_inspectors( return [(key, cls) for key, (_, cls) in register.items() if key not in ignore_inspectors] def _get_db_info( - self, ignore_specs: list[IgnoreSpecType], inspector: BaseInspector, engine: Engine + self, + ignore_specs: list[IgnoreSpecType], + inspector: BaseInspector, + engine: Engine, + schema: str | None = None, ) -> dict | None: try: - return inspector.inspect(engine, ignore_specs) + return inspector.inspect(engine, ignore_specs, schema=schema) except InspectorNotSupported as e: logger.warning({"engine": engine, "inspector": inspector.key, "error": e.message}) diff --git a/src/sqlalchemydiff/inspection/base.py b/src/sqlalchemydiff/inspection/base.py index e8182a3..f0268f5 100644 --- a/src/sqlalchemydiff/inspection/base.py +++ b/src/sqlalchemydiff/inspection/base.py @@ -59,7 +59,10 @@ def __init__(self, one_alias: str = "one", two_alias: str = "two"): @abc.abstractmethod def inspect( - self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, ) -> Any: ... # pragma: no cover @abc.abstractmethod diff --git a/src/sqlalchemydiff/inspection/inspectors.py b/src/sqlalchemydiff/inspection/inspectors.py index 20c231a..d3ee6fb 100644 --- a/src/sqlalchemydiff/inspection/inspectors.py +++ b/src/sqlalchemydiff/inspection/inspectors.py @@ -14,19 +14,24 @@ class TablesInspector(BaseInspector, DiffMixin): key = "tables" db_level = True - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) def get_comment(table_name: str) -> str | None: try: - return inspector.get_table_comment(table_name)["text"] + return inspector.get_table_comment(table_name, schema=schema)["text"] except NotImplementedError: return return { table_name: self._format_table(table_name, get_comment(table_name)) - for table_name in inspector.get_table_names() + for table_name in inspector.get_table_names(schema=schema) if table_name not in ignore_clauses.tables } @@ -48,11 +53,16 @@ class ColumnsInspector(BaseInspector, DiffMixin): key = "columns" - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) - table_names = inspector.get_table_names() + table_names = inspector.get_table_names(schema=schema) result = {} for table_name in table_names: @@ -61,7 +71,7 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No result[table_name] = [ column_item - for column_item in inspector.get_columns(table_name) + for column_item in inspector.get_columns(table_name, schema=schema) if not ignore_clauses.is_clause(table_name, self.key, column_item["name"]) ] self._process_types(result[table_name], engine) @@ -90,17 +100,22 @@ class PrimaryKeysInspector(BaseInspector, DiffMixin): key = "primary_keys" - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) - table_names = inspector.get_table_names() + table_names = inspector.get_table_names(schema=schema) result = {} for table_name in table_names: if table_name in ignore_clauses.tables: continue - inspection_result = inspector.get_pk_constraint(table_name) + inspection_result = inspector.get_pk_constraint(table_name, schema=schema) if not ignore_clauses.is_clause(table_name, self.key, inspection_result["name"]): result[table_name] = inspection_result @@ -121,11 +136,16 @@ class ForeignKeysInspector(BaseInspector, DiffMixin): key = "foreign_keys" - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) - table_names = inspector.get_table_names() + table_names = inspector.get_table_names(schema=schema) result = {} for table_name in table_names: if table_name in ignore_clauses.tables: @@ -133,7 +153,7 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No result[table_name] = [ self._get_fk_identifier(fk) - for fk in inspector.get_foreign_keys(table_name) + for fk in inspector.get_foreign_keys(table_name, schema=schema) if not ignore_clauses.is_clause(table_name, self.key, fk["name"]) ] return result @@ -155,10 +175,15 @@ class IndexesInspector(BaseInspector, DiffMixin): key = "indexes" - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) - table_names = inspector.get_table_names() + table_names = inspector.get_table_names(schema=schema) result = {} for table_name in table_names: if table_name in ignore_clauses.tables: @@ -166,7 +191,7 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No result[table_name] = [ index - for index in inspector.get_indexes(table_name) + for index in inspector.get_indexes(table_name, schema=schema) if not ignore_clauses.is_clause(table_name, self.key, index["name"]) ] return result @@ -183,10 +208,15 @@ class UniqueConstraintsInspector(BaseInspector, DiffMixin): key = "unique_constraints" - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) - table_names = inspector.get_table_names() + table_names = inspector.get_table_names(schema=schema) result = {} for table_name in table_names: if table_name in ignore_clauses.tables: @@ -194,7 +224,7 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No result[table_name] = [ uc - for uc in self._format_unique_constraint(inspector, table_name) + for uc in self._format_unique_constraint(inspector, table_name, schema=schema) if not ignore_clauses.is_clause(table_name, self.key, uc["name"]) ] return result @@ -202,8 +232,10 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No def diff(self, one: dict, two: dict) -> dict: return self._listdiff(one, two) - def _format_unique_constraint(self, inspector: Inspector, table_name: str) -> list[dict]: - result = inspector.get_unique_constraints(table_name) + def _format_unique_constraint( + self, inspector: Inspector, table_name: str, schema: str | None = None + ) -> list[dict]: + result = inspector.get_unique_constraints(table_name, schema=schema) for constraint in result: name = constraint.get("name") if not name: @@ -220,10 +252,15 @@ class CheckConstraintsInspector(BaseInspector, DiffMixin): key = "check_constraints" - def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict: + def inspect( + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, + ) -> dict: ignore_clauses = self._filter_ignorers(ignore_specs) inspector = self._get_inspector(engine) - table_names = inspector.get_table_names() + table_names = inspector.get_table_names(schema=schema) result = {} for table_name in table_names: if table_name in ignore_clauses.tables: @@ -231,7 +268,7 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No result[table_name] = [ cc - for cc in inspector.get_check_constraints(table_name) + for cc in inspector.get_check_constraints(table_name, schema=schema) if not ignore_clauses.is_clause(table_name, self.key, cc["name"]) ] return result @@ -250,12 +287,15 @@ class EnumsInspector(BaseInspector, DiffMixin): db_level = True def inspect( - self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None + self, + engine: Engine, + ignore_specs: list[IgnoreSpecType] | None = None, + schema: str | None = None, ) -> list[dict]: inspector = self._get_inspector(engine) ignore_clauses = self._filter_ignorers(ignore_specs) - enums = getattr(inspector, "get_enums", lambda: [])() or [] + enums = getattr(inspector, "get_enums", lambda schema=None: [])(schema=schema) or [] return [enum for enum in enums if enum["name"] not in ignore_clauses.enums] def diff(self, one: dict, two: dict) -> dict: