Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def add_column(self, column_name, table_name=None, **kwargs):
table_name = self._handle_table_name(table_name)
super().add_column(table_name, column_name, **kwargs)

def _remove_matching_relationships(self, element, keys):
def _remove_relationships_by_table(self, element, keys):
"""Remove relationships where the element matches the keys to check."""
updated_relationships = []
for relationship in self.relationships:
Expand All @@ -335,6 +335,22 @@ def _remove_matching_relationships(self, element, keys):

self.relationships = updated_relationships

def _remove_relationships_by_column(self, table_name, column_name):
"""Remove relationships where the column is a key for the given table."""
updated_relationships = []
for relationship in self.relationships:
should_remove = (
relationship['child_foreign_key'] == column_name
and relationship['child_table_name'] == table_name
) or (
relationship['parent_primary_key'] == column_name
and relationship['parent_table_name'] == table_name
)
if not should_remove:
updated_relationships.append(relationship)

self.relationships = updated_relationships

def remove_table(self, table_name):
"""Remove a table from the metadata.

Expand All @@ -348,7 +364,7 @@ def remove_table(self, table_name):
self._validate_table_exists(table_name)

# Remove relationships
self._remove_matching_relationships(table_name, ['parent_table_name', 'child_table_name'])
self._remove_relationships_by_table(table_name, ['parent_table_name', 'child_table_name'])
del self.tables[table_name]
self._multi_table_updated = True

Expand Down Expand Up @@ -380,9 +396,7 @@ def remove_column(self, column_name, table_name=None):
table_metadata._validate_column_exists(column_name)

# Remove relationships
self._remove_matching_relationships(
column_name, ['parent_primary_key', 'child_foreign_key']
)
self._remove_relationships_by_column(table_name, column_name)
updated_column_relationships = []
for column_relationship in table_metadata.column_relationships:
if column_name not in column_relationship.get('column_names', []):
Expand Down
95 changes: 95 additions & 0 deletions tests/unit/metadata/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,101 @@ def test_remove_column_removes_relationships(self):
assert list(manufacturer_mock.columns.keys()) == ['country', 'id']
assert metadata._multi_table_updated

def test__remove_relationships_by_column_only_removes_matching_table(self):
"""Test that only relationships for the given table and column are removed."""
# Setup
metadata = Metadata()
metadata.relationships = [
{
'parent_table_name': 'parent',
'parent_primary_key': 'id',
'child_table_name': 'child_a',
'child_foreign_key': 'fk',
},
{
'parent_table_name': 'parent',
'parent_primary_key': 'id',
'child_table_name': 'child_b',
'child_foreign_key': 'fk',
},
]

# Run
metadata._remove_relationships_by_column('child_a', 'fk')

# Assert
assert metadata.relationships == [
{
'parent_table_name': 'parent',
'parent_primary_key': 'id',
'child_table_name': 'child_b',
'child_foreign_key': 'fk',
},
]

def test_remove_column_only_removes_relationship_for_that_table(self):
"""Test removing a foreign key column only removes the relationship for that table."""
# Setup
metadata = Metadata.load_from_dict({
'tables': {
'table1': {
'primary_key': 'id',
'columns': {
'id': {'sdtype': 'id'},
'A': {'sdtype': 'numerical'},
'B': {'sdtype': 'categorical'},
},
},
'table2': {
'primary_key': 'id',
'columns': {
'id': {'sdtype': 'id'},
'fk_1': {'sdtype': 'id'},
'A': {'sdtype': 'numerical'},
'B': {'sdtype': 'categorical'},
},
},
'table3': {
'primary_key': 'id',
'columns': {
'id': {'sdtype': 'id'},
'fk_1': {'sdtype': 'id'},
'A': {'sdtype': 'numerical'},
'B': {'sdtype': 'categorical'},
},
},
},
'relationships': [
{
'parent_table_name': 'table1',
'parent_primary_key': 'id',
'child_table_name': 'table2',
'child_foreign_key': 'fk_1',
},
{
'parent_table_name': 'table1',
'parent_primary_key': 'id',
'child_table_name': 'table3',
'child_foreign_key': 'fk_1',
},
],
})

# Run
metadata.remove_column('fk_1', 'table2')

# Assert
assert metadata.relationships == [
{
'parent_table_name': 'table1',
'parent_primary_key': 'id',
'child_table_name': 'table3',
'child_foreign_key': 'fk_1',
},
]
assert 'fk_1' not in metadata.tables['table2'].columns
assert 'fk_1' in metadata.tables['table3'].columns

def test_remove_column_sequence_key(self):
"""Test the method also remove the sequence key if the column is one."""
# Setup
Expand Down