diff --git a/sdv/_utils.py b/sdv/_utils.py index ec28669ec..1d5c983e1 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -520,3 +520,19 @@ def _validate_correct_synthesizer_loading(synthesizer, cls): def _sort_keys(keys): return sorted(keys, key=lambda key: key if isinstance(key, str) else key[0]) + + +def _get_unreferenced_keys(parent_columns, child_columns): + indicator = _create_unique_name( + '_merge', list(child_columns.columns) + list(parent_columns.columns) + ) + merged = child_columns.merge( + parent_columns, + left_on=list(child_columns.columns), + right_on=list(parent_columns.columns), + how='left', + indicator=indicator, + ) + merged = merged[merged[indicator] == 'left_only'][list(child_columns.columns)] + merged = merged.dropna(how='all') + return merged.dropna(how='all') diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index e04ea30c0..5f5e877d7 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -10,7 +10,12 @@ import pandas as pd -from sdv._utils import _cast_to_iterable, _load_data_from_csv +from sdv._utils import ( + _cast_to_iterable, + _format_invalid_values_string, + _get_unreferenced_keys, + _load_data_from_csv, +) from sdv.errors import InvalidDataError from sdv.logging import get_sdv_logger from sdv.metadata.errors import InvalidMetadataError @@ -903,22 +908,21 @@ def _validate_foreign_keys(self, data): parent_table = data.get(relation['parent_table_name']) if isinstance(child_table, pd.DataFrame) and isinstance(parent_table, pd.DataFrame): - child_column = child_table[relation['child_foreign_key']] - parent_column = parent_table[relation['parent_primary_key']] - missing_values = child_column[~child_column.isin(parent_column)].unique() - missing_values = missing_values[~pd.isna(missing_values)] - - if any(missing_values): - message = ', '.join(missing_values[:5].astype(str)) - if len(missing_values) > 5: - message = f'({message}, + more)' - else: - message = f'({message})' - + child_columns = child_table[_cast_to_iterable(relation['child_foreign_key'])] + parent_columns = parent_table[_cast_to_iterable(relation['parent_primary_key'])] + missing_values = _get_unreferenced_keys(parent_columns, child_columns) + missing_values = missing_values.drop_duplicates() + if not missing_values.empty: + foreign_key = relation['child_foreign_key'] + if not isinstance(foreign_key, list): + foreign_key = f"'{foreign_key}'" + + message = f'\n{_format_invalid_values_string(missing_values, 5)}' errors.append( - f"Error: foreign key column '{relation['child_foreign_key']}' contains " - f'unknown references: {message}. Please use the method' - " 'drop_unknown_references' from sdv.utils to clean the data." + f'Error: foreign key column {foreign_key} contains ' + f'unknown references:{message}\n' + "Please use the method 'drop_unknown_references' from sdv.utils " + 'to clean the data.' ) if errors: diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index ff8bf3058..81c32b2a1 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -137,6 +137,9 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self.metadata.validate() self._check_metadata_updated() + self._original_metadata = deepcopy(self.metadata) + self._modified_multi_table_metadata = deepcopy(self.metadata) + self._composite_keys_metadata = None self._handle_composite_keys() self.locales = locales self.verbose = False @@ -144,8 +147,6 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self._table_synthesizers = {} self._table_parameters = defaultdict(dict) self._original_table_columns = {} - self._original_metadata = deepcopy(self.metadata) - self._modified_multi_table_metadata = deepcopy(self.metadata) self.constraints = [] self._single_table_constraints = [] if synthesizer_kwargs is not None: @@ -230,6 +231,10 @@ def add_constraints(self, constraints): """ constraints = _validate_constraints(constraints, self._fitted) metadata = self.metadata + if self._composite_keys_metadata is not None: + metadata = self._composite_keys_metadata + self._modified_multi_table_metadata = self._composite_keys_metadata + multi_table_constraints = [] single_table_constraints = [] idx_single_table_constraint = self._detect_single_table_constraints(constraints) @@ -246,6 +251,9 @@ def add_constraints(self, constraints): self._modified_multi_table_metadata = metadata self.metadata = metadata + self._composite_key = None + self._handle_composite_keys() + self._validate_single_table_constraints(single_table_constraints) self.constraints += multi_table_constraints self._constraints_fitted = False @@ -355,6 +363,9 @@ def _validate_transform_constraints(self, data, enforce_constraint_fitting=False def _reverse_transform_constraints(self, sampled_data): """Reverse transform constraints after sampling.""" + if getattr(self, '_composite_keys', None): + sampled_data = self._composite_keys.reverse_transform(sampled_data) + if not hasattr(self, 'constraints'): return sampled_data @@ -454,7 +465,7 @@ def _validate_all_tables(self, data): def validate(self, data): """Validate the data. - Validate that the metadata matches the data and thta every table's constraints are valid. + Validate that the metadata matches the data and that every table's constraints are valid. Args: data (dict[str, pd.DataFrame]): @@ -463,6 +474,7 @@ def validate(self, data): errors = [] metadata = self._original_metadata metadata.validate_data(data) + data = self._validate_transform_constraints(data, enforce_constraint_fitting=True) for table_name in data: if table_name in self._table_synthesizers: # Validate rules specific to each synthesizer @@ -471,8 +483,6 @@ def validate(self, data): if errors: raise InvalidDataError(errors) - self._validate_transform_constraints(data, enforce_constraint_fitting=True) - def _validate_table_name(self, table_name): if table_name not in self._table_synthesizers: raise ValueError( @@ -574,6 +584,10 @@ def preprocess(self, data): list_of_changed_tables = self._store_and_convert_original_cols(data) self.validate(data) data = self._validate_transform_constraints(data) + if getattr(self, '_composite_keys', None): + self._composite_keys.fit(data, self._composite_keys_metadata) + data = self._composite_keys.transform(data) + if self._fitted: msg = ( 'This model has already been fitted. To use the new preprocessed data, ' diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index de037db1c..729d9fb04 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -199,7 +199,7 @@ def __init__(self, metadata, locales=['en_US'], verbose=True): self, self.metadata, self._table_synthesizers, self._table_sizes ) child_tables = set() - for relationship in metadata.relationships: + for relationship in self.metadata.relationships: child_tables.add(relationship['child_table_name']) for child_table_name in child_tables: self.set_table_parameters(child_table_name, {'default_distribution': 'norm'}) diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index f998818f6..36069d83a 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -7,7 +7,12 @@ import numpy as np import pandas as pd -from sdv._utils import MODELABLE_SDTYPES, _get_root_tables +from sdv._utils import ( + MODELABLE_SDTYPES, + _cast_to_iterable, + _get_root_tables, + _get_unreferenced_keys, +) from sdv.errors import InvalidDataError, SamplingError from sdv.multi_table import HMASynthesizer from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS @@ -393,25 +398,19 @@ def _get_rows_to_drop(data, metadata): for root in current_roots: parent_table = root relationships_parent = _get_relationships_for_parent(relationships, parent_table) - parent_column = metadata.tables[parent_table].primary_key + parent_columns = _cast_to_iterable(metadata.tables[parent_table].primary_key) valid_parent_idx = [ idx for idx in data[parent_table].index if idx not in table_to_idx_to_drop[parent_table] ] - valid_parent_values = set(data[parent_table].loc[valid_parent_idx, parent_column]) + valid_parent_values = data[parent_table].loc[valid_parent_idx, parent_columns] for relationship in relationships_parent: child_table = relationship['child_table_name'] - child_column = relationship['child_foreign_key'] - - is_nan = data[child_table][child_column].isna() - invalid_values = ( - set(data[child_table].loc[~is_nan, child_column]) - valid_parent_values - ) - invalid_rows = data[child_table][ - data[child_table][child_column].isin(invalid_values) - ] - idx_to_drop = set(invalid_rows.index) + child_foreign_key = _cast_to_iterable(relationship['child_foreign_key']) + child_columns = data[child_table][child_foreign_key] + unreferenced_rows = _get_unreferenced_keys(valid_parent_values, child_columns) + idx_to_drop = set(unreferenced_rows.index) if idx_to_drop: table_to_idx_to_drop[child_table] = table_to_idx_to_drop[child_table].union( @@ -428,8 +427,10 @@ def _get_nan_fk_indices_table(data, relationships, table): idx_with_nan_foreign_key = set() relationships_for_table = _get_relationships_for_child(relationships, table) for relationship in relationships_for_table: - child_column = relationship['child_foreign_key'] - idx_with_nan_foreign_key.update(data[table][data[table][child_column].isna()].index) + child_columns = _cast_to_iterable(relationship['child_foreign_key']) + idx_with_nan_foreign_key.update( + data[table][data[table][child_columns].isna().all(axis=1)].index + ) return idx_with_nan_foreign_key diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 166dfefb0..97d1ad1f7 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -150,9 +150,10 @@ def _validate_regex_format(self): ) _check_regex_format(self._table_name, column_name, regex) - def _handle_composite_keys(self, single_table_metadata): + def _handle_composite_keys(self): """Validates that composite keys are not used in Public SDV.""" - if single_table_metadata._primary_key_is_composite: + table_metadata = self.metadata._convert_to_single_table() + if table_metadata._primary_key_is_composite: raise SynthesizerInputError( 'Your metadata contains composite keys (primary key of table ' f"'{self._table_name}' has multiple columns). Composite keys are " @@ -181,11 +182,11 @@ def __init__( self.metadata.validate() self._check_metadata_updated() - single_table_metadata = self.metadata._convert_to_single_table() - self._handle_composite_keys(single_table_metadata) # Points to a metadata object that conserves the initialized status of the synthesizer self._original_metadata = deepcopy(self.metadata) + self._handle_composite_keys() + single_table_metadata = self.metadata._convert_to_single_table() self.enforce_min_max_values = enforce_min_max_values self.enforce_rounding = enforce_rounding @@ -465,13 +466,15 @@ def add_constraints(self, constraints): constraints (list): A list of constraints to apply to the synthesizer. """ + metadata = getattr(self, '_composite_keys_metadata', None) or self.metadata constraints = _validate_constraints_single_table(constraints, self._fitted) for constraint in constraints: if isinstance(constraint, ProgrammableConstraint): constraint = ProgrammableConstraintHarness(constraint) try: - self.metadata = constraint.get_updated_metadata(self.metadata) + metadata = constraint.get_updated_metadata(metadata) + self.metadata = metadata self._chained_constraints.append(constraint) self._constraints_fitted = False except ConstraintNotMetError as e: @@ -486,6 +489,7 @@ def add_constraints(self, constraints): raise e self.metadata.validate() + self._handle_composite_keys() self._data_processor = DataProcessor( metadata=self.metadata._convert_to_single_table(), enforce_rounding=self.enforce_rounding, @@ -626,6 +630,9 @@ def _preprocess_helper(self, data): warnings.warn(msg, RefitWarning) data = self._validate_transform_constraints(data) + if getattr(self, '_composite_keys', None): + self._composite_keys.fit(data, self._composite_keys_metadata) + data = self._composite_keys.transform(data) return data @@ -787,6 +794,9 @@ def load(cls, filepath): def reverse_transform_constraints(self, sampled): """Reverse transform the constraints.""" + if getattr(self, '_composite_keys', None): + sampled = self._composite_keys.reverse_transform(sampled) + if hasattr(self, '_chained_constraints') and hasattr(self, '_reject_sampling_constraints'): for constraint in reversed(self._chained_constraints): sampled = constraint.reverse_transform(sampled) diff --git a/tests/integration/utils/test_utils.py b/tests/integration/utils/test_utils.py index 79d2e8655..c42e635b2 100644 --- a/tests/integration/utils/test_utils.py +++ b/tests/integration/utils/test_utils.py @@ -60,8 +60,10 @@ def test_drop_unknown_references(metadata, data, capsys): expected_message = re.escape( 'The provided data does not match the metadata:\n' 'Relationships:\n' - "Error: foreign key column 'parent_id' contains unknown references: (5)" - ". Please use the method 'drop_unknown_references' from sdv.utils to clean the data." + "Error: foreign key column 'parent_id' contains unknown references:\n" + ' parent_id\n' + '4 5\n' + "Please use the method 'drop_unknown_references' from sdv.utils to clean the data." ) with pytest.raises(InvalidDataError, match=expected_message): metadata.validate_data(data) diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index f03863b86..b6db4f496 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1322,14 +1322,75 @@ def test__validate_foreign_keys_missing_keys(self): # Assert missing_upravna_enota = [ 'Relationships:\n' - "Error: foreign key column 'upravna_enota' contains unknown references: " - '(10, 11, 12, 13, 14, + more). ' + "Error: foreign key column 'upravna_enota' contains unknown references:\n" + ' upravna_enota\n' + '0 10\n' + '1 11\n' + '2 12\n' + '3 13\n' + '4 14\n' + '+5 more\n' "Please use the method 'drop_unknown_references' from sdv.utils to clean the data.\n" - "Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9)." - " Please use the method 'drop_unknown_references' from sdv.utils to clean the data." + "Error: foreign key column 'id_nesreca' contains unknown references:\n" + ' id_nesreca\n' + '1 1\n' + '3 3\n' + '5 5\n' + '7 7\n' + '9 9\n' + "Please use the method 'drop_unknown_references' from sdv.utils to clean the data." ] assert result == missing_upravna_enota + def test__validate_foreign_keys_missing_composite_keys(self): + """Test that errors are being returned. + + When the values of the foreign keys are not within the values of the parent + primary key, a list of errors must be returned indicating the values that are missing. + """ + # Setup + metadata = get_multi_table_metadata() + metadata.remove_relationship('nesreca', 'oseba') + metadata.add_column('id_nesreca2', 'nesreca', sdtype='id') + metadata.add_column('nesreca_fk2', 'oseba', sdtype='id') + metadata.set_primary_key(['id_nesreca', 'id_nesreca2'], 'nesreca') + metadata.add_relationship( + parent_table_name='nesreca', + child_table_name='oseba', + parent_primary_key=['id_nesreca', 'id_nesreca2'], + child_foreign_key=['id_nesreca', 'nesreca_fk2'], + ) + data = { + 'nesreca': pd.DataFrame({ + 'id_nesreca': ['id0', 'id1', 'id2', 'id3', 'id4'] * 2, + 'id_nesreca2': ['A'] * 5 + ['B'] * 5, + 'upravna_enota': np.arange(10), + }), + 'oseba': pd.DataFrame({ + 'upravna_enota': np.arange(9), + 'id_nesreca': ['id0', 'id9', 'id9'] + ['id0', 'id1', 'id2'] * 2, + 'nesreca_fk2': ['X', 'A', 'X'] + ['A', 'B'] * 3, + }), + 'upravna_enota': pd.DataFrame({ + 'id_upravna_enota': np.arange(10), + }), + } + + # Run + result = metadata._validate_foreign_keys(data) + + # Assert + missing_oseba = [ + 'Relationships:\n' + "Error: foreign key column ['id_nesreca', 'nesreca_fk2'] contains unknown references:\n" + ' id_nesreca nesreca_fk2\n' + '0 id0 X\n' + '1 id9 A\n' + '2 id9 X\n' + "Please use the method 'drop_unknown_references' from sdv.utils to clean the data." + ] + assert result == missing_oseba + def test_validate_data(self): """Test that no error is being raised when the data is valid.""" # Setup @@ -1438,7 +1499,13 @@ def test_validate_data_missing_foreign_keys(self): error_msg = re.escape( 'The provided data does not match the metadata:\n' 'Relationships:\n' - "Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9). " + "Error: foreign key column 'id_nesreca' contains unknown references:\n" + ' id_nesreca\n' + '1 1\n' + '3 3\n' + '5 5\n' + '7 7\n' + '9 9\n' "Please use the method 'drop_unknown_references' from sdv.utils to clean the data." ) with pytest.raises(InvalidDataError, match=error_msg): diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index a20ff73f6..03131ceaa 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -542,7 +542,7 @@ def test_validate(self): metadata = get_multi_table_metadata() data = get_multi_table_data() instance = BaseMultiTableSynthesizer(metadata) - instance._validate_transform_constraints = Mock() + instance._validate_transform_constraints = Mock(return_value=data) # Run instance.validate(data) @@ -659,7 +659,13 @@ def test_validate_missing_foreign_keys(self): error_msg = re.escape( 'The provided data does not match the metadata:\n' 'Relationships:\n' - "Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9). " + "Error: foreign key column 'id_nesreca' contains unknown references:\n" + ' id_nesreca\n' + '1 1\n' + '3 3\n' + '5 5\n' + '7 7\n' + '9 9\n' "Please use the method 'drop_unknown_references' from sdv.utils to clean the data." ) with pytest.raises(InvalidDataError, match=error_msg): @@ -1567,6 +1573,7 @@ def test_add_constraints(self, mock_validate_constraints, mock_programmable_cons """Test adding data constraints to the synthesizer.""" # Setup instance = Mock() + del instance._composite_keys_metadata original_metadata = get_multi_table_metadata() instance.metadata = original_metadata instance._original_metadata = original_metadata @@ -1618,6 +1625,7 @@ def test_add_constraints_single_table_overlap(self, mock_validate_constraints): """Test adding overlapping single-table constraints to the synthesizer.""" # Setup instance = Mock() + del instance._composite_keys_metadata original_metadata = get_multi_table_metadata() instance.metadata = original_metadata instance._original_metadata = original_metadata @@ -1665,6 +1673,7 @@ def test_updating_constraints_keeps_original_metadata(self, mock_validate_constr """Test adding data constraints to the synthesizer.""" # Setup instance = Mock() + del instance._composite_keys_metadata delattr(instance, 'constraints') metadata = get_multi_table_metadata() original_metadata = Mock() @@ -1821,6 +1830,7 @@ def test__validate_transform_constraints_with_constraints(self): """Test validating and transforming the data constraints.""" # Setup instance = Mock() + del instance._composite_keys data = {'table1': Mock(), 'table2': Mock()} constraint1 = Mock() constraint2 = Mock() @@ -1841,6 +1851,7 @@ def test__reverse_validate_transform_constraints_no_constraints(self): """Test reverse transforming when no data constraints have been set.""" # Setup instance = Mock() + del instance._composite_keys data = get_multi_table_data() delattr(instance, 'constraints') @@ -1859,6 +1870,7 @@ def test__reverse_validate_transform_constraints(self, drop_unknown_references): """Test reverse transforming the data constraints.""" # Setup instance = Mock() + del instance._composite_keys data = {'table1': Mock(), 'table2': Mock()} constraint1 = Mock() constraint2 = Mock() diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index 84839c19d..95dceb15b 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -104,8 +104,8 @@ def test__get_rows_to_drop(): { 'parent_table_name': 'child', 'child_table_name': 'grandchild', - 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key', + 'parent_primary_key': ['id_child1', 'id_child2'], + 'child_foreign_key': ['child_fk1', 'child_fk2'], }, { 'parent_table_name': 'parent', @@ -119,7 +119,7 @@ def test__get_rows_to_drop(): metadata.relationships = relationships metadata.tables = { 'parent': Mock(primary_key='id_parent'), - 'child': Mock(primary_key='id_child'), + 'child': Mock(primary_key=['id_child1', 'id_child2']), 'grandchild': Mock(primary_key='id_grandchild'), } @@ -130,12 +130,14 @@ def test__get_rows_to_drop(): }), 'child': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 5], - 'id_child': [5, 6, 7, 8, 9], + 'id_child1': [5, 6, 7, 8, 9], + 'id_child2': ['A', 'B', 'A', 'B', 'A'], 'B': ['Yes', 'No', 'No', 'No', 'No'], }), 'grandchild': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 6], - 'child_foreign_key': [9, 5, 11, 6, 4], + 'child_fk1': [9, 5, 11, 6, 6], + 'child_fk2': ['A', 'A', 'A', 'B', 'X'], 'C': ['Yes', 'No', 'No', 'No', 'No'], }), } @@ -1988,7 +1990,7 @@ def test__subsample_data( assert result == data -def test__subsample_data_with_null_foreing_keys(): +def test__subsample_data_with_null_foreign_keys(): """Test the ``_subsample_data`` method when there are null foreign keys.""" # Setup metadata = Metadata.load_from_dict({ diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 7d26cab74..6d2ecd581 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -727,6 +727,7 @@ def test__preprocess_helper(self, mock_warnings): instance = Mock() instance._fitted = True data = pd.DataFrame({'name': ['John', 'Doe', 'John Doe']}) + instance._composite_keys.transform.return_value = data instance._validate_transform_constraints.side_effect = lambda x: x expected_warning = ( 'This model has already been fitted. To use the new preprocessed data, please '