diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 097dabe94..99aed6443 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -1143,7 +1143,11 @@ def visualize( foreign_key = relationship.get('child_foreign_key') primary_key = self.tables.get(parent).primary_key edge_label = f' {foreign_key} → {primary_key}' if show_relationship_labels else '' - edges.append((parent, child, edge_label)) + child_primary_key = self.tables.get(child).primary_key + if foreign_key == child_primary_key: + edges.append((parent, child, edge_label, 'one-to-one')) + else: + edges.append((parent, child, edge_label)) if show_table_details is not None: child_node = nodes.get(child) diff --git a/sdv/metadata/visualization.py b/sdv/metadata/visualization.py index 9db0d9734..3a3c9da7e 100644 --- a/sdv/metadata/visualization.py +++ b/sdv/metadata/visualization.py @@ -102,8 +102,20 @@ def visualize_graph(nodes, edges, filepath=None): for name, label in nodes.items(): digraph.node(name, label=_replace_special_characters(label)) - for parent, child, label in edges: - digraph.edge(parent, child, label=_replace_special_characters(label), arrowhead='oinv') + for edge in edges: + parent, child, label = edge[0], edge[1], edge[2] + relation_type = edge[3] if len(edge) > 3 else 'one-to-many' + if relation_type == 'one-to-one': + digraph.edge( + parent, + child, + label=_replace_special_characters(label), + arrowhead='noneteeodot', + arrowtail='nonetee', + dir='both', + ) + else: + digraph.edge(parent, child, label=_replace_special_characters(label), arrowhead='oinv') if filename: digraph.render(filename=filename, cleanup=True, format=graphviz_extension) diff --git a/tests/integration/metadata/test_visualization.py b/tests/integration/metadata/test_visualization.py index beb936216..768ff9116 100644 --- a/tests/integration/metadata/test_visualization.py +++ b/tests/integration/metadata/test_visualization.py @@ -63,5 +63,10 @@ def test_visualize_pk_to_pk(primary_key_to_primary_key): # Setup _, metadata = primary_key_to_primary_key - # Run and Assert - metadata.visualize() + # Run + graph = metadata.visualize() + + # Assert + assert 'arrowhead=noneteeodot' in graph.source + assert 'arrowtail=nonetee' in graph.source + assert 'dir=both' in graph.source diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index 943b5c547..a4c9c94f9 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -2429,6 +2429,53 @@ def test_visualize_show_table_details_only(self, visualize_graph_mock): ] visualize_graph_mock.assert_called_once_with(expected_nodes, expected_edges, 'output.jpg') + @patch('sdv.metadata.multi_table.visualize_graph') + def test_visualize_pk_to_pk_relationship(self, visualize_graph_mock): + """Test that PK-to-PK relationships produce a 'one-to-one' edge type.""" + # Setup + metadata = MultiTableMetadata.load_from_dict({ + 'tables': { + 'parent_table': { + 'columns': { + 'pk': {'sdtype': 'id'}, + 'col': {'sdtype': 'categorical'}, + }, + 'primary_key': 'pk', + }, + 'child_table': { + 'columns': { + 'pk': {'sdtype': 'id'}, + 'col': {'sdtype': 'numerical'}, + }, + 'primary_key': 'pk', + }, + }, + 'relationships': [ + { + 'parent_table_name': 'parent_table', + 'parent_primary_key': 'pk', + 'child_table_name': 'child_table', + 'child_foreign_key': 'pk', + } + ], + }) + + # Run + metadata.visualize('full', True) + + # Assert + expected_nodes = { + 'parent_table': ('{parent_table|pk : id\\lcol : categorical\\l|Primary key: pk\\l}'), + 'child_table': ( + '{child_table|pk : id\\lcol : numerical\\l|' + 'Primary key: pk\\lForeign key (parent_table): pk\\l}' + ), + } + expected_edges = [ + ('parent_table', 'child_table', ' pk → pk', 'one-to-one'), + ] + visualize_graph_mock.assert_called_once_with(expected_nodes, expected_edges, None) + def test_add_column(self): """Test the ``add_column`` method.