Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,11 @@ def visualize(
foreign_key = relationship.get('child_foreign_key')
primary_key = self.tables.get(parent).primary_key
edge_label = f' {foreign_key} → {primary_key}' if show_relationship_labels else ''
edges.append((parent, child, edge_label))
child_primary_key = self.tables.get(child).primary_key
if foreign_key == child_primary_key:
edges.append((parent, child, edge_label, 'one-to-one'))
else:
edges.append((parent, child, edge_label))

if show_table_details is not None:
child_node = nodes.get(child)
Expand Down
16 changes: 14 additions & 2 deletions sdv/metadata/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,20 @@ def visualize_graph(nodes, edges, filepath=None):
for name, label in nodes.items():
digraph.node(name, label=_replace_special_characters(label))

for parent, child, label in edges:
digraph.edge(parent, child, label=_replace_special_characters(label), arrowhead='oinv')
for edge in edges:
parent, child, label = edge[0], edge[1], edge[2]
relation_type = edge[3] if len(edge) > 3 else 'one-to-many'
if relation_type == 'one-to-one':
digraph.edge(
parent,
child,
label=_replace_special_characters(label),
arrowhead='noneteeodot',
arrowtail='nonetee',
dir='both',
)
else:
digraph.edge(parent, child, label=_replace_special_characters(label), arrowhead='oinv')

if filename:
digraph.render(filename=filename, cleanup=True, format=graphviz_extension)
Expand Down
9 changes: 7 additions & 2 deletions tests/integration/metadata/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,10 @@ def test_visualize_pk_to_pk(primary_key_to_primary_key):
# Setup
_, metadata = primary_key_to_primary_key

# Run and Assert
metadata.visualize()
# Run
graph = metadata.visualize()

# Assert
assert 'arrowhead=noneteeodot' in graph.source
assert 'arrowtail=nonetee' in graph.source
assert 'dir=both' in graph.source
47 changes: 47 additions & 0 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2429,6 +2429,53 @@ def test_visualize_show_table_details_only(self, visualize_graph_mock):
]
visualize_graph_mock.assert_called_once_with(expected_nodes, expected_edges, 'output.jpg')

@patch('sdv.metadata.multi_table.visualize_graph')
def test_visualize_pk_to_pk_relationship(self, visualize_graph_mock):
"""Test that PK-to-PK relationships produce a 'one-to-one' edge type."""
# Setup
metadata = MultiTableMetadata.load_from_dict({
'tables': {
'parent_table': {
'columns': {
'pk': {'sdtype': 'id'},
'col': {'sdtype': 'categorical'},
},
'primary_key': 'pk',
},
'child_table': {
'columns': {
'pk': {'sdtype': 'id'},
'col': {'sdtype': 'numerical'},
},
'primary_key': 'pk',
},
},
'relationships': [
{
'parent_table_name': 'parent_table',
'parent_primary_key': 'pk',
'child_table_name': 'child_table',
'child_foreign_key': 'pk',
}
],
})

# Run
metadata.visualize('full', True)

# Assert
expected_nodes = {
'parent_table': ('{parent_table|pk : id\\lcol : categorical\\l|Primary key: pk\\l}'),
'child_table': (
'{child_table|pk : id\\lcol : numerical\\l|'
'Primary key: pk\\lForeign key (parent_table): pk\\l}'
),
}
expected_edges = [
('parent_table', 'child_table', ' pk → pk', 'one-to-one'),
]
visualize_graph_mock.assert_called_once_with(expected_nodes, expected_edges, None)

def test_add_column(self):
"""Test the ``add_column`` method.

Expand Down