diff --git a/sdv/cag/_utils.py b/sdv/cag/_utils.py index 9b7c78a43..16b9d1bfc 100644 --- a/sdv/cag/_utils.py +++ b/sdv/cag/_utils.py @@ -4,11 +4,30 @@ import numpy as np import pandas as pd +from sdv._utils import _cast_to_iterable from sdv.cag._errors import ConstraintNotMetError from sdv.errors import RefitWarning, SynthesizerInputError, TableNameError from sdv.metadata import Metadata +def _validate_columns_not_primary_key(table_name, columns, metadata): + """Validate that none of the columns are in the primary key for the table.""" + primary_key = metadata.tables[table_name].primary_key + if metadata.tables[table_name]._primary_key_is_composite: + key_columns = set(primary_key).intersection(set(columns)) + if key_columns: + pk_columns = "', '".join(sorted(key_columns)) + raise ConstraintNotMetError( + f"Cannot apply constraint because ['{pk_columns}'] are " + f"part of the primary key for table '{table_name}'." + ) + elif primary_key in columns: + raise ConstraintNotMetError( + f"Cannot apply constraint because '{primary_key}' is the " + f"primary key of table '{table_name}'." + ) + + def _validate_columns_in_metadata(table_name, columns, metadata): """Validates that the columns are in the metadata. @@ -137,9 +156,9 @@ def _remove_columns_from_metadata(metadata, table_name, columns_to_drop): if isinstance(metadata, Metadata): metadata = metadata.to_dict() column_set = set(columns_to_drop) - primary_key = metadata['tables'][table_name].get('primary_key') + primary_key = _cast_to_iterable(metadata['tables'][table_name].get('primary_key')) for column in column_set: - if primary_key and primary_key == column: + if primary_key and column in primary_key: raise ValueError('Cannot remove primary key from Metadata') del metadata['tables'][table_name]['columns'][column] diff --git a/sdv/cag/fixed_combinations.py b/sdv/cag/fixed_combinations.py index 84780f015..aa9d71334 100644 --- a/sdv/cag/fixed_combinations.py +++ b/sdv/cag/fixed_combinations.py @@ -11,6 +11,7 @@ _get_is_valid_dict, _is_list_of_type, _remove_columns_from_metadata, + _validate_columns_not_primary_key, _validate_table_and_column_names, _validate_table_name_if_defined, ) @@ -67,6 +68,7 @@ def _validate_constraint_with_metadata(self, metadata): """ _validate_table_and_column_names(self.table_name, self.column_names, metadata) table_name = self._get_single_table_name(metadata) + _validate_columns_not_primary_key(table_name, self.column_names, metadata) for column in self.column_names: col_sdtype = metadata.tables[table_name].columns[column]['sdtype'] if col_sdtype not in ['boolean', 'categorical']: diff --git a/sdv/cag/fixed_increments.py b/sdv/cag/fixed_increments.py index 4ef002c55..4858da229 100644 --- a/sdv/cag/fixed_increments.py +++ b/sdv/cag/fixed_increments.py @@ -7,6 +7,7 @@ from sdv.cag._utils import ( _get_is_valid_dict, _remove_columns_from_metadata, + _validate_columns_not_primary_key, _validate_table_and_column_names, _validate_table_name_if_defined, ) @@ -67,6 +68,7 @@ def _validate_constraint_with_metadata(self, metadata): self.table_name, columns=[self.column_name], metadata=metadata ) table_name = self._get_single_table_name(metadata) + _validate_columns_not_primary_key(table_name, [self.column_name], metadata) col_sdtype = metadata.tables[table_name].columns[self.column_name]['sdtype'] if col_sdtype != 'numerical': raise ConstraintNotMetError( diff --git a/sdv/cag/inequality.py b/sdv/cag/inequality.py index 7ce300496..6fd2c2beb 100644 --- a/sdv/cag/inequality.py +++ b/sdv/cag/inequality.py @@ -10,6 +10,7 @@ _get_is_valid_dict, _is_list_of_type, _remove_columns_from_metadata, + _validate_columns_not_primary_key, _validate_table_and_column_names, _validate_table_name_if_defined, ) @@ -93,6 +94,7 @@ def _validate_constraint_with_metadata(self, metadata): columns = [self._low_column_name, self._high_column_name] _validate_table_and_column_names(self.table_name, columns, metadata) table_name = self._get_single_table_name(metadata) + _validate_columns_not_primary_key(table_name, columns, metadata) for column in columns: col_sdtype = metadata.tables[table_name].columns[column]['sdtype'] if col_sdtype not in ['numerical', 'datetime']: diff --git a/sdv/cag/one_hot_encoding.py b/sdv/cag/one_hot_encoding.py index a38979bef..181d77df8 100644 --- a/sdv/cag/one_hot_encoding.py +++ b/sdv/cag/one_hot_encoding.py @@ -10,6 +10,7 @@ _get_is_valid_dict, _is_list_of_type, _remove_columns_from_metadata, + _validate_columns_not_primary_key, _validate_table_and_column_names, _validate_table_name_if_defined, ) @@ -73,6 +74,8 @@ def _validate_constraint_with_metadata(self, metadata): If any of the validations fail. """ _validate_table_and_column_names(self.table_name, self._column_names, metadata) + table_name = self._get_single_table_name(metadata) + _validate_columns_not_primary_key(table_name, self._column_names, metadata) def _get_valid_table_data(self, table_data): one_hot_data = table_data[self._column_names] diff --git a/sdv/cag/range.py b/sdv/cag/range.py index 434a9d3a5..91602b214 100644 --- a/sdv/cag/range.py +++ b/sdv/cag/range.py @@ -12,6 +12,7 @@ _get_is_valid_dict, _is_list_of_type, _remove_columns_from_metadata, + _validate_columns_not_primary_key, _validate_table_and_column_names, _validate_table_name_if_defined, ) @@ -126,6 +127,7 @@ def _validate_constraint_with_metadata(self, metadata): columns = [self._low_column_name, self._middle_column_name, self._high_column_name] _validate_table_and_column_names(self.table_name, columns, metadata) table_name = self._get_single_table_name(metadata) + _validate_columns_not_primary_key(table_name, columns, metadata) for column in columns: col_sdtype = metadata.tables[table_name].columns[column]['sdtype'] if col_sdtype not in ['numerical', 'datetime']: diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 2e7ec1021..5e3d2575e 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -2906,7 +2906,11 @@ def test_1_to_1_or_0_not_superset(self, data_metadata_1_to_1_or_0): child_foreign_key='user_id', ) synthesizer = HMASynthesizer(metadata=metadata, verbose=False) - match_ = re.escape("Error: foreign key column 'user_id' contains unknown references: (9).") + match_ = re.escape( + "Error: foreign key column 'user_id' contains unknown references:\n" + ' user_id\n' + '9 9\n' + ) # Run and Assert with pytest.raises(InvalidDataError, match=match_): diff --git a/tests/unit/cag/test__utils.py b/tests/unit/cag/test__utils.py index 0b498dfba..4a58da1f2 100644 --- a/tests/unit/cag/test__utils.py +++ b/tests/unit/cag/test__utils.py @@ -12,6 +12,7 @@ _filter_old_style_constraints, _is_list_of_type, _remove_columns_from_metadata, + _validate_columns_not_primary_key, _validate_constraints, _validate_constraints_single_table, _validate_table_and_column_names, @@ -22,6 +23,34 @@ from sdv.metadata.metadata import Metadata +def test__validate_columns_not_primary_key(): + """Test validating columns do not appear in primary key.""" + # Setup + metadata = Metadata.load_from_dict({ + 'tables': { + 'table': {'primary_key': 'col1'}, + 'composite_table': { + 'primary_key': ['col1', 'col2'], + }, + } + }) + columns = ['col1', 'col2', 'col3'] + expected_single_key_error = re.escape( + "Cannot apply constraint because 'col1' is the primary key of table 'table'." + ) + expected_composite_key_error = re.escape( + "Cannot apply constraint because ['col1', 'col2'] are " + "part of the primary key for table 'composite_table'." + ) + + # Run and Assert + with pytest.raises(ConstraintNotMetError, match=expected_single_key_error): + _validate_columns_not_primary_key('table', columns, metadata) + + with pytest.raises(ConstraintNotMetError, match=expected_composite_key_error): + _validate_columns_not_primary_key('composite_table', columns, metadata) + + def test__validate_table_and_column_names(): """Test `_validate_table_and_column_names` method.""" # Setup @@ -193,6 +222,9 @@ def test__remove_columns_from_metadata_raises_pk(): 'primary_key': 'id', 'columns': {'id': {'sdtype': 'id'}}, }, + 'child': { + 'primary_key': ['pk1', 'pk2'], + }, }, 'relationships': [ { @@ -212,6 +244,12 @@ def test__remove_columns_from_metadata_raises_pk(): table_name='parent', columns_to_drop=['id'], ) + with pytest.raises(ValueError, match=cannot_remove_pk): + _remove_columns_from_metadata( + metadata=original_metadata, + table_name='child', + columns_to_drop=['pk1'], + ) def test__remove_columns_from_metadata_multiple_duplicate_columns(): diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index e174e90c1..6cc3651c1 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -1597,6 +1597,7 @@ def test_add_constraints(self, mock_validate_constraints, mock_programmable_cons original_metadata = get_multi_table_metadata() instance.metadata = original_metadata instance._original_metadata = original_metadata + instance._composite_keys_metadata = None instance.constraints = [] instance._single_table_constraints = [] constraint1 = Mock() @@ -1649,6 +1650,7 @@ def test_add_constraints_single_table_overlap(self, mock_validate_constraints): original_metadata = get_multi_table_metadata() instance.metadata = original_metadata instance._original_metadata = original_metadata + instance._composite_keys_metadata = None instance.constraints = [] instance._single_table_constraints = [] constraint1 = Mock() @@ -1698,6 +1700,7 @@ def test_updating_constraints_keeps_original_metadata(self, mock_validate_constr metadata = get_multi_table_metadata() original_metadata = Mock() instance._original_metadata = original_metadata + instance._composite_keys_metadata = None instance.metadata = metadata constraint1 = Mock() constraint2 = Mock()