diff --git a/graph_split/split_script.py b/graph_split/split_script.py index 8a774b2..81462c0 100644 --- a/graph_split/split_script.py +++ b/graph_split/split_script.py @@ -356,38 +356,44 @@ def split_cv(df, split_type, n_folds, seed=None): verify_split(df, train_idx[i], test_idx[i], split_type) return train_idx, test_idx -def generate_negative_samples(df, graph_type='directed', anchor='source', seed=None): +def generate_negative_samples(edges, random_samples, duplicates=True, seed=None): ''' + For any positive edge **(a, b)**, create a negative edge **(a, c)** such that (a, c) was not present in the set of positive edges. Parameters: - graph_type: 'undirected', 'directed' - anchor: 'source', 'target', 'both' - - If graph_type=='directed' and anchor=='source': - - without edge_type: For any positive edge (a, b) create a negative edge (a, c) such that (a, c) was not present in the set of positive edges. + edges: Pandas DataFrame containing two columns where each row is the edge (a, b). + random_samples: Pandas series containing list of possible labels for c. None if c should be sampled from the set of b. + duplicates: True if the edge (a, c) should be able to appear more than once. + seed: Random generation seed for reproducibility of samples. + :return: DataFrame containing negative edges (a, c) ''' #works for directed, anchor based graph without edgetype. - if df.shape[1]>2: - exit('Error: Not implemented for extra information on edges except for source and target. ') + randomState = np.random.RandomState(seed) + if edges.shape[1]>2: + raise ValueError('Too many columns! Ensure edges only contain columns for an edge (a, b) and no other information.') + df = edges.set_axis(['source', 'target'], axis=1, inplace=False) + init_sample_space = [] + if (random_samples is None): + init_sample_space = set(df['target'].unique()) + else: + init_sample_space = set(random_samples.unique()) + + source_wise_targets = df.groupby('source')['target'].agg([('target_list', lambda x:set(x)), ('count', 'size')]).reset_index() + source_wise_targets['target_list'] = source_wise_targets['target_list'].apply(lambda x: sorted(init_sample_space.difference(x))) all_sampled_sources = [] all_sampled_targets = [] - - if (graph_type=='directed') and (anchor=='source'): - init_sample_space = set(df['target'].unique()) - - source_wise_targets = df.groupby('source').agg(target_list= ('target', lambda x:set(x)), count = ('target', 'size')).reset_index() - source_wise_targets['target_list'] = source_wise_targets['target_list'].apply(lambda x: sorted(init_sample_space.difference(x))) - - for i, row in source_wise_targets.iterrows(): + for i, row in source_wise_targets.iterrows(): + sample_count = 0 + if (duplicates == False): sample_count = min(row['count'], len(row['target_list'])) - all_sampled_targets.extend(list(random.Random(seed).sample(row['target_list'],sample_count))) - all_sampled_sources.extend([row['source']]*sample_count) - negative_df = pd.DataFrame({'source': all_sampled_sources, 'target': all_sampled_targets}) - - return negative_df - else: - exit('Error: current code only works for graph_type= directed, anchor=source. ') + all_sampled_targets.extend(randomState.choice(row['target_list'], size=sample_count, replace=False)) + else: + sample_count = row['count'] + all_sampled_targets.extend(randomState.choice(row['target_list'], size=sample_count, replace=True)) + all_sampled_sources.extend([row['source']]*sample_count) + negative_df = pd.DataFrame({'source': all_sampled_sources, 'target': all_sampled_targets}) + return negative_df diff --git a/pyproject.toml b/pyproject.toml index 8f433f3..a0ac0ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,14 +4,19 @@ build-backend = "setuptools.build_meta" [project] name = "graph_split" -version = "0.3.1" +version = "0.3.2" description = "A package to split edges of graphs using different criteria compatible with machine learning model training." readme = "README.md" license = "GPL-3.0-only" authors = [ { name = "Nure Tasnina", email = "tasnina@vt.edu" } ] -requires-python = ">=3.9" +dependencies = [ + "pandas", + "numpy", + "scikit-learn" +] +requires-python = ">=3.7" [project.urls] Homepage = "https://github.com/Murali-group/graph-split"