diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index 6e77b1489..be734f047 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -325,7 +325,7 @@ def add_column(self, column_name, table_name=None, **kwargs): table_name = self._handle_table_name(table_name) super().add_column(table_name, column_name, **kwargs) - def _remove_matching_relationships(self, element, keys): + def _remove_relationships_by_table(self, element, keys): """Remove relationships where the element matches the keys to check.""" updated_relationships = [] for relationship in self.relationships: @@ -335,6 +335,22 @@ def _remove_matching_relationships(self, element, keys): self.relationships = updated_relationships + def _remove_relationships_by_column(self, table_name, column_name): + """Remove relationships where the column is a key for the given table.""" + updated_relationships = [] + for relationship in self.relationships: + should_remove = ( + relationship['child_foreign_key'] == column_name + and relationship['child_table_name'] == table_name + ) or ( + relationship['parent_primary_key'] == column_name + and relationship['parent_table_name'] == table_name + ) + if not should_remove: + updated_relationships.append(relationship) + + self.relationships = updated_relationships + def remove_table(self, table_name): """Remove a table from the metadata. @@ -348,7 +364,7 @@ def remove_table(self, table_name): self._validate_table_exists(table_name) # Remove relationships - self._remove_matching_relationships(table_name, ['parent_table_name', 'child_table_name']) + self._remove_relationships_by_table(table_name, ['parent_table_name', 'child_table_name']) del self.tables[table_name] self._multi_table_updated = True @@ -380,9 +396,7 @@ def remove_column(self, column_name, table_name=None): table_metadata._validate_column_exists(column_name) # Remove relationships - self._remove_matching_relationships( - column_name, ['parent_primary_key', 'child_foreign_key'] - ) + self._remove_relationships_by_column(table_name, column_name) updated_column_relationships = [] for column_relationship in table_metadata.column_relationships: if column_name not in column_relationship.get('column_names', []): diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index 067ff293c..95ab8ca6a 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -1045,6 +1045,101 @@ def test_remove_column_removes_relationships(self): assert list(manufacturer_mock.columns.keys()) == ['country', 'id'] assert metadata._multi_table_updated + def test__remove_relationships_by_column_only_removes_matching_table(self): + """Test that only relationships for the given table and column are removed.""" + # Setup + metadata = Metadata() + metadata.relationships = [ + { + 'parent_table_name': 'parent', + 'parent_primary_key': 'id', + 'child_table_name': 'child_a', + 'child_foreign_key': 'fk', + }, + { + 'parent_table_name': 'parent', + 'parent_primary_key': 'id', + 'child_table_name': 'child_b', + 'child_foreign_key': 'fk', + }, + ] + + # Run + metadata._remove_relationships_by_column('child_a', 'fk') + + # Assert + assert metadata.relationships == [ + { + 'parent_table_name': 'parent', + 'parent_primary_key': 'id', + 'child_table_name': 'child_b', + 'child_foreign_key': 'fk', + }, + ] + + def test_remove_column_only_removes_relationship_for_that_table(self): + """Test removing a foreign key column only removes the relationship for that table.""" + # Setup + metadata = Metadata.load_from_dict({ + 'tables': { + 'table1': { + 'primary_key': 'id', + 'columns': { + 'id': {'sdtype': 'id'}, + 'A': {'sdtype': 'numerical'}, + 'B': {'sdtype': 'categorical'}, + }, + }, + 'table2': { + 'primary_key': 'id', + 'columns': { + 'id': {'sdtype': 'id'}, + 'fk_1': {'sdtype': 'id'}, + 'A': {'sdtype': 'numerical'}, + 'B': {'sdtype': 'categorical'}, + }, + }, + 'table3': { + 'primary_key': 'id', + 'columns': { + 'id': {'sdtype': 'id'}, + 'fk_1': {'sdtype': 'id'}, + 'A': {'sdtype': 'numerical'}, + 'B': {'sdtype': 'categorical'}, + }, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'table1', + 'parent_primary_key': 'id', + 'child_table_name': 'table2', + 'child_foreign_key': 'fk_1', + }, + { + 'parent_table_name': 'table1', + 'parent_primary_key': 'id', + 'child_table_name': 'table3', + 'child_foreign_key': 'fk_1', + }, + ], + }) + + # Run + metadata.remove_column('fk_1', 'table2') + + # Assert + assert metadata.relationships == [ + { + 'parent_table_name': 'table1', + 'parent_primary_key': 'id', + 'child_table_name': 'table3', + 'child_foreign_key': 'fk_1', + }, + ] + assert 'fk_1' not in metadata.tables['table2'].columns + assert 'fk_1' in metadata.tables['table3'].columns + def test_remove_column_sequence_key(self): """Test the method also remove the sequence key if the column is one.""" # Setup