Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions sdv/cag/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,30 @@
import numpy as np
import pandas as pd

from sdv._utils import _cast_to_iterable
from sdv.cag._errors import ConstraintNotMetError
from sdv.errors import RefitWarning, SynthesizerInputError, TableNameError
from sdv.metadata import Metadata


def _validate_columns_not_primary_key(table_name, columns, metadata):
"""Validate that none of the columns are in the primary key for the table."""
primary_key = metadata.tables[table_name].primary_key
if metadata.tables[table_name]._primary_key_is_composite:
key_columns = set(primary_key).intersection(set(columns))
if key_columns:
pk_columns = "', '".join(sorted(key_columns))
raise ConstraintNotMetError(
f"Cannot apply constraint because ['{pk_columns}'] are "
f"part of the primary key for table '{table_name}'."
)
elif primary_key in columns:
raise ConstraintNotMetError(
f"Cannot apply constraint because '{primary_key}' is the "
f"primary key of table '{table_name}'."
)


def _validate_columns_in_metadata(table_name, columns, metadata):
"""Validates that the columns are in the metadata.

Expand Down Expand Up @@ -137,9 +156,9 @@ def _remove_columns_from_metadata(metadata, table_name, columns_to_drop):
if isinstance(metadata, Metadata):
metadata = metadata.to_dict()
column_set = set(columns_to_drop)
primary_key = metadata['tables'][table_name].get('primary_key')
primary_key = _cast_to_iterable(metadata['tables'][table_name].get('primary_key'))
for column in column_set:
if primary_key and primary_key == column:
if primary_key and column in primary_key:
raise ValueError('Cannot remove primary key from Metadata')
del metadata['tables'][table_name]['columns'][column]

Expand Down
2 changes: 2 additions & 0 deletions sdv/cag/fixed_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_get_is_valid_dict,
_is_list_of_type,
_remove_columns_from_metadata,
_validate_columns_not_primary_key,
_validate_table_and_column_names,
_validate_table_name_if_defined,
)
Expand Down Expand Up @@ -67,6 +68,7 @@ def _validate_constraint_with_metadata(self, metadata):
"""
_validate_table_and_column_names(self.table_name, self.column_names, metadata)
table_name = self._get_single_table_name(metadata)
_validate_columns_not_primary_key(table_name, self.column_names, metadata)
for column in self.column_names:
col_sdtype = metadata.tables[table_name].columns[column]['sdtype']
if col_sdtype not in ['boolean', 'categorical']:
Expand Down
2 changes: 2 additions & 0 deletions sdv/cag/fixed_increments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sdv.cag._utils import (
_get_is_valid_dict,
_remove_columns_from_metadata,
_validate_columns_not_primary_key,
_validate_table_and_column_names,
_validate_table_name_if_defined,
)
Expand Down Expand Up @@ -67,6 +68,7 @@ def _validate_constraint_with_metadata(self, metadata):
self.table_name, columns=[self.column_name], metadata=metadata
)
table_name = self._get_single_table_name(metadata)
_validate_columns_not_primary_key(table_name, [self.column_name], metadata)
col_sdtype = metadata.tables[table_name].columns[self.column_name]['sdtype']
if col_sdtype != 'numerical':
raise ConstraintNotMetError(
Expand Down
2 changes: 2 additions & 0 deletions sdv/cag/inequality.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_get_is_valid_dict,
_is_list_of_type,
_remove_columns_from_metadata,
_validate_columns_not_primary_key,
_validate_table_and_column_names,
_validate_table_name_if_defined,
)
Expand Down Expand Up @@ -93,6 +94,7 @@ def _validate_constraint_with_metadata(self, metadata):
columns = [self._low_column_name, self._high_column_name]
_validate_table_and_column_names(self.table_name, columns, metadata)
table_name = self._get_single_table_name(metadata)
_validate_columns_not_primary_key(table_name, columns, metadata)
for column in columns:
col_sdtype = metadata.tables[table_name].columns[column]['sdtype']
if col_sdtype not in ['numerical', 'datetime']:
Expand Down
3 changes: 3 additions & 0 deletions sdv/cag/one_hot_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_get_is_valid_dict,
_is_list_of_type,
_remove_columns_from_metadata,
_validate_columns_not_primary_key,
_validate_table_and_column_names,
_validate_table_name_if_defined,
)
Expand Down Expand Up @@ -73,6 +74,8 @@ def _validate_constraint_with_metadata(self, metadata):
If any of the validations fail.
"""
_validate_table_and_column_names(self.table_name, self._column_names, metadata)
table_name = self._get_single_table_name(metadata)
_validate_columns_not_primary_key(table_name, self._column_names, metadata)

def _get_valid_table_data(self, table_data):
one_hot_data = table_data[self._column_names]
Expand Down
2 changes: 2 additions & 0 deletions sdv/cag/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_get_is_valid_dict,
_is_list_of_type,
_remove_columns_from_metadata,
_validate_columns_not_primary_key,
_validate_table_and_column_names,
_validate_table_name_if_defined,
)
Expand Down Expand Up @@ -126,6 +127,7 @@ def _validate_constraint_with_metadata(self, metadata):
columns = [self._low_column_name, self._middle_column_name, self._high_column_name]
_validate_table_and_column_names(self.table_name, columns, metadata)
table_name = self._get_single_table_name(metadata)
_validate_columns_not_primary_key(table_name, columns, metadata)
for column in columns:
col_sdtype = metadata.tables[table_name].columns[column]['sdtype']
if col_sdtype not in ['numerical', 'datetime']:
Expand Down
6 changes: 5 additions & 1 deletion tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2906,7 +2906,11 @@ def test_1_to_1_or_0_not_superset(self, data_metadata_1_to_1_or_0):
child_foreign_key='user_id',
)
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
match_ = re.escape("Error: foreign key column 'user_id' contains unknown references: (9).")
match_ = re.escape(
"Error: foreign key column 'user_id' contains unknown references:\n"
' user_id\n'
'9 9\n'
)

# Run and Assert
with pytest.raises(InvalidDataError, match=match_):
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/cag/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_filter_old_style_constraints,
_is_list_of_type,
_remove_columns_from_metadata,
_validate_columns_not_primary_key,
_validate_constraints,
_validate_constraints_single_table,
_validate_table_and_column_names,
Expand All @@ -22,6 +23,34 @@
from sdv.metadata.metadata import Metadata


def test__validate_columns_not_primary_key():
"""Test validating columns do not appear in primary key."""
# Setup
metadata = Metadata.load_from_dict({
'tables': {
'table': {'primary_key': 'col1'},
'composite_table': {
'primary_key': ['col1', 'col2'],
},
}
})
columns = ['col1', 'col2', 'col3']
expected_single_key_error = re.escape(
"Cannot apply constraint because 'col1' is the primary key of table 'table'."
)
expected_composite_key_error = re.escape(
"Cannot apply constraint because ['col1', 'col2'] are "
"part of the primary key for table 'composite_table'."
)

# Run and Assert
with pytest.raises(ConstraintNotMetError, match=expected_single_key_error):
_validate_columns_not_primary_key('table', columns, metadata)

with pytest.raises(ConstraintNotMetError, match=expected_composite_key_error):
_validate_columns_not_primary_key('composite_table', columns, metadata)


def test__validate_table_and_column_names():
"""Test `_validate_table_and_column_names` method."""
# Setup
Expand Down Expand Up @@ -193,6 +222,9 @@ def test__remove_columns_from_metadata_raises_pk():
'primary_key': 'id',
'columns': {'id': {'sdtype': 'id'}},
},
'child': {
'primary_key': ['pk1', 'pk2'],
},
},
'relationships': [
{
Expand All @@ -212,6 +244,12 @@ def test__remove_columns_from_metadata_raises_pk():
table_name='parent',
columns_to_drop=['id'],
)
with pytest.raises(ValueError, match=cannot_remove_pk):
_remove_columns_from_metadata(
metadata=original_metadata,
table_name='child',
columns_to_drop=['pk1'],
)


def test__remove_columns_from_metadata_multiple_duplicate_columns():
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,7 @@ def test_add_constraints(self, mock_validate_constraints, mock_programmable_cons
original_metadata = get_multi_table_metadata()
instance.metadata = original_metadata
instance._original_metadata = original_metadata
instance._composite_keys_metadata = None
instance.constraints = []
instance._single_table_constraints = []
constraint1 = Mock()
Expand Down Expand Up @@ -1649,6 +1650,7 @@ def test_add_constraints_single_table_overlap(self, mock_validate_constraints):
original_metadata = get_multi_table_metadata()
instance.metadata = original_metadata
instance._original_metadata = original_metadata
instance._composite_keys_metadata = None
instance.constraints = []
instance._single_table_constraints = []
constraint1 = Mock()
Expand Down Expand Up @@ -1698,6 +1700,7 @@ def test_updating_constraints_keeps_original_metadata(self, mock_validate_constr
metadata = get_multi_table_metadata()
original_metadata = Mock()
instance._original_metadata = original_metadata
instance._composite_keys_metadata = None
instance.metadata = metadata
constraint1 = Mock()
constraint2 = Mock()
Expand Down