diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 0c9df7704..91a336e17 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -219,6 +219,7 @@ def __init__(self): self.column_relationships = [] self._version = self.METADATA_SPEC_VERSION self._updated = False + self._valid_column_relationships = [] @property def _primary_key_is_composite(self): @@ -836,6 +837,20 @@ def _validate_key(self, column_name, key_type): self._validate_keys_sdtype([column_name], key_type) + def _validate_primary_key_not_in_column_relationship(self, primary_key_candidate): + if isinstance(primary_key_candidate, list): + primary_key_candidate = set(primary_key_candidate) + else: + primary_key_candidate = {primary_key_candidate} + + for column_relationship in self.column_relationships: + column_names = set(column_relationship['column_names']) + if column_names.intersection(primary_key_candidate): + raise InvalidMetadataError( + f"Cannot set primary key '{primary_key_candidate}' because it is part " + 'of a column relationship.' + ) + def set_primary_key(self, column_name): """Set the metadata primary key. @@ -847,6 +862,7 @@ def set_primary_key(self, column_name): column_name = column_name[0] self._validate_key(column_name, 'primary') + self._validate_primary_key_not_in_column_relationship(column_name) if column_name in self.alternate_keys: warnings.warn( f"'{column_name}' is currently set as an alternate key and will be removed from " @@ -1004,11 +1020,17 @@ def _validate_column_relationship(self, relationship): f'Must be one of {list(self._COLUMN_RELATIONSHIP_TYPES.keys())}.' ) + primary_keys = set() + if isinstance(self.primary_key, list): + primary_keys = set(self.primary_key) + elif self.primary_key: + primary_keys = {self.primary_key} + errors = [] for column in column_names: if column not in self.columns: errors.append(f"Column '{column}' not in metadata.") - elif self.primary_key == column: + if column in primary_keys: errors.append(f"Cannot use primary key '{column}' in column relationship.") columns_to_sdtypes = { diff --git a/tests/integration/metadata/test_metadata.py b/tests/integration/metadata/test_metadata.py index 95578c244..bd1ea7c91 100644 --- a/tests/integration/metadata/test_metadata.py +++ b/tests/integration/metadata/test_metadata.py @@ -1,17 +1,18 @@ import os import re from copy import deepcopy +from unittest.mock import Mock import pandas as pd import pytest -from sdv.datasets.demo import download_demo from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata import Metadata from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata from sdv.multi_table.hma import HMASynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer +from tests.utils import download_test_demo DEFAULT_TABLE_NAME = 'table' @@ -68,7 +69,7 @@ def test_load_from_json_single_table_metadata(tmp_path): def test_detect_from_dataframes_multi_table(): """Test the ``detect_from_dataframes`` method works with multi-table.""" # Setup - real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + real_data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') # Run metadata = Metadata.detect_from_dataframes(real_data) @@ -124,7 +125,7 @@ def test_detect_from_dataframes_multi_table(): def test_detect_from_dataframes_multi_table_without_infer_sdtypes(): """Test it when infer_sdtypes is False.""" # Setup - real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + real_data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') # Run metadata = Metadata.detect_from_dataframes(real_data, infer_sdtypes=False) @@ -180,7 +181,7 @@ def test_detect_from_dataframes_multi_table_without_infer_sdtypes(): def test_detect_from_dataframes_multi_table_with_infer_keys_primary_only(): """Test it when infer_keys is 'primary_only'.""" # Setup - real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + real_data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') # Run metadata = Metadata.detect_from_dataframes(real_data, infer_keys='primary_only') @@ -229,7 +230,7 @@ def test_detect_from_dataframes_multi_table_with_infer_keys_primary_only(): def test_detect_from_dataframes_multi_table_with_infer_keys_none(): """Test it when infer_keys is None.""" # Setup - real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + real_data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') # Run metadata = Metadata.detect_from_dataframes(real_data, infer_keys=None) @@ -276,7 +277,7 @@ def test_detect_from_dataframes_multi_table_with_infer_keys_none(): def test_detect_from_dataframes_single_table(): """Test the ``detect_from_dataframes`` method works with a single table.""" # Setup - data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') metadata = Metadata.detect_from_dataframes({'table_1': data['hotels']}) # Run @@ -305,7 +306,7 @@ def test_detect_from_dataframes_single_table(): def test_detect_from_dataframes_single_table_infer_sdtypes_false(): """Test it for a single table when infer_sdtypes is False.""" # Setup - data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') metadata = Metadata.detect_from_dataframes({'table_1': data['hotels']}, infer_sdtypes=False) # Run @@ -334,7 +335,7 @@ def test_detect_from_dataframes_single_table_infer_sdtypes_false(): def test_detect_from_dataframes_single_table_infer_keys_primary_only(): """Test it for a single table when infer_keys is 'primary_only'.""" # Setup - data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') metadata = Metadata.detect_from_dataframes( {'table_1': data['hotels']}, infer_keys='primary_only' ) @@ -365,7 +366,7 @@ def test_detect_from_dataframes_single_table_infer_keys_primary_only(): def test_detect_from_dataframes_single_table_infer_keys_none(): """Test it for a single table when infer_keys is None.""" # Setup - data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') metadata = Metadata.detect_from_dataframes({'table_1': data['hotels']}, infer_keys=None) # Run @@ -393,7 +394,7 @@ def test_detect_from_dataframes_single_table_infer_keys_none(): def test_detect_from_dataframe(): """Test that a single table can be detected as a DataFrame.""" # Setup - data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') metadata = Metadata.detect_from_dataframe(data['hotels']) @@ -423,7 +424,7 @@ def test_detect_from_dataframe(): def test_detect_from_dataframe_infer_sdtypes_false(): """Test it when infer_sdtypes is False.""" # Setup - data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') metadata = Metadata.detect_from_dataframe(data['hotels'], infer_sdtypes=False) # Run @@ -452,7 +453,7 @@ def test_detect_from_dataframe_infer_sdtypes_false(): def test_detect_from_dataframe_infer_keys_none(): """Test it when infer_keys is None.""" # Setup - data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') metadata = Metadata.detect_from_dataframe(data['hotels'], infer_keys=None) # Run @@ -480,7 +481,7 @@ def test_detect_from_dataframe_infer_keys_none(): def test_detect_from_dataframe_infer_keys_none_infer_sdtypes_false(): """Test it when infer_keys is None and infer_sdtypes is False.""" # Setup - data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') metadata = Metadata.detect_from_dataframe(data['hotels'], infer_keys=None, infer_sdtypes=False) # Run @@ -508,7 +509,7 @@ def test_detect_from_dataframe_infer_keys_none_infer_sdtypes_false(): def test_detect_from_csvs(tmp_path): """Test the ``detect_from_csvs`` method.""" # Setup - real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + real_data, _ = download_test_demo(modality='multi_table', dataset_name='fake_hotels') metadata = Metadata() @@ -571,7 +572,7 @@ def test_detect_from_csvs(tmp_path): def test_single_table_compatibility(tmp_path): """Test if SingleTableMetadata still has compatibility with single table synthesizers.""" # Setup - data, _ = download_demo('single_table', 'fake_hotel_guests') + data, _ = download_test_demo('single_table', 'fake_hotel_guests') warn_msg = ( "The 'SingleTableMetadata' is deprecated. Please use the new " "'Metadata' class for synthesizers." @@ -623,7 +624,7 @@ def test_single_table_compatibility(tmp_path): def test_multi_table_compatibility(tmp_path): """Test if MultiTableMetadata still has compatibility with multi table synthesizers.""" # Setup - data, _ = download_demo('multi_table', 'fake_hotels') + data, _ = download_test_demo('multi_table', 'fake_hotels') warn_msg = re.escape( "The 'MultiTableMetadata' is deprecated. Please use the new " "'Metadata' class for synthesizers." @@ -741,7 +742,7 @@ def test_multi_table_compatibility(tmp_path): def test_any_metadata_update_single_table(method, args, kwargs): """Test that any method that updates metadata works for single-table case.""" # Setup - _, metadata = download_demo('single_table', 'fake_hotel_guests') + _, metadata = download_test_demo('single_table', 'fake_hotel_guests') metadata.update_column( table_name='fake_hotel_guests', column_name='billing_address', sdtype='street_address' ) @@ -764,7 +765,7 @@ def test_any_metadata_update_single_table(method, args, kwargs): def test_any_metadata_update_multi_table(method, args, kwargs): """Test that any method that updates metadata works for multi-table case.""" # Setup - _, metadata = download_demo('multi_table', 'fake_hotels') + _, metadata = download_test_demo('multi_table', 'fake_hotels') metadata.update_column( table_name='guests', column_name='billing_address', sdtype='street_address' ) @@ -1345,7 +1346,7 @@ def test_remove_column_alternate_key(): def test_loading_invalid_single_table_metadata(): """Test loading invalid single table metadata dict.""" # Setup - _, metadata = download_demo(modality='multi_table', dataset_name='fake_hotels') + _, metadata = download_test_demo(modality='multi_table', dataset_name='fake_hotels') metadata_dict = metadata.to_dict() metadata_dict['tables']['guests']['invalid_key'] = {'value1': True, 'value2': False} expected_error = re.escape( @@ -1575,3 +1576,41 @@ def test_add_relationship_pk_to_pk( 'child_foreign_key': child_foreign_key, } ] + + +def test_add_column_relationship_fails_with_primary_key_column(): + """Test that adding a column relationship fails if the column is part of the primary key. + + This test also adds a `billing` mutation to the column relationship types + for `SingleTableMetadata`. The error that is being raised otherwise + is `ImportError` instead of `InvalidMetadataError`. + """ + # Setup + data, metadata = download_test_demo(modality='single_table', dataset_name='fake_hotel_guests') + metadata.update_column(column_name='billing_address', sdtype='street_address') + metadata.set_primary_key(['guest_email', 'billing_address']) + expected_msg = "Cannot use primary key 'billing_address' in column relationship." + SingleTableMetadata._COLUMN_RELATIONSHIP_TYPES['billing'] = Mock() + + # Run and Assert + with pytest.raises(InvalidMetadataError, match=expected_msg): + metadata.add_column_relationship( + column_names=['billing_address'], relationship_type='billing' + ) + + # Test cleanup: remove 'billing' from the class-level relationship types. + # Without this, the mutation would leak into later SingleTableMetadata instances. + SingleTableMetadata._COLUMN_RELATIONSHIP_TYPES.pop('billing') + + +def test_metadata_fails_for_relationship_with_set_primary_key_column_in_relationship(): + """Test metadata set_primary_key fails if a column relationship includes primary key column.""" + # Setup + data, metadata = download_test_demo(modality='single_table', dataset_name='fake_hotel_guests') + metadata.update_column(column_name='billing_address', sdtype='street_address') + metadata.add_column_relationship(column_names=['billing_address'], relationship_type='address') + expected_msg = r"Cannot set primary key '.*' because it is part of a column relationship\." + + # Run and Assert + with pytest.raises(InvalidMetadataError, match=expected_msg): + metadata.set_primary_key(['guest_email', 'billing_address']) diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index eed31a2cc..e4b7a71ce 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -1848,6 +1848,30 @@ def test_set_primary_key(self): # Assert assert instance.primary_key == 'column' + @pytest.mark.parametrize('primary_key', ['column_b', ['column_b', 'column_c']]) + def test_set_primary_key_column_in_relationship(self, primary_key): + """Test that ``set_primary_key`` raises an error if the column is in relationship.""" + # Setup + instance = SingleTableMetadata() + instance.columns = { + 'column': {'sdtype': 'id'}, + 'column_b': {'sdtype': 'address'}, + 'column_c': {'sdtype': 'address'}, + } + instance.column_relationships = [ + { + 'column_names': ['column_b', 'column_c'], + 'type': 'address', + } + ] + + error_msg_pattern = ( + r"Cannot set primary key '.*' because it is part of a column relationship\." + ) + # Run and Assert + with pytest.raises(InvalidMetadataError, match=error_msg_pattern): + instance.set_primary_key(primary_key) + def test_set_primary_key_singleton_composite_key(self): """Test a composite key with one element is set as a single primary key.""" # Setup @@ -2352,6 +2376,52 @@ def test__validate_column_relationship_with_other_relationships(self): relationship_invalid, column_relationships ) + def test__validate_column_relationship_raises_import_error(self, recwarn): + """Test that ``_validate_column_relationship`` raises an `ImportError`.""" + # Setup + instance = SingleTableMetadata() + relationship = {'type': 'address', 'column_names': ['a', 'b']} + instance.columns = { + 'a': {'sdtype': 'street_address'}, + 'b': {'sdtype': 'city'}, + 'c': {'sdtype': 'datetime'}, + } + + # Run + with pytest.raises(ImportError): + instance._validate_column_relationship(relationship) + + # Assert + assert len(recwarn) == 1 + warning_msg = recwarn.pop(UserWarning) + expected_msg = ( + "The metadata contains a column relationship of type 'address' " + 'which requires the address add-on. ' + 'This relationship will be ignored. For higher quality data in this' + ' relationship, please inquire about the SDV Enterprise tier.' + ) + assert str(warning_msg.message) == expected_msg + + @pytest.mark.parametrize('primary_key', ['a', ['a', 'b']]) + def test__validate_column_relationship_column_belongs_to_primary_key(self, primary_key): + """Test validation fails for columns that are in the primary keys.""" + # Setup + instance = SingleTableMetadata() + mock_relationship_validation = Mock() + instance._COLUMN_RELATIONSHIP_TYPES = {'mock_relationship': mock_relationship_validation} + relationship = {'type': 'mock_relationship', 'column_names': ['a', 'b']} + instance.columns = { + 'a': {'sdtype': 'street_address'}, + 'b': {'sdtype': 'street_address'}, + 'c': {'sdtype': 'datetime'}, + } + instance.set_primary_key(primary_key) + + expected_message = "Cannot use primary key 'a' in column relationship." + # Run and Assert + with pytest.raises(InvalidMetadataError, match=expected_message): + instance._validate_column_relationship(relationship) + def test__validate_all_column_relationships(self): """Test ``_validate_all_column_relationships`` method.""" # Setup diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 6cc3651c1..fec920ede 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -1593,7 +1593,7 @@ def test_add_constraints(self, mock_validate_constraints, mock_programmable_cons """Test adding data constraints to the synthesizer.""" # Setup instance = Mock() - del instance._composite_keys_metadata + instance._composite_keys_metadata = None original_metadata = get_multi_table_metadata() instance.metadata = original_metadata instance._original_metadata = original_metadata @@ -1646,7 +1646,7 @@ def test_add_constraints_single_table_overlap(self, mock_validate_constraints): """Test adding overlapping single-table constraints to the synthesizer.""" # Setup instance = Mock() - del instance._composite_keys_metadata + instance._composite_keys_metadata = None original_metadata = get_multi_table_metadata() instance.metadata = original_metadata instance._original_metadata = original_metadata @@ -1695,7 +1695,7 @@ def test_updating_constraints_keeps_original_metadata(self, mock_validate_constr """Test adding data constraints to the synthesizer.""" # Setup instance = Mock() - del instance._composite_keys_metadata + instance._composite_keys_metadata = None delattr(instance, 'constraints') metadata = get_multi_table_metadata() original_metadata = Mock() diff --git a/tests/utils.py b/tests/utils.py index e64c09eb2..84a2898d5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,9 +1,12 @@ """Utils for testing.""" import contextlib +from copy import deepcopy +from functools import lru_cache import pandas as pd +from sdv.datasets.demo import download_demo from sdv.logging import get_sdv_logger from sdv.metadata.metadata import Metadata from sdv.multi_table import HMASynthesizer @@ -163,3 +166,21 @@ def run_hma(data, metadata, constraints=None): synthesizer.fit(data) return synthesizer + + +@lru_cache +def _download_demo(modality, dataset_name): + return download_demo(modality, dataset_name) + + +def download_test_demo(modality, dataset_name): + """Download demo datasets with caching. + + Args: + modality: + The modality of the dataset: 'single_table', 'multi_table', 'sequential'. + dataset_name: + Name of the dataset to download. + """ + data, metadata = _download_demo(modality, dataset_name) + return deepcopy(data), deepcopy(metadata)