diff --git a/train.py b/train.py index 5c160e8..176b30c 100644 --- a/train.py +++ b/train.py @@ -86,11 +86,19 @@ def get_inductive_links(df, train_edge_end, val_edge_end): if args.use_inductive: test_df = df[val_edge_end:] - inductive_nodes = set(test_df.src.values).union(test_df.src.values) + train_df = df[:train_edge_end] + train_nodes = set(train_df.src.values).union(train_df.dst.values) + all_nodes = set(df.src.values).union(df.dst.values) + inductive_nodes = set(test_df.src.values).union(test_df.dst.values) print("inductive nodes", len(inductive_nodes)) - neg_link_sampler = NegLinkInductiveSampler(inductive_nodes) + neg_link_sampler_train = NegLinkInductiveSampler(train_nodes, seed=args.exp_seed) + neg_link_sampler_val = NegLinkInductiveSampler(all_nodes, seed=args.exp_seed) + neg_link_sampler_test = NegLinkInductiveSampler(inductive_nodes, seed=args.exp_seed) else: neg_link_sampler = NegLinkSampler(g['indptr'].shape[0] - 1) + neg_link_sampler_train = neg_link_sampler + neg_link_sampler_val = neg_link_sampler + neg_link_sampler_test = neg_link_sampler def eval(mode='val'): neg_samples = 1 @@ -98,11 +106,14 @@ def eval(mode='val'): aps = list() aucs_mrrs = list() if mode == 'val': + neg_link_sampler = neg_link_sampler_val eval_df = df[train_edge_end:val_edge_end] elif mode == 'test': + neg_link_sampler = neg_link_sampler_test eval_df = df[val_edge_end:] neg_samples = args.eval_neg_samples elif mode == 'train': + neg_link_sampler = neg_link_sampler_train eval_df = df[:train_edge_end] with torch.no_grad(): total_loss = 0 @@ -189,7 +200,7 @@ def eval(mode='val'): model.memory_updater.last_updated_nid = None for _, rows in df[:train_edge_end].groupby(group_indexes[random.randint(0, len(group_indexes) - 1)]): t_tot_s = time.time() - root_nodes = np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))]).astype(np.int32) + root_nodes = np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler_train.sample(len(rows))]).astype(np.int32) ts = np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32) if sampler is not None: if 'no_neg' in sample_param and sample_param['no_neg']: