From 579fd6a1710eeddc78f4135c8a4adfaa233b531a Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Tue, 10 Feb 2026 14:47:25 -0500 Subject: [PATCH 1/5] Updates for Enterprise composite keys --- sdv/metadata/multi_table.py | 46 +++++++++++------ sdv/multi_table/base.py | 22 +++++--- sdv/multi_table/hma.py | 2 +- sdv/single_table/base.py | 19 +++++-- tests/unit/metadata/test_multi_table.py | 69 +++++++++++++++++++++++-- tests/unit/multi_table/test_utils.py | 12 +++-- 6 files changed, 132 insertions(+), 38 deletions(-) diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index e04ea30c0..9f86389cf 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -10,7 +10,12 @@ import pandas as pd -from sdv._utils import _cast_to_iterable, _load_data_from_csv +from sdv._utils import ( + _cast_to_iterable, + _create_unique_name, + _format_invalid_values_string, + _load_data_from_csv, +) from sdv.errors import InvalidDataError from sdv.logging import get_sdv_logger from sdv.metadata.errors import InvalidMetadataError @@ -903,22 +908,31 @@ def _validate_foreign_keys(self, data): parent_table = data.get(relation['parent_table_name']) if isinstance(child_table, pd.DataFrame) and isinstance(parent_table, pd.DataFrame): - child_column = child_table[relation['child_foreign_key']] - parent_column = parent_table[relation['parent_primary_key']] - missing_values = child_column[~child_column.isin(parent_column)].unique() - missing_values = missing_values[~pd.isna(missing_values)] - - if any(missing_values): - message = ', '.join(missing_values[:5].astype(str)) - if len(missing_values) > 5: - message = f'({message}, + more)' - else: - message = f'({message})' - + child_columns = child_table[_cast_to_iterable(relation['child_foreign_key'])] + parent_columns = parent_table[_cast_to_iterable(relation['parent_primary_key'])] + indicator = _create_unique_name( + '_merge', list(child_columns.columns) + list(parent_columns.columns) + ) + merged_columns = parent_columns.merge( + child_columns.drop_duplicates(), + left_on=list(parent_columns.columns), + right_on=list(child_columns.columns), + how='right', + indicator=indicator, + ) + missing_values = merged_columns[merged_columns[indicator] == 'right_only'] + missing_values = missing_values[list(child_columns.columns)] + if not missing_values.empty: + foreign_key = relation['child_foreign_key'] + if not isinstance(foreign_key, list): + foreign_key = f"'{foreign_key}'" + + message = f'\n{_format_invalid_values_string(missing_values, 5)}' errors.append( - f"Error: foreign key column '{relation['child_foreign_key']}' contains " - f'unknown references: {message}. Please use the method' - " 'drop_unknown_references' from sdv.utils to clean the data." + f"Error: foreign key column {foreign_key} contains " + f'unknown references:{message}\n' + "Please use the method 'drop_unknown_references' from sdv.utils " + 'to clean the data.' ) if errors: diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index ff8bf3058..297726a8c 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -137,6 +137,8 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self.metadata.validate() self._check_metadata_updated() + self._original_metadata = deepcopy(self.metadata) + self._modified_multi_table_metadata = deepcopy(self.metadata) self._handle_composite_keys() self.locales = locales self.verbose = False @@ -144,8 +146,6 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self._table_synthesizers = {} self._table_parameters = defaultdict(dict) self._original_table_columns = {} - self._original_metadata = deepcopy(self.metadata) - self._modified_multi_table_metadata = deepcopy(self.metadata) self.constraints = [] self._single_table_constraints = [] if synthesizer_kwargs is not None: @@ -229,7 +229,7 @@ def add_constraints(self, constraints): A list of constraints to apply to the synthesizer. """ constraints = _validate_constraints(constraints, self._fitted) - metadata = self.metadata + metadata = getattr(self, '_composite_keys_metadata', None) or self.metadata multi_table_constraints = [] single_table_constraints = [] idx_single_table_constraint = self._detect_single_table_constraints(constraints) @@ -243,9 +243,11 @@ def add_constraints(self, constraints): multi_table_constraints.append(constraint) metadata = constraint.get_updated_metadata(metadata) - self._modified_multi_table_metadata = metadata self.metadata = metadata + self._modified_multi_table_metadata = self.metadata + self._handle_composite_keys() + self._validate_single_table_constraints(single_table_constraints) self.constraints += multi_table_constraints self._constraints_fitted = False @@ -355,6 +357,9 @@ def _validate_transform_constraints(self, data, enforce_constraint_fitting=False def _reverse_transform_constraints(self, sampled_data): """Reverse transform constraints after sampling.""" + if getattr(self, '_composite_keys', None): + sampled_data = self._composite_keys.reverse_transform(sampled_data) + if not hasattr(self, 'constraints'): return sampled_data @@ -454,7 +459,7 @@ def _validate_all_tables(self, data): def validate(self, data): """Validate the data. - Validate that the metadata matches the data and thta every table's constraints are valid. + Validate that the metadata matches the data and that every table's constraints are valid. Args: data (dict[str, pd.DataFrame]): @@ -463,6 +468,7 @@ def validate(self, data): errors = [] metadata = self._original_metadata metadata.validate_data(data) + data = self._validate_transform_constraints(data, enforce_constraint_fitting=True) for table_name in data: if table_name in self._table_synthesizers: # Validate rules specific to each synthesizer @@ -471,8 +477,6 @@ def validate(self, data): if errors: raise InvalidDataError(errors) - self._validate_transform_constraints(data, enforce_constraint_fitting=True) - def _validate_table_name(self, table_name): if table_name not in self._table_synthesizers: raise ValueError( @@ -574,6 +578,10 @@ def preprocess(self, data): list_of_changed_tables = self._store_and_convert_original_cols(data) self.validate(data) data = self._validate_transform_constraints(data) + if getattr(self, '_composite_keys', None): + self._composite_keys.fit(data, self._composite_keys_metadata) + data = self._composite_keys.transform(data) + if self._fitted: msg = ( 'This model has already been fitted. To use the new preprocessed data, ' diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index de037db1c..729d9fb04 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -199,7 +199,7 @@ def __init__(self, metadata, locales=['en_US'], verbose=True): self, self.metadata, self._table_synthesizers, self._table_sizes ) child_tables = set() - for relationship in metadata.relationships: + for relationship in self.metadata.relationships: child_tables.add(relationship['child_table_name']) for child_table_name in child_tables: self.set_table_parameters(child_table_name, {'default_distribution': 'norm'}) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 166dfefb0..f036b2b00 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -150,9 +150,9 @@ def _validate_regex_format(self): ) _check_regex_format(self._table_name, column_name, regex) - def _handle_composite_keys(self, single_table_metadata): + def _handle_composite_keys(self): """Validates that composite keys are not used in Public SDV.""" - if single_table_metadata._primary_key_is_composite: + if self.metadata.tables[self._table_name]._primary_key_is_composite: raise SynthesizerInputError( 'Your metadata contains composite keys (primary key of table ' f"'{self._table_name}' has multiple columns). Composite keys are " @@ -181,11 +181,11 @@ def __init__( self.metadata.validate() self._check_metadata_updated() - single_table_metadata = self.metadata._convert_to_single_table() - self._handle_composite_keys(single_table_metadata) # Points to a metadata object that conserves the initialized status of the synthesizer self._original_metadata = deepcopy(self.metadata) + self._handle_composite_keys() + single_table_metadata = self.metadata._convert_to_single_table() self.enforce_min_max_values = enforce_min_max_values self.enforce_rounding = enforce_rounding @@ -465,13 +465,15 @@ def add_constraints(self, constraints): constraints (list): A list of constraints to apply to the synthesizer. """ + metadata = getattr(self, '_composite_keys_metadata', None) or self.metadata constraints = _validate_constraints_single_table(constraints, self._fitted) for constraint in constraints: if isinstance(constraint, ProgrammableConstraint): constraint = ProgrammableConstraintHarness(constraint) try: - self.metadata = constraint.get_updated_metadata(self.metadata) + metadata = constraint.get_updated_metadata(metadata) + self.metadata = metadata self._chained_constraints.append(constraint) self._constraints_fitted = False except ConstraintNotMetError as e: @@ -486,6 +488,7 @@ def add_constraints(self, constraints): raise e self.metadata.validate() + self._handle_composite_keys() self._data_processor = DataProcessor( metadata=self.metadata._convert_to_single_table(), enforce_rounding=self.enforce_rounding, @@ -626,6 +629,9 @@ def _preprocess_helper(self, data): warnings.warn(msg, RefitWarning) data = self._validate_transform_constraints(data) + if getattr(self, '_composite_keys', None): + self._composite_keys.fit(data, self._composite_keys_metadata) + data = self._composite_keys.transform(data) return data @@ -787,6 +793,9 @@ def load(cls, filepath): def reverse_transform_constraints(self, sampled): """Reverse transform the constraints.""" + if getattr(self, '_composite_keys', None): + sampled = self._composite_keys.reverse_transform(sampled) + if hasattr(self, '_chained_constraints') and hasattr(self, '_reject_sampling_constraints'): for constraint in reversed(self._chained_constraints): sampled = constraint.reverse_transform(sampled) diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index f03863b86..f4acdbcf9 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1322,14 +1322,75 @@ def test__validate_foreign_keys_missing_keys(self): # Assert missing_upravna_enota = [ 'Relationships:\n' - "Error: foreign key column 'upravna_enota' contains unknown references: " - '(10, 11, 12, 13, 14, + more). ' + "Error: foreign key column 'upravna_enota' contains unknown references:\n" + ' upravna_enota\n' + '0 10\n' + '1 11\n' + '2 12\n' + '3 13\n' + '4 14\n' + '+5 more\n' "Please use the method 'drop_unknown_references' from sdv.utils to clean the data.\n" - "Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9)." - " Please use the method 'drop_unknown_references' from sdv.utils to clean the data." + "Error: foreign key column 'id_nesreca' contains unknown references:\n" + ' id_nesreca\n' + '1 1\n' + '3 3\n' + '5 5\n' + '7 7\n' + '9 9\n' + "Please use the method 'drop_unknown_references' from sdv.utils to clean the data." ] assert result == missing_upravna_enota + def test__validate_foreign_keys_missing_composite_keys(self): + """Test that errors are being returned. + + When the values of the foreign keys are not within the values of the parent + primary key, a list of errors must be returned indicating the values that are missing. + """ + # Setup + metadata = get_multi_table_metadata() + metadata.remove_relationship('nesreca', 'oseba') + metadata.add_column('id_nesreca2', 'nesreca', sdtype='id') + metadata.add_column('nesreca_fk2', 'oseba', sdtype='id') + metadata.set_primary_key(['id_nesreca', 'id_nesreca2'], 'nesreca') + metadata.add_relationship( + parent_table_name='nesreca', + child_table_name='oseba', + parent_primary_key=['id_nesreca', 'id_nesreca2'], + child_foreign_key=['id_nesreca', 'nesreca_fk2'], + ) + data = { + 'nesreca': pd.DataFrame({ + 'id_nesreca': ['id0', 'id1', 'id2', 'id3', 'id4'] * 2, + 'id_nesreca2': ['A'] * 5 + ['B'] * 5, + 'upravna_enota': np.arange(10), + }), + 'oseba': pd.DataFrame({ + 'upravna_enota': np.arange(9), + 'id_nesreca': ['id0', 'id9', 'id9'] + ['id0', 'id1', 'id2'] * 2, + 'nesreca_fk2': ['X', 'A', 'X'] + ['A', 'B'] * 3, + }), + 'upravna_enota': pd.DataFrame({ + 'id_upravna_enota': np.arange(10), + }), + } + + # Run + result = metadata._validate_foreign_keys(data) + + # Assert + missing_oseba = [ + 'Relationships:\n' + "Error: foreign key column ['id_nesreca', 'nesreca_fk2'] contains unknown references:\n" + ' id_nesreca nesreca_fk2\n' + '0 id0 X\n' + '1 id9 A\n' + '2 id9 X\n' + "Please use the method 'drop_unknown_references' from sdv.utils to clean the data." + ] + assert result == missing_oseba + def test_validate_data(self): """Test that no error is being raised when the data is valid.""" # Setup diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index 84839c19d..3e4804865 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -104,8 +104,8 @@ def test__get_rows_to_drop(): { 'parent_table_name': 'child', 'child_table_name': 'grandchild', - 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key', + 'parent_primary_key': ['id_child1', 'id_child2'], + 'child_foreign_key': ['child_fk1', 'child_fk2'], }, { 'parent_table_name': 'parent', @@ -119,7 +119,7 @@ def test__get_rows_to_drop(): metadata.relationships = relationships metadata.tables = { 'parent': Mock(primary_key='id_parent'), - 'child': Mock(primary_key='id_child'), + 'child': Mock(primary_key=['id_child1', 'id_child2']), 'grandchild': Mock(primary_key='id_grandchild'), } @@ -130,12 +130,14 @@ def test__get_rows_to_drop(): }), 'child': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 5], - 'id_child': [5, 6, 7, 8, 9], + 'id_child1': [5, 6, 7, 8, 9], + 'id_child2': ['A', 'B', 'A', 'B', 'A'], 'B': ['Yes', 'No', 'No', 'No', 'No'], }), 'grandchild': pd.DataFrame({ 'parent_foreign_key': [0, 1, 2, 2, 6], - 'child_foreign_key': [9, 5, 11, 6, 4], + 'child_fk1': [9, 5, 11, 6, 6], + 'child_fk2': ['A', 'A', 'A', 'B', 'X'], 'C': ['Yes', 'No', 'No', 'No', 'No'], }), } From 7ed633b3837fa06b9ac05a9e07d94fbe2d2a8b2a Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Tue, 10 Feb 2026 15:09:10 -0500 Subject: [PATCH 2/5] Make drop_unknown_references compatible with composite keys --- sdv/_utils.py | 14 ++++++++++++++ sdv/metadata/multi_table.py | 18 ++++-------------- sdv/multi_table/utils.py | 25 ++++++++++++------------- 3 files changed, 30 insertions(+), 27 deletions(-) diff --git a/sdv/_utils.py b/sdv/_utils.py index ec28669ec..2881885ce 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -520,3 +520,17 @@ def _validate_correct_synthesizer_loading(synthesizer, cls): def _sort_keys(keys): return sorted(keys, key=lambda key: key if isinstance(key, str) else key[0]) + + +def _get_unreferenced_keys(parent_columns, child_columns): + indicator = _create_unique_name( + '_merge', list(child_columns.columns) + list(parent_columns.columns) + ) + merged_columns = child_columns.merge( + parent_columns, + left_on=list(child_columns.columns), + right_on=list(parent_columns.columns), + how='left', + indicator=indicator, + ) + return merged_columns[merged_columns[indicator] == 'left_only'][list(child_columns.columns)] diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 9f86389cf..5f5e877d7 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -12,8 +12,8 @@ from sdv._utils import ( _cast_to_iterable, - _create_unique_name, _format_invalid_values_string, + _get_unreferenced_keys, _load_data_from_csv, ) from sdv.errors import InvalidDataError @@ -910,18 +910,8 @@ def _validate_foreign_keys(self, data): if isinstance(child_table, pd.DataFrame) and isinstance(parent_table, pd.DataFrame): child_columns = child_table[_cast_to_iterable(relation['child_foreign_key'])] parent_columns = parent_table[_cast_to_iterable(relation['parent_primary_key'])] - indicator = _create_unique_name( - '_merge', list(child_columns.columns) + list(parent_columns.columns) - ) - merged_columns = parent_columns.merge( - child_columns.drop_duplicates(), - left_on=list(parent_columns.columns), - right_on=list(child_columns.columns), - how='right', - indicator=indicator, - ) - missing_values = merged_columns[merged_columns[indicator] == 'right_only'] - missing_values = missing_values[list(child_columns.columns)] + missing_values = _get_unreferenced_keys(parent_columns, child_columns) + missing_values = missing_values.drop_duplicates() if not missing_values.empty: foreign_key = relation['child_foreign_key'] if not isinstance(foreign_key, list): @@ -929,7 +919,7 @@ def _validate_foreign_keys(self, data): message = f'\n{_format_invalid_values_string(missing_values, 5)}' errors.append( - f"Error: foreign key column {foreign_key} contains " + f'Error: foreign key column {foreign_key} contains ' f'unknown references:{message}\n' "Please use the method 'drop_unknown_references' from sdv.utils " 'to clean the data.' diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index f998818f6..b429b8879 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -7,7 +7,12 @@ import numpy as np import pandas as pd -from sdv._utils import MODELABLE_SDTYPES, _get_root_tables +from sdv._utils import ( + MODELABLE_SDTYPES, + _cast_to_iterable, + _get_root_tables, + _get_unreferenced_keys, +) from sdv.errors import InvalidDataError, SamplingError from sdv.multi_table import HMASynthesizer from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS @@ -393,25 +398,19 @@ def _get_rows_to_drop(data, metadata): for root in current_roots: parent_table = root relationships_parent = _get_relationships_for_parent(relationships, parent_table) - parent_column = metadata.tables[parent_table].primary_key + parent_columns = _cast_to_iterable(metadata.tables[parent_table].primary_key) valid_parent_idx = [ idx for idx in data[parent_table].index if idx not in table_to_idx_to_drop[parent_table] ] - valid_parent_values = set(data[parent_table].loc[valid_parent_idx, parent_column]) + valid_parent_values = data[parent_table].loc[valid_parent_idx, parent_columns] for relationship in relationships_parent: child_table = relationship['child_table_name'] - child_column = relationship['child_foreign_key'] - - is_nan = data[child_table][child_column].isna() - invalid_values = ( - set(data[child_table].loc[~is_nan, child_column]) - valid_parent_values - ) - invalid_rows = data[child_table][ - data[child_table][child_column].isin(invalid_values) - ] - idx_to_drop = set(invalid_rows.index) + child_foreign_key = _cast_to_iterable(relationship['child_foreign_key']) + child_columns = data[child_table][child_foreign_key] + unreferenced_rows = _get_unreferenced_keys(valid_parent_values, child_columns) + idx_to_drop = set(unreferenced_rows.index) if idx_to_drop: table_to_idx_to_drop[child_table] = table_to_idx_to_drop[child_table].union( From 6463881ec3f7aae2273489176bf3883b2da7631c Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Tue, 10 Feb 2026 16:20:29 -0500 Subject: [PATCH 3/5] Fix tests --- sdv/_utils.py | 6 ++++-- sdv/multi_table/base.py | 8 ++++++-- sdv/single_table/base.py | 3 ++- tests/integration/utils/test_utils.py | 6 ++++-- tests/unit/metadata/test_multi_table.py | 8 +++++++- tests/unit/multi_table/test_base.py | 16 ++++++++++++++-- tests/unit/multi_table/test_utils.py | 2 +- tests/unit/single_table/test_base.py | 1 + 8 files changed, 39 insertions(+), 11 deletions(-) diff --git a/sdv/_utils.py b/sdv/_utils.py index 2881885ce..1d5c983e1 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -526,11 +526,13 @@ def _get_unreferenced_keys(parent_columns, child_columns): indicator = _create_unique_name( '_merge', list(child_columns.columns) + list(parent_columns.columns) ) - merged_columns = child_columns.merge( + merged = child_columns.merge( parent_columns, left_on=list(child_columns.columns), right_on=list(parent_columns.columns), how='left', indicator=indicator, ) - return merged_columns[merged_columns[indicator] == 'left_only'][list(child_columns.columns)] + merged = merged[merged[indicator] == 'left_only'][list(child_columns.columns)] + merged = merged.dropna(how='all') + return merged.dropna(how='all') diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 297726a8c..b2ab2d4e1 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -229,7 +229,11 @@ def add_constraints(self, constraints): A list of constraints to apply to the synthesizer. """ constraints = _validate_constraints(constraints, self._fitted) - metadata = getattr(self, '_composite_keys_metadata', None) or self.metadata + metadata = self.metadata + if hasattr(self, '_composite_keys_metadata'): + metadata = self._composite_keys_metadata + self._modified_multi_table_metadata = self._composite_keys_metadata + multi_table_constraints = [] single_table_constraints = [] idx_single_table_constraint = self._detect_single_table_constraints(constraints) @@ -243,9 +247,9 @@ def add_constraints(self, constraints): multi_table_constraints.append(constraint) metadata = constraint.get_updated_metadata(metadata) + self._modified_multi_table_metadata = metadata self.metadata = metadata - self._modified_multi_table_metadata = self.metadata self._handle_composite_keys() self._validate_single_table_constraints(single_table_constraints) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index f036b2b00..97d1ad1f7 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -152,7 +152,8 @@ def _validate_regex_format(self): def _handle_composite_keys(self): """Validates that composite keys are not used in Public SDV.""" - if self.metadata.tables[self._table_name]._primary_key_is_composite: + table_metadata = self.metadata._convert_to_single_table() + if table_metadata._primary_key_is_composite: raise SynthesizerInputError( 'Your metadata contains composite keys (primary key of table ' f"'{self._table_name}' has multiple columns). Composite keys are " diff --git a/tests/integration/utils/test_utils.py b/tests/integration/utils/test_utils.py index 79d2e8655..c42e635b2 100644 --- a/tests/integration/utils/test_utils.py +++ b/tests/integration/utils/test_utils.py @@ -60,8 +60,10 @@ def test_drop_unknown_references(metadata, data, capsys): expected_message = re.escape( 'The provided data does not match the metadata:\n' 'Relationships:\n' - "Error: foreign key column 'parent_id' contains unknown references: (5)" - ". Please use the method 'drop_unknown_references' from sdv.utils to clean the data." + "Error: foreign key column 'parent_id' contains unknown references:\n" + ' parent_id\n' + '4 5\n' + "Please use the method 'drop_unknown_references' from sdv.utils to clean the data." ) with pytest.raises(InvalidDataError, match=expected_message): metadata.validate_data(data) diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index f4acdbcf9..b6db4f496 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1499,7 +1499,13 @@ def test_validate_data_missing_foreign_keys(self): error_msg = re.escape( 'The provided data does not match the metadata:\n' 'Relationships:\n' - "Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9). " + "Error: foreign key column 'id_nesreca' contains unknown references:\n" + ' id_nesreca\n' + '1 1\n' + '3 3\n' + '5 5\n' + '7 7\n' + '9 9\n' "Please use the method 'drop_unknown_references' from sdv.utils to clean the data." ) with pytest.raises(InvalidDataError, match=error_msg): diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index a20ff73f6..03131ceaa 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -542,7 +542,7 @@ def test_validate(self): metadata = get_multi_table_metadata() data = get_multi_table_data() instance = BaseMultiTableSynthesizer(metadata) - instance._validate_transform_constraints = Mock() + instance._validate_transform_constraints = Mock(return_value=data) # Run instance.validate(data) @@ -659,7 +659,13 @@ def test_validate_missing_foreign_keys(self): error_msg = re.escape( 'The provided data does not match the metadata:\n' 'Relationships:\n' - "Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9). " + "Error: foreign key column 'id_nesreca' contains unknown references:\n" + ' id_nesreca\n' + '1 1\n' + '3 3\n' + '5 5\n' + '7 7\n' + '9 9\n' "Please use the method 'drop_unknown_references' from sdv.utils to clean the data." ) with pytest.raises(InvalidDataError, match=error_msg): @@ -1567,6 +1573,7 @@ def test_add_constraints(self, mock_validate_constraints, mock_programmable_cons """Test adding data constraints to the synthesizer.""" # Setup instance = Mock() + del instance._composite_keys_metadata original_metadata = get_multi_table_metadata() instance.metadata = original_metadata instance._original_metadata = original_metadata @@ -1618,6 +1625,7 @@ def test_add_constraints_single_table_overlap(self, mock_validate_constraints): """Test adding overlapping single-table constraints to the synthesizer.""" # Setup instance = Mock() + del instance._composite_keys_metadata original_metadata = get_multi_table_metadata() instance.metadata = original_metadata instance._original_metadata = original_metadata @@ -1665,6 +1673,7 @@ def test_updating_constraints_keeps_original_metadata(self, mock_validate_constr """Test adding data constraints to the synthesizer.""" # Setup instance = Mock() + del instance._composite_keys_metadata delattr(instance, 'constraints') metadata = get_multi_table_metadata() original_metadata = Mock() @@ -1821,6 +1830,7 @@ def test__validate_transform_constraints_with_constraints(self): """Test validating and transforming the data constraints.""" # Setup instance = Mock() + del instance._composite_keys data = {'table1': Mock(), 'table2': Mock()} constraint1 = Mock() constraint2 = Mock() @@ -1841,6 +1851,7 @@ def test__reverse_validate_transform_constraints_no_constraints(self): """Test reverse transforming when no data constraints have been set.""" # Setup instance = Mock() + del instance._composite_keys data = get_multi_table_data() delattr(instance, 'constraints') @@ -1859,6 +1870,7 @@ def test__reverse_validate_transform_constraints(self, drop_unknown_references): """Test reverse transforming the data constraints.""" # Setup instance = Mock() + del instance._composite_keys data = {'table1': Mock(), 'table2': Mock()} constraint1 = Mock() constraint2 = Mock() diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index 3e4804865..95dceb15b 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -1990,7 +1990,7 @@ def test__subsample_data( assert result == data -def test__subsample_data_with_null_foreing_keys(): +def test__subsample_data_with_null_foreign_keys(): """Test the ``_subsample_data`` method when there are null foreign keys.""" # Setup metadata = Metadata.load_from_dict({ diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 7d26cab74..6d2ecd581 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -727,6 +727,7 @@ def test__preprocess_helper(self, mock_warnings): instance = Mock() instance._fitted = True data = pd.DataFrame({'name': ['John', 'Doe', 'John Doe']}) + instance._composite_keys.transform.return_value = data instance._validate_transform_constraints.side_effect = lambda x: x expected_warning = ( 'This model has already been fitted. To use the new preprocessed data, please ' From 95ae2259c056dba4fe7a97ee6c8834c1ca47e03a Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Wed, 11 Feb 2026 10:52:27 -0500 Subject: [PATCH 4/5] Fix null fk detection --- sdv/multi_table/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index b429b8879..36069d83a 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -427,8 +427,10 @@ def _get_nan_fk_indices_table(data, relationships, table): idx_with_nan_foreign_key = set() relationships_for_table = _get_relationships_for_child(relationships, table) for relationship in relationships_for_table: - child_column = relationship['child_foreign_key'] - idx_with_nan_foreign_key.update(data[table][data[table][child_column].isna()].index) + child_columns = _cast_to_iterable(relationship['child_foreign_key']) + idx_with_nan_foreign_key.update( + data[table][data[table][child_columns].isna().all(axis=1)].index + ) return idx_with_nan_foreign_key From 2cff3bf2da42e1a88d582a7132f80caf93c4162e Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Tue, 17 Feb 2026 10:24:51 -0500 Subject: [PATCH 5/5] Comments --- sdv/multi_table/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index b2ab2d4e1..81c32b2a1 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -139,6 +139,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self._check_metadata_updated() self._original_metadata = deepcopy(self.metadata) self._modified_multi_table_metadata = deepcopy(self.metadata) + self._composite_keys_metadata = None self._handle_composite_keys() self.locales = locales self.verbose = False @@ -230,7 +231,7 @@ def add_constraints(self, constraints): """ constraints = _validate_constraints(constraints, self._fitted) metadata = self.metadata - if hasattr(self, '_composite_keys_metadata'): + if self._composite_keys_metadata is not None: metadata = self._composite_keys_metadata self._modified_multi_table_metadata = self._composite_keys_metadata @@ -250,6 +251,7 @@ def add_constraints(self, constraints): self._modified_multi_table_metadata = metadata self.metadata = metadata + self._composite_key = None self._handle_composite_keys() self._validate_single_table_constraints(single_table_constraints)