Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions src/sqlalchemydiff/comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,19 @@ def compare(
two_alias: str = "two",
ignores: list[str] | None = None,
ignore_inspectors: Iterable[str] | None = None,
one_schema: str | None = None,
two_schema: str | None = None,
):
"""Compare two databases.

:param one_alias: Alias for the first database in the result.
:param two_alias: Alias for the second database in the result.
:param ignores: List of ignore specifications.
:param ignore_inspectors: List of inspector keys to ignore.
:param one_schema: Schema name for the first database. If None, uses the default schema.
:param two_schema: Schema name for the second database. If None, uses the default schema.
:return: A CompareResult object with the comparison result.
"""
ignore_specs = self.ignore_spec_factory_class().create_specs(register, ignores)

filtered_inspectors = self._filter_inspectors(set(ignore_inspectors or set()))
Expand All @@ -138,8 +150,12 @@ def compare(
for key, inspector_class in filtered_inspectors:
inspector = inspector_class(one_alias=one_alias, two_alias=two_alias)

db_one_info = self._get_db_info(ignore_specs, inspector, self.db_one_engine)
db_two_info = self._get_db_info(ignore_specs, inspector, self.db_two_engine)
db_one_info = self._get_db_info(
ignore_specs, inspector, self.db_one_engine, schema=one_schema
)
db_two_info = self._get_db_info(
ignore_specs, inspector, self.db_two_engine, schema=two_schema
)

if db_one_info is not None and db_two_info is not None:
result[key] = inspector.diff(db_one_info, db_two_info)
Expand All @@ -163,10 +179,14 @@ def _filter_inspectors(
return [(key, cls) for key, (_, cls) in register.items() if key not in ignore_inspectors]

def _get_db_info(
self, ignore_specs: list[IgnoreSpecType], inspector: BaseInspector, engine: Engine
self,
ignore_specs: list[IgnoreSpecType],
inspector: BaseInspector,
engine: Engine,
schema: str | None = None,
) -> dict | None:
try:
return inspector.inspect(engine, ignore_specs)
return inspector.inspect(engine, ignore_specs, schema=schema)
except InspectorNotSupported as e:
logger.warning({"engine": engine, "inspector": inspector.key, "error": e.message})

Expand Down
5 changes: 4 additions & 1 deletion src/sqlalchemydiff/inspection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ def __init__(self, one_alias: str = "one", two_alias: str = "two"):

@abc.abstractmethod
def inspect(
self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None
self,
engine: Engine,
ignore_specs: list[IgnoreSpecType] | None = None,
schema: str | None = None,
) -> Any: ... # pragma: no cover

@abc.abstractmethod
Expand Down
90 changes: 65 additions & 25 deletions src/sqlalchemydiff/inspection/inspectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,24 @@ class TablesInspector(BaseInspector, DiffMixin):
key = "tables"
db_level = True

def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict:
def inspect(
self,
engine: Engine,
ignore_specs: list[IgnoreSpecType] | None = None,
schema: str | None = None,
) -> dict:
ignore_clauses = self._filter_ignorers(ignore_specs)
inspector = self._get_inspector(engine)

def get_comment(table_name: str) -> str | None:
try:
return inspector.get_table_comment(table_name)["text"]
return inspector.get_table_comment(table_name, schema=schema)["text"]
except NotImplementedError:
return

return {
table_name: self._format_table(table_name, get_comment(table_name))
for table_name in inspector.get_table_names()
for table_name in inspector.get_table_names(schema=schema)
if table_name not in ignore_clauses.tables
}

Expand All @@ -48,11 +53,16 @@ class ColumnsInspector(BaseInspector, DiffMixin):

key = "columns"

def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict:
def inspect(
self,
engine: Engine,
ignore_specs: list[IgnoreSpecType] | None = None,
schema: str | None = None,
) -> dict:
ignore_clauses = self._filter_ignorers(ignore_specs)

inspector = self._get_inspector(engine)
table_names = inspector.get_table_names()
table_names = inspector.get_table_names(schema=schema)

result = {}
for table_name in table_names:
Expand All @@ -61,7 +71,7 @@ def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = No

result[table_name] = [
column_item
for column_item in inspector.get_columns(table_name)
for column_item in inspector.get_columns(table_name, schema=schema)
if not ignore_clauses.is_clause(table_name, self.key, column_item["name"])
]
self._process_types(result[table_name], engine)
Expand Down Expand Up @@ -90,17 +100,22 @@ class PrimaryKeysInspector(BaseInspector, DiffMixin):

key = "primary_keys"

def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict:
def inspect(
self,
engine: Engine,
ignore_specs: list[IgnoreSpecType] | None = None,
schema: str | None = None,
) -> dict:
ignore_clauses = self._filter_ignorers(ignore_specs)

inspector = self._get_inspector(engine)
table_names = inspector.get_table_names()
table_names = inspector.get_table_names(schema=schema)
result = {}

for table_name in table_names:
if table_name in ignore_clauses.tables:
continue
inspection_result = inspector.get_pk_constraint(table_name)
inspection_result = inspector.get_pk_constraint(table_name, schema=schema)

if not ignore_clauses.is_clause(table_name, self.key, inspection_result["name"]):
result[table_name] = inspection_result
Expand All @@ -121,19 +136,24 @@ class ForeignKeysInspector(BaseInspector, DiffMixin):

key = "foreign_keys"

def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict:
def inspect(
self,
engine: Engine,
ignore_specs: list[IgnoreSpecType] | None = None,
schema: str | None = None,
) -> dict:
ignore_clauses = self._filter_ignorers(ignore_specs)

inspector = self._get_inspector(engine)
table_names = inspector.get_table_names()
table_names = inspector.get_table_names(schema=schema)
result = {}
for table_name in table_names:
if table_name in ignore_clauses.tables:
continue

result[table_name] = [
self._get_fk_identifier(fk)
for fk in inspector.get_foreign_keys(table_name)
for fk in inspector.get_foreign_keys(table_name, schema=schema)
if not ignore_clauses.is_clause(table_name, self.key, fk["name"])
]
return result
Expand All @@ -155,18 +175,23 @@ class IndexesInspector(BaseInspector, DiffMixin):

key = "indexes"

def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict:
def inspect(
self,
engine: Engine,
ignore_specs: list[IgnoreSpecType] | None = None,
schema: str | None = None,
) -> dict:
ignore_clauses = self._filter_ignorers(ignore_specs)
inspector = self._get_inspector(engine)
table_names = inspector.get_table_names()
table_names = inspector.get_table_names(schema=schema)
result = {}
for table_name in table_names:
if table_name in ignore_clauses.tables:
continue

result[table_name] = [
index
for index in inspector.get_indexes(table_name)
for index in inspector.get_indexes(table_name, schema=schema)
if not ignore_clauses.is_clause(table_name, self.key, index["name"])
]
return result
Expand All @@ -183,27 +208,34 @@ class UniqueConstraintsInspector(BaseInspector, DiffMixin):

key = "unique_constraints"

def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict:
def inspect(
self,
engine: Engine,
ignore_specs: list[IgnoreSpecType] | None = None,
schema: str | None = None,
) -> dict:
ignore_clauses = self._filter_ignorers(ignore_specs)
inspector = self._get_inspector(engine)
table_names = inspector.get_table_names()
table_names = inspector.get_table_names(schema=schema)
result = {}
for table_name in table_names:
if table_name in ignore_clauses.tables:
continue

result[table_name] = [
uc
for uc in self._format_unique_constraint(inspector, table_name)
for uc in self._format_unique_constraint(inspector, table_name, schema=schema)
if not ignore_clauses.is_clause(table_name, self.key, uc["name"])
]
return result

def diff(self, one: dict, two: dict) -> dict:
return self._listdiff(one, two)

def _format_unique_constraint(self, inspector: Inspector, table_name: str) -> list[dict]:
result = inspector.get_unique_constraints(table_name)
def _format_unique_constraint(
self, inspector: Inspector, table_name: str, schema: str | None = None
) -> list[dict]:
result = inspector.get_unique_constraints(table_name, schema=schema)
for constraint in result:
name = constraint.get("name")
if not name:
Expand All @@ -220,18 +252,23 @@ class CheckConstraintsInspector(BaseInspector, DiffMixin):

key = "check_constraints"

def inspect(self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None) -> dict:
def inspect(
self,
engine: Engine,
ignore_specs: list[IgnoreSpecType] | None = None,
schema: str | None = None,
) -> dict:
ignore_clauses = self._filter_ignorers(ignore_specs)
inspector = self._get_inspector(engine)
table_names = inspector.get_table_names()
table_names = inspector.get_table_names(schema=schema)
result = {}
for table_name in table_names:
if table_name in ignore_clauses.tables:
continue

result[table_name] = [
cc
for cc in inspector.get_check_constraints(table_name)
for cc in inspector.get_check_constraints(table_name, schema=schema)
if not ignore_clauses.is_clause(table_name, self.key, cc["name"])
]
return result
Expand All @@ -250,12 +287,15 @@ class EnumsInspector(BaseInspector, DiffMixin):
db_level = True

def inspect(
self, engine: Engine, ignore_specs: list[IgnoreSpecType] | None = None
self,
engine: Engine,
ignore_specs: list[IgnoreSpecType] | None = None,
schema: str | None = None,
) -> list[dict]:
inspector = self._get_inspector(engine)

ignore_clauses = self._filter_ignorers(ignore_specs)
enums = getattr(inspector, "get_enums", lambda: [])() or []
enums = getattr(inspector, "get_enums", lambda schema=None: [])(schema=schema) or []
return [enum for enum in enums if enum["name"] not in ignore_clauses.enums]

def diff(self, one: dict, two: dict) -> dict:
Expand Down