Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,25 +543,36 @@ def _detect_foreign_keys_by_column_name(self, data):
"""
for parent_candidate in self.tables.keys():
primary_key = self.tables[parent_candidate].primary_key
if primary_key is None:
continue

pk_sdtype = self.tables[parent_candidate].columns[primary_key]['sdtype']
for child_candidate in self.tables.keys() - {parent_candidate}:
child_meta = self.tables[child_candidate]
if primary_key in child_meta.columns.keys():
original_fk_meta = deepcopy(child_meta.columns[primary_key])
original_fk_sdtype = original_fk_meta['sdtype']
if pk_sdtype != 'id' and original_fk_sdtype != pk_sdtype:
continue

try:
original_foreign_key_sdtype = child_meta.columns[primary_key]['sdtype']
if original_foreign_key_sdtype != 'id':
if pk_sdtype == 'id' and original_fk_sdtype != 'id':
self.update_column(
table_name=child_candidate, column_name=primary_key, sdtype='id'
table_name=child_candidate,
column_name=primary_key,
sdtype='id',
)

self.add_relationship(
parent_candidate, child_candidate, primary_key, primary_key
)

except InvalidMetadataError:
self.update_column(
table_name=child_candidate,
column_name=primary_key,
sdtype=original_foreign_key_sdtype,
)
if pk_sdtype == 'id' and original_fk_sdtype != 'id':
self.update_column(
table_name=child_candidate,
column_name=primary_key,
**original_fk_meta,
)
continue

def _detect_relationships(self, data=None, foreign_key_inference_algorithm='column_name_match'):
Expand Down
105 changes: 105 additions & 0 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from sdv.errors import InvalidDataError
from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import MultiTableMetadata, SingleTableMetadata
from tests.utils import catch_sdv_logs, get_multi_table_data, get_multi_table_metadata

Expand Down Expand Up @@ -2624,6 +2625,110 @@ def test__detect_relationships(self):
assert instance.relationships == expected_relationships
assert instance.tables['sessions'].columns['user_id']['sdtype'] == 'id'

def test__detect_relationships_semantic_foreign_key(self):
"""Test semantic foreign keys are automatically detected without changing the sdtype."""
# Setup
instance = Metadata.load_from_dict({
'tables': {
'parent': {
'primary_key': 'email',
'columns': {
'email': {'sdtype': 'email'},
'user_name': {'sdtype': 'categorical'},
},
},
'child': {
'primary_key': 'child_id',
'columns': {
'child_id': {'sdtype': 'id'},
'email': {'sdtype': 'email', 'pii': True},
},
},
},
'relationships': [],
})

# Run
instance._detect_relationships()

# Assert
expected_relationships = [
{
'parent_table_name': 'parent',
'child_table_name': 'child',
'parent_primary_key': 'email',
'child_foreign_key': 'email',
}
]
assert instance.relationships == expected_relationships
assert instance.tables['child'].columns['email'] == {'sdtype': 'email', 'pii': True}
assert instance.tables['parent'].columns['email'] == {'sdtype': 'email'}
assert instance.tables['parent'].primary_key == 'email'

def test__detect_relationships_semantic_foreign_key_does_not_overwrite_mismatch(self):
"""Test semantic foreign key mismatches do not coerce the child sdtype."""
# Setup
instance = Metadata.load_from_dict({
'tables': {
'parent': {
'primary_key': 'email',
'columns': {
'email': {'sdtype': 'email'},
'user_name': {'sdtype': 'categorical'},
},
},
'child': {
'primary_key': 'child_id',
'columns': {
'child_id': {'sdtype': 'id'},
'email': {'sdtype': 'categorical'},
},
},
},
'relationships': [],
})

# Run
instance._detect_relationships()

# Assert
assert instance.relationships == []
assert instance.tables['child'].columns['email'] == {'sdtype': 'categorical'}
assert instance.tables['parent'].columns['email'] == {'sdtype': 'email'}
assert instance.tables['parent'].primary_key == 'email'

def test__detect_relationships_restores_foreign_key_metadata_after_failure(self):
"""Test failed detection restores all original metadata values in the child foreign key."""
# Setup
original_foreign_key_metadata = {'sdtype': 'email', 'pii': True}
instance = Metadata.load_from_dict({
'tables': {
'users': {
'primary_key': 'user_id',
'columns': {
'user_id': {'sdtype': 'id'},
'user_name': {'sdtype': 'categorical'},
},
},
'sessions': {
'primary_key': 'session_id',
'columns': {
'user_id': original_foreign_key_metadata.copy(),
'session_id': {'sdtype': 'id'},
},
},
},
'relationships': [],
})
instance.add_relationship = Mock(side_effect=InvalidMetadataError('bad relationship'))

# Run
instance._detect_relationships()

# Assert
instance.add_relationship.assert_called_once_with('users', 'sessions', 'user_id', 'user_id')
assert instance.tables['sessions'].columns['user_id'] == original_foreign_key_metadata

def test__detect_relationships_circular(self):
"""Test that relationships that invalidate the metadata are not added."""
# Setup
Expand Down