From a3e3e17d6f9aaae5d95eb78a0e417c14dfa19d3c Mon Sep 17 00:00:00 2001 From: yoid2000 Date: Tue, 11 Mar 2025 07:41:15 +0100 Subject: [PATCH] Added helper function get_cluster to stitcher.py --- syndiffix/stitcher.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/syndiffix/stitcher.py b/syndiffix/stitcher.py index 7403d4c..e1299f7 100644 --- a/syndiffix/stitcher.py +++ b/syndiffix/stitcher.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Any, Dict, List, Tuple import numpy as np import pandas as pd @@ -81,3 +81,19 @@ def stitch(df_left: pd.DataFrame, df_right: pd.DataFrame, shared: bool = True) - col_names = [col_names_all[col_id] for col_id in columns] data = [[tup[0] for tup in row] for row in microdata] return pd.DataFrame(data, columns=col_names) + + +def get_cluster(syn: Synthesizer) -> Dict[str, List[Any]]: + # Returns a friendly representation of the cluster (column names instead of IDs) + cluster: Dict[str, List[Any]] = {"initial": [], "derived": []} + for col_id in syn.clusters.initial_cluster: + cluster["initial"].append(syn.forest.columns[col_id]) + for owner, cols1, cols2 in syn.clusters.derived_clusters: + cluster["derived"].append( + { + "stitch_style": owner, + "stitch_columns": [syn.forest.columns[col_id] for col_id in cols1], + "new_columns": [syn.forest.columns[col_id] for col_id in cols2], + } + ) + return cluster