Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(self):
self.column_relationships = []
self._version = self.METADATA_SPEC_VERSION
self._updated = False
self._valid_column_relationships = []

@property
def _primary_key_is_composite(self):
Expand Down Expand Up @@ -836,6 +837,20 @@ def _validate_key(self, column_name, key_type):

self._validate_keys_sdtype([column_name], key_type)

def _validate_primary_key_not_in_column_relationship(self, primary_key_candidate):
if isinstance(primary_key_candidate, list):
primary_key_candidate = set(primary_key_candidate)
else:
primary_key_candidate = {primary_key_candidate}

for column_relationship in self.column_relationships:
column_names = set(column_relationship['column_names'])
if column_names.intersection(primary_key_candidate):
raise InvalidMetadataError(
f"Cannot set primary key '{primary_key_candidate}' because it is part "
'of a column relationship.'
)

def set_primary_key(self, column_name):
"""Set the metadata primary key.

Expand All @@ -847,6 +862,7 @@ def set_primary_key(self, column_name):
column_name = column_name[0]

self._validate_key(column_name, 'primary')
self._validate_primary_key_not_in_column_relationship(column_name)
if column_name in self.alternate_keys:
warnings.warn(
f"'{column_name}' is currently set as an alternate key and will be removed from "
Expand Down Expand Up @@ -1004,11 +1020,17 @@ def _validate_column_relationship(self, relationship):
f'Must be one of {list(self._COLUMN_RELATIONSHIP_TYPES.keys())}.'
)

primary_keys = set()
if isinstance(self.primary_key, list):
primary_keys = set(self.primary_key)
elif self.primary_key:
primary_keys = {self.primary_key}

errors = []
for column in column_names:
if column not in self.columns:
errors.append(f"Column '{column}' not in metadata.")
elif self.primary_key == column:
if column in primary_keys:
errors.append(f"Cannot use primary key '{column}' in column relationship.")

columns_to_sdtypes = {
Expand Down
77 changes: 58 additions & 19 deletions tests/integration/metadata/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import os
import re
from copy import deepcopy
from unittest.mock import Mock

import pandas as pd
import pytest

from sdv.datasets.demo import download_demo
from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.multi_table.hma import HMASynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from tests.utils import download_test_demo

DEFAULT_TABLE_NAME = 'table'

Expand Down Expand Up @@ -68,7 +69,7 @@ def test_load_from_json_single_table_metadata(tmp_path):
def test_detect_from_dataframes_multi_table():
"""Test the ``detect_from_dataframes`` method works with multi-table."""
# Setup
real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
real_data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')

# Run
metadata = Metadata.detect_from_dataframes(real_data)
Expand Down Expand Up @@ -124,7 +125,7 @@ def test_detect_from_dataframes_multi_table():
def test_detect_from_dataframes_multi_table_without_infer_sdtypes():
"""Test it when infer_sdtypes is False."""
# Setup
real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
real_data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')

# Run
metadata = Metadata.detect_from_dataframes(real_data, infer_sdtypes=False)
Expand Down Expand Up @@ -180,7 +181,7 @@ def test_detect_from_dataframes_multi_table_without_infer_sdtypes():
def test_detect_from_dataframes_multi_table_with_infer_keys_primary_only():
"""Test it when infer_keys is 'primary_only'."""
# Setup
real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
real_data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')

# Run
metadata = Metadata.detect_from_dataframes(real_data, infer_keys='primary_only')
Expand Down Expand Up @@ -229,7 +230,7 @@ def test_detect_from_dataframes_multi_table_with_infer_keys_primary_only():
def test_detect_from_dataframes_multi_table_with_infer_keys_none():
"""Test it when infer_keys is None."""
# Setup
real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
real_data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')

# Run
metadata = Metadata.detect_from_dataframes(real_data, infer_keys=None)
Expand Down Expand Up @@ -276,7 +277,7 @@ def test_detect_from_dataframes_multi_table_with_infer_keys_none():
def test_detect_from_dataframes_single_table():
"""Test the ``detect_from_dataframes`` method works with a single table."""
# Setup
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')
metadata = Metadata.detect_from_dataframes({'table_1': data['hotels']})

# Run
Expand Down Expand Up @@ -305,7 +306,7 @@ def test_detect_from_dataframes_single_table():
def test_detect_from_dataframes_single_table_infer_sdtypes_false():
"""Test it for a single table when infer_sdtypes is False."""
# Setup
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')
metadata = Metadata.detect_from_dataframes({'table_1': data['hotels']}, infer_sdtypes=False)

# Run
Expand Down Expand Up @@ -334,7 +335,7 @@ def test_detect_from_dataframes_single_table_infer_sdtypes_false():
def test_detect_from_dataframes_single_table_infer_keys_primary_only():
"""Test it for a single table when infer_keys is 'primary_only'."""
# Setup
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')
metadata = Metadata.detect_from_dataframes(
{'table_1': data['hotels']}, infer_keys='primary_only'
)
Expand Down Expand Up @@ -365,7 +366,7 @@ def test_detect_from_dataframes_single_table_infer_keys_primary_only():
def test_detect_from_dataframes_single_table_infer_keys_none():
"""Test it for a single table when infer_keys is None."""
# Setup
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')
metadata = Metadata.detect_from_dataframes({'table_1': data['hotels']}, infer_keys=None)

# Run
Expand Down Expand Up @@ -393,7 +394,7 @@ def test_detect_from_dataframes_single_table_infer_keys_none():
def test_detect_from_dataframe():
"""Test that a single table can be detected as a DataFrame."""
# Setup
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')

metadata = Metadata.detect_from_dataframe(data['hotels'])

Expand Down Expand Up @@ -423,7 +424,7 @@ def test_detect_from_dataframe():
def test_detect_from_dataframe_infer_sdtypes_false():
"""Test it when infer_sdtypes is False."""
# Setup
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')
metadata = Metadata.detect_from_dataframe(data['hotels'], infer_sdtypes=False)

# Run
Expand Down Expand Up @@ -452,7 +453,7 @@ def test_detect_from_dataframe_infer_sdtypes_false():
def test_detect_from_dataframe_infer_keys_none():
"""Test it when infer_keys is None."""
# Setup
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')
metadata = Metadata.detect_from_dataframe(data['hotels'], infer_keys=None)

# Run
Expand Down Expand Up @@ -480,7 +481,7 @@ def test_detect_from_dataframe_infer_keys_none():
def test_detect_from_dataframe_infer_keys_none_infer_sdtypes_false():
"""Test it when infer_keys is None and infer_sdtypes is False."""
# Setup
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')
metadata = Metadata.detect_from_dataframe(data['hotels'], infer_keys=None, infer_sdtypes=False)

# Run
Expand Down Expand Up @@ -508,7 +509,7 @@ def test_detect_from_dataframe_infer_keys_none_infer_sdtypes_false():
def test_detect_from_csvs(tmp_path):
"""Test the ``detect_from_csvs`` method."""
# Setup
real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
real_data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels')

metadata = Metadata()

Expand Down Expand Up @@ -571,7 +572,7 @@ def test_detect_from_csvs(tmp_path):
def test_single_table_compatibility(tmp_path):
"""Test if SingleTableMetadata still has compatibility with single table synthesizers."""
# Setup
data, _ = download_demo('single_table', 'fake_hotel_guests')
data, _ = download_test_demo('single_table', 'fake_hotel_guests')
warn_msg = (
"The 'SingleTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
Expand Down Expand Up @@ -623,7 +624,7 @@ def test_single_table_compatibility(tmp_path):
def test_multi_table_compatibility(tmp_path):
"""Test if MultiTableMetadata still has compatibility with multi table synthesizers."""
# Setup
data, _ = download_demo('multi_table', 'fake_hotels')
data, _ = download_test_demo('multi_table', 'fake_hotels')
warn_msg = re.escape(
"The 'MultiTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
Expand Down Expand Up @@ -741,7 +742,7 @@ def test_multi_table_compatibility(tmp_path):
def test_any_metadata_update_single_table(method, args, kwargs):
"""Test that any method that updates metadata works for single-table case."""
# Setup
_, metadata = download_demo('single_table', 'fake_hotel_guests')
_, metadata = download_test_demo('single_table', 'fake_hotel_guests')
metadata.update_column(
table_name='fake_hotel_guests', column_name='billing_address', sdtype='street_address'
)
Expand All @@ -764,7 +765,7 @@ def test_any_metadata_update_single_table(method, args, kwargs):
def test_any_metadata_update_multi_table(method, args, kwargs):
"""Test that any method that updates metadata works for multi-table case."""
# Setup
_, metadata = download_demo('multi_table', 'fake_hotels')
_, metadata = download_test_demo('multi_table', 'fake_hotels')
metadata.update_column(
table_name='guests', column_name='billing_address', sdtype='street_address'
)
Expand Down Expand Up @@ -1345,7 +1346,7 @@ def test_remove_column_alternate_key():
def test_loading_invalid_single_table_metadata():
"""Test loading invalid single table metadata dict."""
# Setup
_, metadata = download_demo(modality='multi_table', dataset_name='fake_hotels')
_, metadata = download_test_demo(modality='multi_table', dataset_name='fake_hotels')
metadata_dict = metadata.to_dict()
metadata_dict['tables']['guests']['invalid_key'] = {'value1': True, 'value2': False}
expected_error = re.escape(
Expand Down Expand Up @@ -1575,3 +1576,41 @@ def test_add_relationship_pk_to_pk(
'child_foreign_key': child_foreign_key,
}
]


def test_add_column_relationship_fails_with_primary_key_column():
"""Test that adding a column relationship fails if the column is part of the primary key.

This test also adds a `billing` mutation to the column relationship types
for `SingleTableMetadata`. The error that is being raised otherwise
is `ImportError` instead of `InvalidMetadataError`.
"""
# Setup
data, metadata = download_test_demo(modality='single_table', dataset_name='fake_hotel_guests')
metadata.update_column(column_name='billing_address', sdtype='street_address')
metadata.set_primary_key(['guest_email', 'billing_address'])
expected_msg = "Cannot use primary key 'billing_address' in column relationship."
SingleTableMetadata._COLUMN_RELATIONSHIP_TYPES['billing'] = Mock()

# Run and Assert
with pytest.raises(InvalidMetadataError, match=expected_msg):
metadata.add_column_relationship(
column_names=['billing_address'], relationship_type='billing'
)

# Test cleanup: remove 'billing' from the class-level relationship types.
# Without this, the mutation would leak into later SingleTableMetadata instances.
SingleTableMetadata._COLUMN_RELATIONSHIP_TYPES.pop('billing')


def test_metadata_fails_for_relationship_with_set_primary_key_column_in_relationship():
"""Test metadata set_primary_key fails if a column relationship includes primary key column."""
# Setup
data, metadata = download_test_demo(modality='single_table', dataset_name='fake_hotel_guests')
metadata.update_column(column_name='billing_address', sdtype='street_address')
metadata.add_column_relationship(column_names=['billing_address'], relationship_type='address')
expected_msg = r"Cannot set primary key '.*' because it is part of a column relationship\."

# Run and Assert
with pytest.raises(InvalidMetadataError, match=expected_msg):
metadata.set_primary_key(['guest_email', 'billing_address'])
70 changes: 70 additions & 0 deletions tests/unit/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1848,6 +1848,30 @@ def test_set_primary_key(self):
# Assert
assert instance.primary_key == 'column'

@pytest.mark.parametrize('primary_key', ['column_b', ['column_b', 'column_c']])
def test_set_primary_key_column_in_relationship(self, primary_key):
"""Test that ``set_primary_key`` raises an error if the column is in relationship."""
# Setup
instance = SingleTableMetadata()
instance.columns = {
'column': {'sdtype': 'id'},
'column_b': {'sdtype': 'address'},
'column_c': {'sdtype': 'address'},
}
instance.column_relationships = [
{
'column_names': ['column_b', 'column_c'],
'type': 'address',
}
]

error_msg_pattern = (
r"Cannot set primary key '.*' because it is part of a column relationship\."
)
# Run and Assert
with pytest.raises(InvalidMetadataError, match=error_msg_pattern):
instance.set_primary_key(primary_key)

def test_set_primary_key_singleton_composite_key(self):
"""Test a composite key with one element is set as a single primary key."""
# Setup
Expand Down Expand Up @@ -2352,6 +2376,52 @@ def test__validate_column_relationship_with_other_relationships(self):
relationship_invalid, column_relationships
)

def test__validate_column_relationship_raises_import_error(self, recwarn):
"""Test that ``_validate_column_relationship`` raises an `ImportError`."""
# Setup
instance = SingleTableMetadata()
relationship = {'type': 'address', 'column_names': ['a', 'b']}
instance.columns = {
'a': {'sdtype': 'street_address'},
'b': {'sdtype': 'city'},
'c': {'sdtype': 'datetime'},
}

# Run
with pytest.raises(ImportError):
instance._validate_column_relationship(relationship)

# Assert
assert len(recwarn) == 1
warning_msg = recwarn.pop(UserWarning)
expected_msg = (
"The metadata contains a column relationship of type 'address' "
'which requires the address add-on. '
'This relationship will be ignored. For higher quality data in this'
' relationship, please inquire about the SDV Enterprise tier.'
)
assert str(warning_msg.message) == expected_msg

@pytest.mark.parametrize('primary_key', ['a', ['a', 'b']])
def test__validate_column_relationship_column_belongs_to_primary_key(self, primary_key):
"""Test validation fails for columns that are in the primary keys."""
# Setup
instance = SingleTableMetadata()
mock_relationship_validation = Mock()
instance._COLUMN_RELATIONSHIP_TYPES = {'mock_relationship': mock_relationship_validation}
relationship = {'type': 'mock_relationship', 'column_names': ['a', 'b']}
instance.columns = {
'a': {'sdtype': 'street_address'},
'b': {'sdtype': 'street_address'},
'c': {'sdtype': 'datetime'},
}
instance.set_primary_key(primary_key)

expected_message = "Cannot use primary key 'a' in column relationship."
# Run and Assert
with pytest.raises(InvalidMetadataError, match=expected_message):
instance._validate_column_relationship(relationship)

def test__validate_all_column_relationships(self):
"""Test ``_validate_all_column_relationships`` method."""
# Setup
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,7 +1593,7 @@ def test_add_constraints(self, mock_validate_constraints, mock_programmable_cons
"""Test adding data constraints to the synthesizer."""
# Setup
instance = Mock()
del instance._composite_keys_metadata
instance._composite_keys_metadata = None
original_metadata = get_multi_table_metadata()
instance.metadata = original_metadata
instance._original_metadata = original_metadata
Expand Down Expand Up @@ -1646,7 +1646,7 @@ def test_add_constraints_single_table_overlap(self, mock_validate_constraints):
"""Test adding overlapping single-table constraints to the synthesizer."""
# Setup
instance = Mock()
del instance._composite_keys_metadata
instance._composite_keys_metadata = None
original_metadata = get_multi_table_metadata()
instance.metadata = original_metadata
instance._original_metadata = original_metadata
Expand Down Expand Up @@ -1695,7 +1695,7 @@ def test_updating_constraints_keeps_original_metadata(self, mock_validate_constr
"""Test adding data constraints to the synthesizer."""
# Setup
instance = Mock()
del instance._composite_keys_metadata
instance._composite_keys_metadata = None
delattr(instance, 'constraints')
metadata = get_multi_table_metadata()
original_metadata = Mock()
Expand Down
Loading