Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions sdv/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,19 @@ def _validate_correct_synthesizer_loading(synthesizer, cls):

def _sort_keys(keys):
return sorted(keys, key=lambda key: key if isinstance(key, str) else key[0])


def _get_unreferenced_keys(parent_columns, child_columns):
indicator = _create_unique_name(
'_merge', list(child_columns.columns) + list(parent_columns.columns)
)
merged = child_columns.merge(
parent_columns,
left_on=list(child_columns.columns),
right_on=list(parent_columns.columns),
how='left',
indicator=indicator,
)
merged = merged[merged[indicator] == 'left_only'][list(child_columns.columns)]
merged = merged.dropna(how='all')
return merged.dropna(how='all')
36 changes: 20 additions & 16 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

import pandas as pd

from sdv._utils import _cast_to_iterable, _load_data_from_csv
from sdv._utils import (
_cast_to_iterable,
_format_invalid_values_string,
_get_unreferenced_keys,
_load_data_from_csv,
)
from sdv.errors import InvalidDataError
from sdv.logging import get_sdv_logger
from sdv.metadata.errors import InvalidMetadataError
Expand Down Expand Up @@ -903,22 +908,21 @@ def _validate_foreign_keys(self, data):
parent_table = data.get(relation['parent_table_name'])

if isinstance(child_table, pd.DataFrame) and isinstance(parent_table, pd.DataFrame):
child_column = child_table[relation['child_foreign_key']]
parent_column = parent_table[relation['parent_primary_key']]
missing_values = child_column[~child_column.isin(parent_column)].unique()
missing_values = missing_values[~pd.isna(missing_values)]

if any(missing_values):
message = ', '.join(missing_values[:5].astype(str))
if len(missing_values) > 5:
message = f'({message}, + more)'
else:
message = f'({message})'

child_columns = child_table[_cast_to_iterable(relation['child_foreign_key'])]
parent_columns = parent_table[_cast_to_iterable(relation['parent_primary_key'])]
missing_values = _get_unreferenced_keys(parent_columns, child_columns)
missing_values = missing_values.drop_duplicates()
if not missing_values.empty:
foreign_key = relation['child_foreign_key']
if not isinstance(foreign_key, list):
foreign_key = f"'{foreign_key}'"

message = f'\n{_format_invalid_values_string(missing_values, 5)}'
errors.append(
f"Error: foreign key column '{relation['child_foreign_key']}' contains "
f'unknown references: {message}. Please use the method'
" 'drop_unknown_references' from sdv.utils to clean the data."
f'Error: foreign key column {foreign_key} contains '
f'unknown references:{message}\n'
"Please use the method 'drop_unknown_references' from sdv.utils "
'to clean the data.'
)

if errors:
Expand Down
24 changes: 19 additions & 5 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,16 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self.metadata.validate()

self._check_metadata_updated()
self._original_metadata = deepcopy(self.metadata)
self._modified_multi_table_metadata = deepcopy(self.metadata)
self._composite_keys_metadata = None
self._handle_composite_keys()
self.locales = locales
self.verbose = False
self.extended_columns = defaultdict(dict)
self._table_synthesizers = {}
self._table_parameters = defaultdict(dict)
self._original_table_columns = {}
self._original_metadata = deepcopy(self.metadata)
self._modified_multi_table_metadata = deepcopy(self.metadata)
self.constraints = []
self._single_table_constraints = []
if synthesizer_kwargs is not None:
Expand Down Expand Up @@ -230,6 +231,10 @@ def add_constraints(self, constraints):
"""
constraints = _validate_constraints(constraints, self._fitted)
metadata = self.metadata
if self._composite_keys_metadata is not None:
metadata = self._composite_keys_metadata
self._modified_multi_table_metadata = self._composite_keys_metadata

multi_table_constraints = []
single_table_constraints = []
idx_single_table_constraint = self._detect_single_table_constraints(constraints)
Expand All @@ -246,6 +251,9 @@ def add_constraints(self, constraints):
self._modified_multi_table_metadata = metadata

self.metadata = metadata
self._composite_key = None
self._handle_composite_keys()

self._validate_single_table_constraints(single_table_constraints)
self.constraints += multi_table_constraints
self._constraints_fitted = False
Expand Down Expand Up @@ -355,6 +363,9 @@ def _validate_transform_constraints(self, data, enforce_constraint_fitting=False

def _reverse_transform_constraints(self, sampled_data):
"""Reverse transform constraints after sampling."""
if getattr(self, '_composite_keys', None):
sampled_data = self._composite_keys.reverse_transform(sampled_data)

if not hasattr(self, 'constraints'):
return sampled_data

Expand Down Expand Up @@ -454,7 +465,7 @@ def _validate_all_tables(self, data):
def validate(self, data):
"""Validate the data.

Validate that the metadata matches the data and thta every table's constraints are valid.
Validate that the metadata matches the data and that every table's constraints are valid.

Args:
data (dict[str, pd.DataFrame]):
Expand All @@ -463,6 +474,7 @@ def validate(self, data):
errors = []
metadata = self._original_metadata
metadata.validate_data(data)
data = self._validate_transform_constraints(data, enforce_constraint_fitting=True)
for table_name in data:
if table_name in self._table_synthesizers:
# Validate rules specific to each synthesizer
Expand All @@ -471,8 +483,6 @@ def validate(self, data):
if errors:
raise InvalidDataError(errors)

self._validate_transform_constraints(data, enforce_constraint_fitting=True)

def _validate_table_name(self, table_name):
if table_name not in self._table_synthesizers:
raise ValueError(
Expand Down Expand Up @@ -574,6 +584,10 @@ def preprocess(self, data):
list_of_changed_tables = self._store_and_convert_original_cols(data)
self.validate(data)
data = self._validate_transform_constraints(data)
if getattr(self, '_composite_keys', None):
self._composite_keys.fit(data, self._composite_keys_metadata)
data = self._composite_keys.transform(data)

if self._fitted:
msg = (
'This model has already been fitted. To use the new preprocessed data, '
Expand Down
2 changes: 1 addition & 1 deletion sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def __init__(self, metadata, locales=['en_US'], verbose=True):
self, self.metadata, self._table_synthesizers, self._table_sizes
)
child_tables = set()
for relationship in metadata.relationships:
for relationship in self.metadata.relationships:
child_tables.add(relationship['child_table_name'])
for child_table_name in child_tables:
self.set_table_parameters(child_table_name, {'default_distribution': 'norm'})
Expand Down
31 changes: 16 additions & 15 deletions sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
import numpy as np
import pandas as pd

from sdv._utils import MODELABLE_SDTYPES, _get_root_tables
from sdv._utils import (
MODELABLE_SDTYPES,
_cast_to_iterable,
_get_root_tables,
_get_unreferenced_keys,
)
from sdv.errors import InvalidDataError, SamplingError
from sdv.multi_table import HMASynthesizer
from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS
Expand Down Expand Up @@ -393,25 +398,19 @@ def _get_rows_to_drop(data, metadata):
for root in current_roots:
parent_table = root
relationships_parent = _get_relationships_for_parent(relationships, parent_table)
parent_column = metadata.tables[parent_table].primary_key
parent_columns = _cast_to_iterable(metadata.tables[parent_table].primary_key)
valid_parent_idx = [
idx
for idx in data[parent_table].index
if idx not in table_to_idx_to_drop[parent_table]
]
valid_parent_values = set(data[parent_table].loc[valid_parent_idx, parent_column])
valid_parent_values = data[parent_table].loc[valid_parent_idx, parent_columns]
for relationship in relationships_parent:
child_table = relationship['child_table_name']
child_column = relationship['child_foreign_key']

is_nan = data[child_table][child_column].isna()
invalid_values = (
set(data[child_table].loc[~is_nan, child_column]) - valid_parent_values
)
invalid_rows = data[child_table][
data[child_table][child_column].isin(invalid_values)
]
idx_to_drop = set(invalid_rows.index)
child_foreign_key = _cast_to_iterable(relationship['child_foreign_key'])
child_columns = data[child_table][child_foreign_key]
unreferenced_rows = _get_unreferenced_keys(valid_parent_values, child_columns)
idx_to_drop = set(unreferenced_rows.index)

if idx_to_drop:
table_to_idx_to_drop[child_table] = table_to_idx_to_drop[child_table].union(
Expand All @@ -428,8 +427,10 @@ def _get_nan_fk_indices_table(data, relationships, table):
idx_with_nan_foreign_key = set()
relationships_for_table = _get_relationships_for_child(relationships, table)
for relationship in relationships_for_table:
child_column = relationship['child_foreign_key']
idx_with_nan_foreign_key.update(data[table][data[table][child_column].isna()].index)
child_columns = _cast_to_iterable(relationship['child_foreign_key'])
idx_with_nan_foreign_key.update(
data[table][data[table][child_columns].isna().all(axis=1)].index
)

return idx_with_nan_foreign_key

Expand Down
20 changes: 15 additions & 5 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ def _validate_regex_format(self):
)
_check_regex_format(self._table_name, column_name, regex)

def _handle_composite_keys(self, single_table_metadata):
def _handle_composite_keys(self):
"""Validates that composite keys are not used in Public SDV."""
if single_table_metadata._primary_key_is_composite:
table_metadata = self.metadata._convert_to_single_table()
if table_metadata._primary_key_is_composite:
raise SynthesizerInputError(
'Your metadata contains composite keys (primary key of table '
f"'{self._table_name}' has multiple columns). Composite keys are "
Expand Down Expand Up @@ -181,11 +182,11 @@ def __init__(

self.metadata.validate()
self._check_metadata_updated()
single_table_metadata = self.metadata._convert_to_single_table()
self._handle_composite_keys(single_table_metadata)

# Points to a metadata object that conserves the initialized status of the synthesizer
self._original_metadata = deepcopy(self.metadata)
self._handle_composite_keys()
single_table_metadata = self.metadata._convert_to_single_table()

self.enforce_min_max_values = enforce_min_max_values
self.enforce_rounding = enforce_rounding
Expand Down Expand Up @@ -465,13 +466,15 @@ def add_constraints(self, constraints):
constraints (list):
A list of constraints to apply to the synthesizer.
"""
metadata = getattr(self, '_composite_keys_metadata', None) or self.metadata
constraints = _validate_constraints_single_table(constraints, self._fitted)
for constraint in constraints:
if isinstance(constraint, ProgrammableConstraint):
constraint = ProgrammableConstraintHarness(constraint)

try:
self.metadata = constraint.get_updated_metadata(self.metadata)
metadata = constraint.get_updated_metadata(metadata)
self.metadata = metadata
self._chained_constraints.append(constraint)
self._constraints_fitted = False
except ConstraintNotMetError as e:
Expand All @@ -486,6 +489,7 @@ def add_constraints(self, constraints):
raise e

self.metadata.validate()
self._handle_composite_keys()
self._data_processor = DataProcessor(
metadata=self.metadata._convert_to_single_table(),
enforce_rounding=self.enforce_rounding,
Expand Down Expand Up @@ -626,6 +630,9 @@ def _preprocess_helper(self, data):
warnings.warn(msg, RefitWarning)

data = self._validate_transform_constraints(data)
if getattr(self, '_composite_keys', None):
self._composite_keys.fit(data, self._composite_keys_metadata)
data = self._composite_keys.transform(data)

return data

Expand Down Expand Up @@ -787,6 +794,9 @@ def load(cls, filepath):

def reverse_transform_constraints(self, sampled):
"""Reverse transform the constraints."""
if getattr(self, '_composite_keys', None):
sampled = self._composite_keys.reverse_transform(sampled)

if hasattr(self, '_chained_constraints') and hasattr(self, '_reject_sampling_constraints'):
for constraint in reversed(self._chained_constraints):
sampled = constraint.reverse_transform(sampled)
Expand Down
6 changes: 4 additions & 2 deletions tests/integration/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ def test_drop_unknown_references(metadata, data, capsys):
expected_message = re.escape(
'The provided data does not match the metadata:\n'
'Relationships:\n'
"Error: foreign key column 'parent_id' contains unknown references: (5)"
". Please use the method 'drop_unknown_references' from sdv.utils to clean the data."
"Error: foreign key column 'parent_id' contains unknown references:\n"
' parent_id\n'
'4 5\n'
"Please use the method 'drop_unknown_references' from sdv.utils to clean the data."
)
with pytest.raises(InvalidDataError, match=expected_message):
metadata.validate_data(data)
Expand Down
Loading
Loading