Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,9 @@ def _validate_all_tables_connected(self, parent_map, child_map):
def _detect_foreign_keys_by_column_name(self, data):
"""Detect the foreign keys based on if a column name matches a primary key.
If a column name (a child table) is a primary key, it will also be considered
to be a valid candidate for a foreign key.
Args:
data (dict):
Dictionary of table names to dataframes.
Expand Down Expand Up @@ -567,6 +570,7 @@ def _detect_foreign_keys_by_column_name(self, data):
)

except InvalidMetadataError:
# circular relationship
if pk_sdtype == 'id' and original_fk_sdtype != 'id':
self.update_column(
table_name=child_candidate,
Expand Down
2 changes: 1 addition & 1 deletion sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def _validate_all_tables(self, data):
def validate(self, data):
"""Validate the data.

Validate that the metadata matches the data and thta every table's constraints are valid.
Validate that the metadata matches the data and that every table's constraints are valid.

Args:
data (dict[str, pd.DataFrame]):
Expand Down
74 changes: 74 additions & 0 deletions tests/integration/metadata/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,80 @@ def test_detect_from_dataframes_invalid_format():
Metadata.detect_from_dataframes(data)


def test_detect_from_dataframes__primary_to_primary():
"""Test metadata auto-detection works for primary to primary relationships."""
# Setup
data = {
'tableA': pd.DataFrame({
'table_A_id': range(5),
'column_1': ['A', 'B', 'B', 'C', 'C'],
}),
'tableB': pd.DataFrame({
'table_A_id': range(5),
'column_2': ['A', 'B', 'B', 'C', 'C'],
}),
}

# Run
detected_metadata = Metadata().detect_from_dataframes(
data, foreign_key_inference_algorithm='column_name_match'
)

# Assert
assert detected_metadata.tables['tableA'].primary_key == 'table_A_id'
assert detected_metadata.tables['tableB'].primary_key == 'table_A_id'
assert detected_metadata.relationships == [
{
'parent_table_name': 'tableA',
'child_table_name': 'tableB',
'parent_primary_key': 'table_A_id',
'child_foreign_key': 'table_A_id',
}
]


def test_detect_from_dataframes__primary_to_primary_no_cycles():
"""Test metadata auto-detection does not create cycles with PK to PK."""
# Setup
data = {
'tableA': pd.DataFrame({
'table_A_id': range(5),
'column_1': ['A', 'B', 'B', 'C', 'C'],
}),
'tableB': pd.DataFrame({
'table_A_id': range(5),
'column_2': ['A', 'B', 'B', 'C', 'C'],
}),
'tableC': pd.DataFrame({
'table_A_id': range(5),
'column_2': ['A', 'B', 'B', 'C', 'C'],
}),
}

# Run
detected_metadata = Metadata().detect_from_dataframes(
data, foreign_key_inference_algorithm='column_name_match'
)

# Assert
assert detected_metadata.tables['tableA'].primary_key == 'table_A_id'
assert detected_metadata.tables['tableB'].primary_key == 'table_A_id'
assert detected_metadata.tables['tableC'].primary_key == 'table_A_id'
assert len(detected_metadata.relationships) == 2
assert {
'parent_table_name': 'tableA', # PK to PK
'child_table_name': 'tableC',
'parent_primary_key': 'table_A_id',
'child_foreign_key': 'table_A_id',
} in detected_metadata.relationships
assert {
'parent_table_name': 'tableA', # PK to PK
'child_table_name': 'tableB',
'parent_primary_key': 'table_A_id',
'child_foreign_key': 'table_A_id',
} in detected_metadata.relationships


def test_validate_metadata_with_reused_foreign_keys():
# Setup
metadata_dict = {
Expand Down
69 changes: 69 additions & 0 deletions tests/unit/metadata/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,75 @@ def test_detect_from_dataframes_bad_input_infer_keys(self):
with pytest.raises(ValueError, match=expected_message):
Metadata.detect_from_dataframes(data, infer_keys=infer_keys)

def test_detect_from_dataframe_primary_key_to_primary_key(self):
"""Test primary to primary key relationship is detected if column name match."""
# Setup
data = {
'table1': pd.DataFrame({
'id': [1, 2, 3],
}),
'table2': pd.DataFrame({
'id': [1, 2, 3],
}),
}
instance = Metadata()
instance.detect_table_from_dataframe('table1', data['table1'])
instance.detect_table_from_dataframe('table2', data['table2'])

# Run
instance._detect_foreign_keys_by_column_name(data)

# Assert
assert instance.to_dict()['relationships'] == [
{
'parent_table_name': 'table1',
'child_table_name': 'table2',
'parent_primary_key': 'id',
'child_foreign_key': 'id',
}
]

def test_detect_from_dataframe_primary_key_to_primary_key_no_cycles(self):
"""Test no cycles are created with primary to primary key relationship."""
# Setup
data = {
'table1': pd.DataFrame({
'id': [1, 2, 3],
}),
'table2': pd.DataFrame({
'id': [1, 2, 3],
}),
'table3': pd.DataFrame({
'id': [1, 2, 3],
}),
}
instance = Metadata()
instance.detect_table_from_dataframe('table1', data['table1'])
instance.detect_table_from_dataframe('table2', data['table2'])
instance.detect_table_from_dataframe('table3', data['table3'])

# Run
instance._detect_foreign_keys_by_column_name(data)

# Assert
expected = [
{
'parent_table_name': 'table1',
'child_table_name': 'table2',
'parent_primary_key': 'id',
'child_foreign_key': 'id',
},
{
'parent_table_name': 'table1',
'child_table_name': 'table3',
'parent_primary_key': 'id',
'child_foreign_key': 'id',
},
]
for rel in expected:
assert rel in instance.to_dict()['relationships']
assert len(instance.relationships) == 2

@patch('sdv.metadata.metadata.Metadata')
def test_detect_from_dataframe(self, mock_metadata):
"""Test that the method calls the detection method and returns the metadata.
Expand Down