diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 5ea9562c3..2a5fdff00 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -536,6 +536,9 @@ def _validate_all_tables_connected(self, parent_map, child_map): def _detect_foreign_keys_by_column_name(self, data): """Detect the foreign keys based on if a column name matches a primary key. + If a column name (a child table) is a primary key, it will also be considered + to be a valid candidate for a foreign key. + Args: data (dict): Dictionary of table names to dataframes. @@ -567,6 +570,7 @@ def _detect_foreign_keys_by_column_name(self, data): ) except InvalidMetadataError: + # circular relationship if pk_sdtype == 'id' and original_fk_sdtype != 'id': self.update_column( table_name=child_candidate, diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 8742e5f99..8fde63e8e 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -439,7 +439,7 @@ def _validate_all_tables(self, data): def validate(self, data): """Validate the data. - Validate that the metadata matches the data and thta every table's constraints are valid. + Validate that the metadata matches the data and that every table's constraints are valid. Args: data (dict[str, pd.DataFrame]): diff --git a/tests/integration/metadata/test_metadata.py b/tests/integration/metadata/test_metadata.py index 17ce74083..2d1e1ed89 100644 --- a/tests/integration/metadata/test_metadata.py +++ b/tests/integration/metadata/test_metadata.py @@ -879,6 +879,80 @@ def test_detect_from_dataframes_invalid_format(): Metadata.detect_from_dataframes(data) +def test_detect_from_dataframes__primary_to_primary(): + """Test metadata auto-detection works for primary to primary relationships.""" + # Setup + data = { + 'tableA': pd.DataFrame({ + 'table_A_id': range(5), + 'column_1': ['A', 'B', 'B', 'C', 'C'], + }), + 'tableB': pd.DataFrame({ + 'table_A_id': range(5), + 'column_2': ['A', 'B', 'B', 'C', 'C'], + }), + } + + # Run + detected_metadata = Metadata().detect_from_dataframes( + data, foreign_key_inference_algorithm='column_name_match' + ) + + # Assert + assert detected_metadata.tables['tableA'].primary_key == 'table_A_id' + assert detected_metadata.tables['tableB'].primary_key == 'table_A_id' + assert detected_metadata.relationships == [ + { + 'parent_table_name': 'tableA', + 'child_table_name': 'tableB', + 'parent_primary_key': 'table_A_id', + 'child_foreign_key': 'table_A_id', + } + ] + + +def test_detect_from_dataframes__primary_to_primary_no_cycles(): + """Test metadata auto-detection does not create cycles with PK to PK.""" + # Setup + data = { + 'tableA': pd.DataFrame({ + 'table_A_id': range(5), + 'column_1': ['A', 'B', 'B', 'C', 'C'], + }), + 'tableB': pd.DataFrame({ + 'table_A_id': range(5), + 'column_2': ['A', 'B', 'B', 'C', 'C'], + }), + 'tableC': pd.DataFrame({ + 'table_A_id': range(5), + 'column_2': ['A', 'B', 'B', 'C', 'C'], + }), + } + + # Run + detected_metadata = Metadata().detect_from_dataframes( + data, foreign_key_inference_algorithm='column_name_match' + ) + + # Assert + assert detected_metadata.tables['tableA'].primary_key == 'table_A_id' + assert detected_metadata.tables['tableB'].primary_key == 'table_A_id' + assert detected_metadata.tables['tableC'].primary_key == 'table_A_id' + assert len(detected_metadata.relationships) == 2 + assert { + 'parent_table_name': 'tableA', # PK to PK + 'child_table_name': 'tableC', + 'parent_primary_key': 'table_A_id', + 'child_foreign_key': 'table_A_id', + } in detected_metadata.relationships + assert { + 'parent_table_name': 'tableA', # PK to PK + 'child_table_name': 'tableB', + 'parent_primary_key': 'table_A_id', + 'child_foreign_key': 'table_A_id', + } in detected_metadata.relationships + + def test_validate_metadata_with_reused_foreign_keys(): # Setup metadata_dict = { diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index 95ab8ca6a..dd5a33708 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -727,6 +727,75 @@ def test_detect_from_dataframes_bad_input_infer_keys(self): with pytest.raises(ValueError, match=expected_message): Metadata.detect_from_dataframes(data, infer_keys=infer_keys) + def test_detect_from_dataframe_primary_key_to_primary_key(self): + """Test primary to primary key relationship is detected if column name match.""" + # Setup + data = { + 'table1': pd.DataFrame({ + 'id': [1, 2, 3], + }), + 'table2': pd.DataFrame({ + 'id': [1, 2, 3], + }), + } + instance = Metadata() + instance.detect_table_from_dataframe('table1', data['table1']) + instance.detect_table_from_dataframe('table2', data['table2']) + + # Run + instance._detect_foreign_keys_by_column_name(data) + + # Assert + assert instance.to_dict()['relationships'] == [ + { + 'parent_table_name': 'table1', + 'child_table_name': 'table2', + 'parent_primary_key': 'id', + 'child_foreign_key': 'id', + } + ] + + def test_detect_from_dataframe_primary_key_to_primary_key_no_cycles(self): + """Test no cycles are created with primary to primary key relationship.""" + # Setup + data = { + 'table1': pd.DataFrame({ + 'id': [1, 2, 3], + }), + 'table2': pd.DataFrame({ + 'id': [1, 2, 3], + }), + 'table3': pd.DataFrame({ + 'id': [1, 2, 3], + }), + } + instance = Metadata() + instance.detect_table_from_dataframe('table1', data['table1']) + instance.detect_table_from_dataframe('table2', data['table2']) + instance.detect_table_from_dataframe('table3', data['table3']) + + # Run + instance._detect_foreign_keys_by_column_name(data) + + # Assert + expected = [ + { + 'parent_table_name': 'table1', + 'child_table_name': 'table2', + 'parent_primary_key': 'id', + 'child_foreign_key': 'id', + }, + { + 'parent_table_name': 'table1', + 'child_table_name': 'table3', + 'parent_primary_key': 'id', + 'child_foreign_key': 'id', + }, + ] + for rel in expected: + assert rel in instance.to_dict()['relationships'] + assert len(instance.relationships) == 2 + @patch('sdv.metadata.metadata.Metadata') def test_detect_from_dataframe(self, mock_metadata): """Test that the method calls the detection method and returns the metadata.