diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..723ef36 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/ICEWS05-15/ICEWS05-15_predicate_preprocess.py b/ICEWS05-15/ICEWS05-15_predicate_preprocess.py index fda6b7b..a801904 100644 --- a/ICEWS05-15/ICEWS05-15_predicate_preprocess.py +++ b/ICEWS05-15/ICEWS05-15_predicate_preprocess.py @@ -5,8 +5,9 @@ import pandas as pd from collections import defaultdict as ddict + def load_quadruples(inPath, fileName, fileName2=None): - with open(os.path.join(inPath, fileName), 'r') as fr: + with open(os.path.join(inPath, fileName), "r") as fr: quadrupleList = [] times = set() for line in fr: @@ -15,14 +16,14 @@ def load_quadruples(inPath, fileName, fileName2=None): tail = int(line_split[2]) rel = int(line_split[1]) time = int(line_split[3]) - #ent_set.add(head) - #rel_set.add(rel) - #ent_set.add(tail) + # ent_set.add(head) + # rel_set.add(rel) + # ent_set.add(tail) quadrupleList.append([head, rel, tail, time]) times.add(time) if fileName2 is not None: - with open(os.path.join(inPath, fileName2), 'r') as fr: + with open(os.path.join(inPath, fileName2), "r") as fr: for line in fr: line_split = line.split() head = int(line_split[0]) @@ -36,21 +37,23 @@ def load_quadruples(inPath, fileName, fileName2=None): return np.asarray(quadrupleList), np.asarray(times) + def get_total_number(inPath, fileName): - with open(os.path.join(inPath, fileName), 'r') as fr: + with open(os.path.join(inPath, fileName), "r") as fr: for line in fr: line_split = line.split() return int(line_split[0]), int(line_split[1]) + def get_data_with_t(data, tim, split): - e1 = [quad[0] for quad in data if quad[3] == tim] # subject in this ts - rel = [quad[1] for quad in data if quad[3] == tim] # relation in this ts - e2 = [quad[2] for quad in data if quad[3] == tim] # object in this ts + e1 = [quad[0] for quad in data if quad[3] == tim] # subject in this ts + rel = [quad[1] for quad in data if quad[3] == tim] # relation in this ts + e2 = [quad[2] for quad in data if quad[3] == tim] # object in this ts triplet = np.array([e1, rel, e2]).transpose() - triplet_unique = np.unique(triplet, axis=0) # data without inv rel + triplet_unique = np.unique(triplet, axis=0) # data without inv rel - adj_mtx_idx = [] # adjacency matrix element index whose value is 1 + adj_mtx_idx = [] # adjacency matrix element index whose value is 1 sr2o = ddict(set) neib = ddict(set) so2r = ddict(set) @@ -59,8 +62,8 @@ def get_data_with_t(data, tim, split): sr2o[(e2[trp_idx], rel[trp_idx] + num_rel)].add(e1[trp_idx]) neib[e1[trp_idx]].add(e2[trp_idx]) neib[e2[trp_idx]].add(e1[trp_idx]) - #adj_mtx[e1[trp_idx], e2[trp_idx]] = 1 # adjacency matrix - #adj_mtx[e2[trp_idx], e1[trp_idx]] = 1 + # adj_mtx[e1[trp_idx], e2[trp_idx]] = 1 # adjacency matrix + # adj_mtx[e2[trp_idx], e1[trp_idx]] = 1 adj_mtx_idx.append([e1[trp_idx], rel[trp_idx], e2[trp_idx]]) adj_mtx_idx.append([e2[trp_idx], rel[trp_idx] + num_rel, e1[trp_idx]]) so2r[(e1[trp_idx], e2[trp_idx])].add(rel[trp_idx]) @@ -73,30 +76,65 @@ def get_data_with_t(data, tim, split): adj_mtx_idx = torch.tensor(adj_mtx_idx_unique, dtype=int).t() adj_one = torch.ones((adj_mtx_idx.shape[1],)) - trp = [] # as input for the model - trp_eval = [] # for evaluation - if split == 'train': + trp = [] # as input for the model + trp_eval = [] # for evaluation + if split == "train": for (sub, pre), obj in sr2o_tmp.items(): - trp.extend([{'triple':(sub, pre, o), 'label': sr2o_tmp[(sub, pre)], 'sub_samp': 1} for o in obj]) + trp.extend( + [ + { + "triple": (sub, pre, o), + "label": sr2o_tmp[(sub, pre)], + "sub_samp": 1, + } + for o in obj + ] + ) else: trp1 = [] trp2 = [] for trp_idx in range(triplet_unique.shape[0]): - sub, pre, obj = triplet_unique[trp_idx,:] - trp.append({'triple':(sub, pre, obj), 'label': sr2o_tmp[(sub, pre)], 'sub_samp': 1}) - trp.append({'triple': (obj, pre + num_rel, sub), 'label': sr2o_tmp[(obj, pre + num_rel)], 'sub_samp': 1}) - trp1.append({'triple': (sub, pre, obj), 'label': sr2o_tmp[(sub, pre)]}) - trp2.append({'triple': (obj, pre + num_rel, sub), 'label': sr2o_tmp[(obj, pre + num_rel)]}) + sub, pre, obj = triplet_unique[trp_idx, :] + trp.append( + { + "triple": (sub, pre, obj), + "label": sr2o_tmp[(sub, pre)], + "sub_samp": 1, + } + ) + trp.append( + { + "triple": (obj, pre + num_rel, sub), + "label": sr2o_tmp[(obj, pre + num_rel)], + "sub_samp": 1, + } + ) + trp1.append({"triple": (sub, pre, obj), "label": sr2o_tmp[(sub, pre)]}) + trp2.append( + { + "triple": (obj, pre + num_rel, sub), + "label": sr2o_tmp[(obj, pre + num_rel)], + } + ) trp_eval = [trp1, trp2] - return triplet_unique.transpose(), sr2o_tmp, trp, trp_eval, neib_tmp, torch.sparse_coo_tensor(adj_mtx_idx, adj_one, [num_e, 2 * num_rel, num_e]), so2r_tmp + return ( + triplet_unique.transpose(), + sr2o_tmp, + trp, + trp_eval, + neib_tmp, + torch.sparse_coo_tensor(adj_mtx_idx, adj_one, [num_e, 2 * num_rel, num_e]), + so2r_tmp, + ) + def construct_adj(data, num_rel): edge_index, edge_type = [], [] # Adding edges for trp_idx in range(data.shape[0]): - sub, rel, obj = data[trp_idx,:] + sub, rel, obj = data[trp_idx, :] edge_index.append((sub, obj)) edge_type.append(rel) @@ -111,15 +149,16 @@ def construct_adj(data, num_rel): return edge_index, edge_type + def load_static(num_rel): - #data = [] + # data = [] sr2o = ddict(set) - for split in ['train', 'valid', 'test']: - for line in open('{}.txt'.format(split)): - sub, rel, obj, _ = map(str.lower, line.strip().split('\t')) + for split in ["train", "valid", "test"]: + for line in open("{}.txt".format(split)): + sub, rel, obj, _ = map(str.lower, line.strip().split("\t")) sub, rel, obj = int(sub), int(rel), int(obj) - #data.append((sub, rel, obj)) + # data.append((sub, rel, obj)) sr2o[(sub, rel)].add(obj) sr2o[(obj, rel + num_rel)].add(sub) @@ -137,33 +176,37 @@ def load_static(num_rel): nei = ddict(list) so2r_all = ddict(list) adjlist = [] -num_e, num_rel = get_total_number('', 'stat.txt') +num_e, num_rel = get_total_number("", "stat.txt") t_indep_trp = load_static(num_rel) -for split in ['train', 'valid', 'test']: - quadruple, ts = load_quadruples('', '{}.txt'.format(split)) +for split in ["train", "valid", "test"]: + quadruple, ts = load_quadruples("", "{}.txt".format(split)) for ts_ in ts: print(ts_) timestamp[split].append(ts_) - data_ts_, sr2o, trp, trp_eval, neib, adj_mtx, so2r = get_data_with_t(quadruple, ts_, split) - data[split].append(data_ts_) # data without inv rel - sr2o_all[split].append(sr2o) # with inv rel + data_ts_, sr2o, trp, trp_eval, neib, adj_mtx, so2r = get_data_with_t( + quadruple, ts_, split + ) + data[split].append(data_ts_) # data without inv rel + sr2o_all[split].append(sr2o) # with inv rel so2r_all[split].append(so2r) nei[split].append(neib) adjlist.append(adj_mtx) edge_info = ddict(torch.Tensor) - edge_info['edge_index'], edge_info['edge_type'] = construct_adj(data_ts_.transpose(), num_rel) + edge_info["edge_index"], edge_info["edge_type"] = construct_adj( + data_ts_.transpose(), num_rel + ) edge_info = dict(edge_info) adjs[split].append(edge_info) - if split =='train': - triples[split].append(trp) # with inv rel + if split == "train": + triples[split].append(trp) # with inv rel else: triples[split].append(trp) - triples['{}_{}'.format(split, 'tail')].append(trp_eval[0]) - triples['{}_{}'.format(split, 'head')].append(trp_eval[1]) + triples["{}_{}".format(split, "tail")].append(trp_eval[0]) + triples["{}_{}".format(split, "head")].append(trp_eval[1]) data = dict(data) @@ -173,29 +216,29 @@ def load_static(num_rel): timestamp = dict(timestamp) nei = dict(nei) -with open('t_indep_trp.pkl', 'wb') as fp: +with open("t_indep_trp.pkl", "wb") as fp: pickle.dump(t_indep_trp, fp) -with open('data_tKG.pkl', 'wb') as fp: +with open("data_tKG.pkl", "wb") as fp: pickle.dump(data, fp) -with open('sr2o_all_tKG.pkl', 'wb') as fp: +with open("sr2o_all_tKG.pkl", "wb") as fp: pickle.dump(sr2o_all, fp) -with open('triples_tKG.pkl', 'wb') as fp: +with open("triples_tKG.pkl", "wb") as fp: pickle.dump(triples, fp) -with open('adjs_tKG.pkl', 'wb') as fp: +with open("adjs_tKG.pkl", "wb") as fp: pickle.dump(adjs, fp) -with open('timestamp_tKG.pkl', 'wb') as fp: +with open("timestamp_tKG.pkl", "wb") as fp: pickle.dump(timestamp, fp) -with open('neighbor_tKG.pkl', 'wb') as fp: +with open("neighbor_tKG.pkl", "wb") as fp: pickle.dump(nei, fp) -with open('adjlist_tKG.pkl', 'wb') as fp: +with open("adjlist_tKG.pkl", "wb") as fp: pickle.dump(adjlist, fp) - -with open('so2r_all_tKG.pkl', 'wb') as fp: + +with open("so2r_all_tKG.pkl", "wb") as fp: pickle.dump(so2r_all, fp) diff --git a/ICEWS05-15/mapping.py b/ICEWS05-15/mapping.py index bf94d7e..103c3f0 100644 --- a/ICEWS05-15/mapping.py +++ b/ICEWS05-15/mapping.py @@ -2,27 +2,32 @@ import pandas as pd from ordered_set import OrderedSet -file1 = pd.read_table("05-15train.txt", sep='\t') -file2 = pd.read_table("05-15valid.txt", sep='\t') -file3 = pd.read_table("05-15test.txt", sep='\t') -file = pd.concat([file1,file2,file3]) +file1 = pd.read_table("05-15train.txt", sep="\t") +file2 = pd.read_table("05-15valid.txt", sep="\t") +file3 = pd.read_table("05-15test.txt", sep="\t") +file = pd.concat([file1, file2, file3]) file = file.drop_duplicates() -nodes = np.unique(file['sub'].to_list() + file['obj'].to_list()) +nodes = np.unique(file["sub"].to_list() + file["obj"].to_list()) print("number of entities:", len(nodes)) print("number of interactions:", file.shape[0]) -len_train, len_val = int(file.shape[0] * 0.8)+41, int(file.shape[0] * 0.1)+56 +len_train, len_val = int(file.shape[0] * 0.8) + 41, int(file.shape[0] * 0.1) + 56 len_test = file.shape[0] - len_train - len_val -time_s = file['time'].to_list() +time_s = file["time"].to_list() print(len(np.unique(time_s))) print(np.unique((time_s))[0], np.unique((time_s))[-1]) print(file) -quadruples = file.sort_values(by="time" , ascending=True) +quadruples = file.sort_values(by="time", ascending=True) print(quadruples) ent_set, rel_set, t_set = OrderedSet(), OrderedSet(), OrderedSet() for quad in quadruples.itertuples(): - sub, rel, obj, t = getattr(quad, 'sub'), getattr(quad, 'rel'), getattr(quad, 'obj'), getattr(quad, 'time') + sub, rel, obj, t = ( + getattr(quad, "sub"), + getattr(quad, "rel"), + getattr(quad, "obj"), + getattr(quad, "time"), + ) ent_set.add(sub) rel_set.add(rel) ent_set.add(obj) @@ -30,8 +35,8 @@ ent2id = {ent: idx for idx, ent in enumerate(ent_set)} rel2id = {rel: idx for idx, rel in enumerate(rel_set)} -rel2id.update({rel+'_reverse': idx+len(rel2id) for idx, rel in enumerate(rel_set)}) -t2id = {t: idx*24 for idx, t in enumerate(t_set)} +rel2id.update({rel + "_reverse": idx + len(rel2id) for idx, rel in enumerate(rel_set)}) +t2id = {t: idx * 24 for idx, t in enumerate(t_set)} id2ent = {idx: ent for ent, idx in ent2id.items()} id2rel = {idx: rel for rel, idx in rel2id.items()} @@ -39,46 +44,65 @@ stat = open("stat.txt", "w") stat.write(str(len(ent2id))) -stat.write('\t') -stat.write(str(len(rel2id)//2)) +stat.write("\t") +stat.write(str(len(rel2id) // 2)) stat.close() -train_quad, val_quad, test_quad = quadruples.iloc[0:len_train], quadruples.iloc[len_train:len_train+len_val], quadruples[len_train+len_val:] -tr = open('train.txt','w') +train_quad, val_quad, test_quad = ( + quadruples.iloc[0:len_train], + quadruples.iloc[len_train : len_train + len_val], + quadruples[len_train + len_val :], +) +tr = open("train.txt", "w") for quad_tr in train_quad.itertuples(): - sub, rel, obj, t = getattr(quad_tr, 'sub'), getattr(quad_tr, 'rel'), getattr(quad_tr, 'obj'), getattr(quad_tr, 'time') + sub, rel, obj, t = ( + getattr(quad_tr, "sub"), + getattr(quad_tr, "rel"), + getattr(quad_tr, "obj"), + getattr(quad_tr, "time"), + ) tr.write(str(ent2id[sub])) - tr.write('\t') + tr.write("\t") tr.write(str(rel2id[rel])) - tr.write('\t') + tr.write("\t") tr.write(str(ent2id[obj])) - tr.write('\t') + tr.write("\t") tr.write(str(t2id[t])) - tr.write('\n') + tr.write("\n") tr.close() -val = open('valid.txt','w') +val = open("valid.txt", "w") for quad_val in val_quad.itertuples(): - sub, rel, obj, t = getattr(quad_val, 'sub'), getattr(quad_val, 'rel'), getattr(quad_val, 'obj'), getattr(quad_val, 'time') + sub, rel, obj, t = ( + getattr(quad_val, "sub"), + getattr(quad_val, "rel"), + getattr(quad_val, "obj"), + getattr(quad_val, "time"), + ) val.write(str(ent2id[sub])) - val.write('\t') + val.write("\t") val.write(str(rel2id[rel])) - val.write('\t') + val.write("\t") val.write(str(ent2id[obj])) - val.write('\t') + val.write("\t") val.write(str(t2id[t])) - val.write('\n') + val.write("\n") val.close() -te = open('test.txt','w') +te = open("test.txt", "w") for quad_te in test_quad.itertuples(): - sub, rel, obj, t = getattr(quad_te, 'sub'), getattr(quad_te, 'rel'), getattr(quad_te, 'obj'), getattr(quad_te, 'time') + sub, rel, obj, t = ( + getattr(quad_te, "sub"), + getattr(quad_te, "rel"), + getattr(quad_te, "obj"), + getattr(quad_te, "time"), + ) te.write(str(ent2id[sub])) - te.write('\t') + te.write("\t") te.write(str(rel2id[rel])) - te.write('\t') + te.write("\t") te.write(str(ent2id[obj])) - te.write('\t') + te.write("\t") te.write(str(t2id[t])) - te.write('\n') -te.close() \ No newline at end of file + te.write("\n") +te.close() diff --git a/TANGO.py b/TANGO.py index 7e85862..b90d4ec 100644 --- a/TANGO.py +++ b/TANGO.py @@ -10,88 +10,176 @@ from utils import * from eval import * + def save_model(model, args, best_val, best_epoch, optimizer, save_path): state = { - 'state_dict': model.state_dict(), - 'best_val': best_val, - 'best_epoch': best_epoch, - 'optimizer': optimizer.state_dict(), - 'args' : vars(args) + "state_dict": model.state_dict(), + "best_val": best_val, + "best_epoch": best_epoch, + "optimizer": optimizer.state_dict(), + "args": vars(args), } torch.save(state, save_path) + def load_model(load_path, optimizer, model): - state = torch.load(load_path, map_location={'cuda:3': 'cuda:1'}) - state_dict = state['state_dict'] - best_val = state['best_val'] - best_val_mrr = best_val['mrr'] + state = torch.load(load_path, map_location={"cuda:3": "cuda:1"}) + state_dict = state["state_dict"] + best_val = state["best_val"] + best_val_mrr = best_val["mrr"] model.load_state_dict(state_dict) - optimizer.load_state_dict(state['optimizer']) + optimizer.load_state_dict(state["optimizer"]) return best_val_mrr + def load_emb(load_path, model): - state = torch.load(load_path, map_location={'cuda:3': 'cuda:1'}) - state_dict = state['state_dict'] + state = torch.load(load_path, map_location={"cuda:3": "cuda:1"}) + state_dict = state["state_dict"] model.load_state_dict(state_dict) + def adjust_learning_rate(optimizer, lr, gamma): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr_ = lr * gamma for param_group in optimizer.param_groups: - param_group['lr'] = lr_ + param_group["lr"] = lr_ return lr_ -if __name__ == '__main__': - modelpth = './checkpoints/' - parser = argparse.ArgumentParser(description='TANGO Training Parameters') - parser.add_argument('--gde_core', type=str, default='mgcn', help='core layer function of the TANGO model') - parser.add_argument('--score_func', type=str, default='tucker', help='score function') - parser.add_argument('--core_layer', type=int, default=2, help='number of core function layers') - parser.add_argument('--num_epoch', type=int, default=100, help='number of maximum epoch') - parser.add_argument('--test_step', type=int, default=1, help='number of epochs after which we do evaluation') - parser.add_argument('--input_step', type=int, default=4, help='number of input steps for ODEblock') - parser.add_argument('--delta_step', type=int, default=0, help='number of steps between the last input snapshot and the prediction snapshot') - parser.add_argument('--target_step', type=int, default=1, help='number of prediction snapshots') - parser.add_argument('--initsize', type=int, default=200, help='size of initial representation dimension') - parser.add_argument('--embsize', type=int, default=200, help='size of output embeddings') - parser.add_argument('--hidsize', type=int, default= 200, help='size of representation dimension in the core function') - parser.add_argument('--lr', type=float, default=0.001, help='learning rate') - parser.add_argument('--solver', type=str, default='rk4', help='ODE solver') - parser.add_argument('--atol', type=float, default='1e-4', help='lower bound of the tolerance') - parser.add_argument('--rtol', type=float, default='1e-3', help='higher bound of the tolerance') - parser.add_argument('--device', type=str, default='cuda:0', help='device name') - parser.add_argument('--dataset', type=str, default='ICEWS05-15', help='dataset name') - parser.add_argument('--scale', type=float, default=0.1, help='scale the length of integration') - parser.add_argument('--dropout', type=float, default=0.3, help='dropout') - parser.add_argument('--bias', action='store_false', help='whether to use bias in relation specific transformation') - parser.add_argument('--adjoint_flag', action='store_false', help='whether to use adjoint method') - parser.add_argument('--opn', type=str, default='mult', help='composition operation to be used in MGCN') - parser.add_argument('--shuffle', action='store_false', help='shuffle in dataloader') - parser.add_argument('--cheby_grid', type=int, default=3, help='number of chebyshev nodes, without chebyshev approximation if cheby_grid=0') - parser.add_argument('--resume', action='store_true', help='retore a model') - parser.add_argument('--name', type=str, default='TANGO', help='name of the run') - parser.add_argument('--jump', action='store_true', help='whether to use graph transition layer') - parser.add_argument('--jump_init', type=float, default=0.01, help='weight of transition term') - parser.add_argument('--activation', type=str, default='relu', help='activation function') - parser.add_argument('--res', action='store_true', help='include residual MGCN layer') - parser.add_argument('--rel_jump', action='store_true', help='include transition tensor') - parser.add_argument('--induct_test', action='store_true', help='inductive link prediction') - parser.add_argument('--test', action='store_true', help='store to start the test, otherwise start training') + +if __name__ == "__main__": + modelpth = "./checkpoints/" + parser = argparse.ArgumentParser(description="TANGO Training Parameters") + parser.add_argument( + "--gde_core", + type=str, + default="mgcn", + help="core layer function of the TANGO model", + ) + parser.add_argument( + "--score_func", type=str, default="tucker", help="score function" + ) + parser.add_argument( + "--core_layer", type=int, default=2, help="number of core function layers" + ) + parser.add_argument( + "--num_epoch", type=int, default=100, help="number of maximum epoch" + ) + parser.add_argument( + "--test_step", + type=int, + default=1, + help="number of epochs after which we do evaluation", + ) + parser.add_argument( + "--input_step", type=int, default=4, help="number of input steps for ODEblock" + ) + parser.add_argument( + "--delta_step", + type=int, + default=0, + help="number of steps between the last input snapshot and the prediction snapshot", + ) + parser.add_argument( + "--target_step", type=int, default=1, help="number of prediction snapshots" + ) + parser.add_argument( + "--initsize", + type=int, + default=200, + help="size of initial representation dimension", + ) + parser.add_argument( + "--embsize", type=int, default=200, help="size of output embeddings" + ) + parser.add_argument( + "--hidsize", + type=int, + default=200, + help="size of representation dimension in the core function", + ) + parser.add_argument("--lr", type=float, default=0.001, help="learning rate") + parser.add_argument("--solver", type=str, default="rk4", help="ODE solver") + parser.add_argument( + "--atol", type=float, default="1e-4", help="lower bound of the tolerance" + ) + parser.add_argument( + "--rtol", type=float, default="1e-3", help="higher bound of the tolerance" + ) + parser.add_argument("--device", type=str, default="cuda:0", help="device name") + parser.add_argument( + "--dataset", type=str, default="ICEWS05-15", help="dataset name" + ) + parser.add_argument( + "--scale", type=float, default=0.1, help="scale the length of integration" + ) + parser.add_argument("--dropout", type=float, default=0.3, help="dropout") + parser.add_argument( + "--bias", + action="store_false", + help="whether to use bias in relation specific transformation", + ) + parser.add_argument( + "--adjoint_flag", action="store_false", help="whether to use adjoint method" + ) + parser.add_argument( + "--opn", + type=str, + default="mult", + help="composition operation to be used in MGCN", + ) + parser.add_argument("--shuffle", action="store_false", help="shuffle in dataloader") + parser.add_argument( + "--cheby_grid", + type=int, + default=3, + help="number of chebyshev nodes, without chebyshev approximation if cheby_grid=0", + ) + parser.add_argument("--resume", action="store_true", help="retore a model") + parser.add_argument("--name", type=str, default="TANGO", help="name of the run") + parser.add_argument( + "--jump", action="store_true", help="whether to use graph transition layer" + ) + parser.add_argument( + "--jump_init", type=float, default=0.01, help="weight of transition term" + ) + parser.add_argument( + "--activation", type=str, default="relu", help="activation function" + ) + parser.add_argument( + "--res", action="store_true", help="include residual MGCN layer" + ) + parser.add_argument( + "--rel_jump", action="store_true", help="include transition tensor" + ) + parser.add_argument( + "--induct_test", action="store_true", help="inductive link prediction" + ) + parser.add_argument( + "--test", + action="store_true", + help="store to start the test, otherwise start training", + ) args = parser.parse_args() - if not args.resume: args.name = args.name + '_' + time.strftime('%Y_%m_%d') + '_' + time.strftime('%H:%M:%S') + if not args.resume: + args.name = ( + args.name + + "_" + + time.strftime("%Y_%m_%d") + + "_" + + time.strftime("%H:%M:%S") + ) logger = setup_logger(args.name) - if not os.path.exists(modelpth): os.mkdir(modelpth) loadpth = modelpth + args.name - device = args.device if torch.cuda.is_available() else 'cpu' + device = args.device if torch.cuda.is_available() else "cpu" args.device = device print("Using device: ", device) @@ -104,118 +192,179 @@ def adjust_learning_rate(optimizer, lr, gamma): torch.manual_seed(0) np.random.seed(0) - if args.dataset == 'ICEWS14': - val_exist = 0 # ICEWS14 does not have validation set + if args.dataset == "ICEWS14": + val_exist = 0 # ICEWS14 does not have validation set else: val_exist = 1 if val_exist: if args.induct_test: - num_e, num_rel, test_timestamps, test_adj, test_triple, test_1nei, t_indep_trp, test_so2r, induct_tar = setup_induct_test( - args.dataset, logger, args.scale, args.input_step) + ( + num_e, + num_rel, + test_timestamps, + test_adj, + test_triple, + test_1nei, + t_indep_trp, + test_so2r, + induct_tar, + ) = setup_induct_test(args.dataset, logger, args.scale, args.input_step) adjlist = load_adjmtx(args.dataset) - test_adjmtx = adjlist[-len(test_timestamps):] + test_adjmtx = adjlist[-len(test_timestamps) :] else: induct_tar = None - num_e, num_rel, train_timestamps, test_timestamps, val_timestamps, train_adj, test_adj, val_adj, train_triple, test_triple, val_triple, \ - train_1nei, test_1nei, val_1nei, t_indep_trp, train_so2r, val_so2r, test_so2r = setup_tKG(args.dataset, - logger, - args.initsize, - args.scale, - val_exist, - args.input_step) - trainl, testl, vall = len(train_timestamps), len(test_timestamps)-args.input_step, len(val_timestamps)-args.input_step + ( + num_e, + num_rel, + train_timestamps, + test_timestamps, + val_timestamps, + train_adj, + test_adj, + val_adj, + train_triple, + test_triple, + val_triple, + train_1nei, + test_1nei, + val_1nei, + t_indep_trp, + train_so2r, + val_so2r, + test_so2r, + ) = setup_tKG( + args.dataset, + logger, + args.initsize, + args.scale, + val_exist, + args.input_step, + ) + trainl, testl, vall = ( + len(train_timestamps), + len(test_timestamps) - args.input_step, + len(val_timestamps) - args.input_step, + ) adjlist = load_adjmtx(args.dataset) - train_adjmtx, test_adjmtx, val_adjmtx = adjlist[:trainl], adjlist[trainl+vall-args.input_step:], adjlist[trainl-args.input_step:trainl+vall] + train_adjmtx, test_adjmtx, val_adjmtx = ( + adjlist[:trainl], + adjlist[trainl + vall - args.input_step :], + adjlist[trainl - args.input_step : trainl + vall], + ) else: induct_tar = None - num_e, num_rel, train_timestamps, test_timestamps, train_adj, test_adj, train_triple, test_triple, train_1nei, test_1nei, t_indep_trp, train_so2r, test_so2r \ - = setup_tKG(args.dataset, logger, args.initsize, args.scale, val_exist, args.input_step) - trainl, testl = len(train_timestamps), len(test_timestamps)-args.input_step + ( + num_e, + num_rel, + train_timestamps, + test_timestamps, + train_adj, + test_adj, + train_triple, + test_triple, + train_1nei, + test_1nei, + t_indep_trp, + train_so2r, + test_so2r, + ) = setup_tKG( + args.dataset, logger, args.initsize, args.scale, val_exist, args.input_step + ) + trainl, testl = len(train_timestamps), len(test_timestamps) - args.input_step adjlist = load_adjmtx(args.dataset) - train_adjmtx, test_adjmtx = adjlist[:trainl], adjlist[trainl-args.input_step:trainl+testl] - + train_adjmtx, test_adjmtx = ( + adjlist[:trainl], + adjlist[trainl - args.input_step : trainl + testl], + ) if args.induct_test: - test_dataset = TANGOtestDataset(args, - test_triple, - test_adj, - test_adjmtx, - test_so2r, - num_e, - input_steps=args.input_step, - target_steps=args.target_step, - delta_steps=args.delta_step, - time_stamps=test_timestamps, - t_indep_trp=t_indep_trp, - induct_tar=induct_tar) - test_loader = TANGOtestDataLoader(dataset=test_dataset, - batch_size=1, - shuffle=args.shuffle) + test_dataset = TANGOtestDataset( + args, + test_triple, + test_adj, + test_adjmtx, + test_so2r, + num_e, + input_steps=args.input_step, + target_steps=args.target_step, + delta_steps=args.delta_step, + time_stamps=test_timestamps, + t_indep_trp=t_indep_trp, + induct_tar=induct_tar, + ) + test_loader = TANGOtestDataLoader( + dataset=test_dataset, batch_size=1, shuffle=args.shuffle + ) else: - train_dataset = TANGOtrainDataset(args, - train_triple, - train_adj, - train_adjmtx, - train_so2r, - num_e, - input_steps=args.input_step, - target_steps=args.target_step, - delta_steps=args.delta_step, - time_stamps=train_timestamps, - neg_samp=False) - - test_dataset = TANGOtestDataset(args, - test_triple, - test_adj, - test_adjmtx, - test_so2r, - num_e, - input_steps=args.input_step, - target_steps=args.target_step, - delta_steps=args.delta_step, - time_stamps=test_timestamps, - t_indep_trp=t_indep_trp) - - train_loader = TANGOtrainDataLoader(dataset=train_dataset, - batch_size=1, - shuffle=args.shuffle) - - test_loader = TANGOtestDataLoader(dataset=test_dataset, - batch_size=1, - shuffle=args.shuffle) + train_dataset = TANGOtrainDataset( + args, + train_triple, + train_adj, + train_adjmtx, + train_so2r, + num_e, + input_steps=args.input_step, + target_steps=args.target_step, + delta_steps=args.delta_step, + time_stamps=train_timestamps, + neg_samp=False, + ) + + test_dataset = TANGOtestDataset( + args, + test_triple, + test_adj, + test_adjmtx, + test_so2r, + num_e, + input_steps=args.input_step, + target_steps=args.target_step, + delta_steps=args.delta_step, + time_stamps=test_timestamps, + t_indep_trp=t_indep_trp, + ) + + train_loader = TANGOtrainDataLoader( + dataset=train_dataset, batch_size=1, shuffle=args.shuffle + ) + + test_loader = TANGOtestDataLoader( + dataset=test_dataset, batch_size=1, shuffle=args.shuffle + ) if val_exist: - val_dataset = TANGOtestDataset(args, - val_triple, - val_adj, - val_adjmtx, - val_so2r, - num_e, - input_steps=args.input_step, - target_steps=args.target_step, - delta_steps=args.delta_step, - time_stamps=val_timestamps, - t_indep_trp=t_indep_trp) - - val_loader = TANGOtestDataLoader(dataset=val_dataset, - batch_size=1, - shuffle=False) + val_dataset = TANGOtestDataset( + args, + val_triple, + val_adj, + val_adjmtx, + val_so2r, + num_e, + input_steps=args.input_step, + target_steps=args.target_step, + delta_steps=args.delta_step, + time_stamps=val_timestamps, + t_indep_trp=t_indep_trp, + ) + + val_loader = TANGOtestDataLoader( + dataset=val_dataset, batch_size=1, shuffle=False + ) eval_loader = val_loader else: eval_loader = test_loader - # instantiate model model = TANGO(num_e, num_rel, args, device, logger) model.to(device) for name, param in model.named_parameters(): - print(name, ' ', param.size()) + print(name, " ", param.size()) # optimizer optim = torch.optim.Adam(model.parameters(), lr=args.lr) @@ -225,103 +374,178 @@ def adjust_learning_rate(optimizer, lr, gamma): # resume unfinished training process if args.resume: best_val_mrr = load_model(loadpth, optim, model) - logger.info('Successfully Loaded previous model') + logger.info("Successfully Loaded previous model") - kill_cnt = 0 # counts the number of epochs before early stop + kill_cnt = 0 # counts the number of epochs before early stop if args.induct_test == False and args.test == False: for epoch in range(args.num_epoch): running_loss = 0 batch_num = 0 - ftime = 0 # forward time - btime = 0 # backward time + ftime = 0 # forward time + btime = 0 # backward time t1 = time.time() - for step, (sub_in, rel_in, obj_in, lab_in, sub_tar, rel_tar, obj_tar, lab_tar, tar_ts, in_ts, edge_idlist, \ - edge_typelist, adj_mtx, edge_jump_w, edge_jump_id, rel_jump) in enumerate(train_loader): + for ( + step, + ( + sub_in, + rel_in, + obj_in, + lab_in, + sub_tar, + rel_tar, + obj_tar, + lab_tar, + tar_ts, + in_ts, + edge_idlist, + edge_typelist, + adj_mtx, + edge_jump_w, + edge_jump_id, + rel_jump, + ), + ) in enumerate(train_loader): optim.zero_grad() model.train() - + # forward t3 = time.time() - loss = model(sub_tar, rel_tar, obj_tar, lab_tar, in_ts, tar_ts, edge_idlist, edge_typelist, edge_jump_id, edge_jump_w, rel_jump) + loss = model( + sub_tar, + rel_tar, + obj_tar, + lab_tar, + in_ts, + tar_ts, + edge_idlist, + edge_typelist, + edge_jump_id, + edge_jump_w, + rel_jump, + ) t4 = time.time() - ftime += (t4 - t3) + ftime += t4 - t3 # backward loss.backward() optim.step() t5 = time.time() - btime += (t5 - t4) + btime += t5 - t4 running_loss += loss.item() batch_num += 1 - running_loss /= batch_num # average loss + running_loss /= batch_num # average loss t2 = time.time() # report loss information - print("Epoch " + str(epoch + 1) + ": " + str(running_loss) + " Time: " + str(t2-t1)) - logger.info("Epoch " + str(epoch + 1) + ": " + str(running_loss) + " Time: " + str(t2-t1)) + print( + "Epoch " + + str(epoch + 1) + + ": " + + str(running_loss) + + " Time: " + + str(t2 - t1) + ) + logger.info( + "Epoch " + + str(epoch + 1) + + ": " + + str(running_loss) + + " Time: " + + str(t2 - t1) + ) # report forward and backward time - print("Epoch " + str(epoch + 1) + ": Forward Time: " + str(ftime) + " Backward Time: " + str(btime)) - logger.info("Epoch " + str(epoch + 1) + ": Forward Time: " + str(ftime) + " Backward Time: " + str(btime)) + print( + "Epoch " + + str(epoch + 1) + + ": Forward Time: " + + str(ftime) + + " Backward Time: " + + str(btime) + ) + logger.info( + "Epoch " + + str(epoch + 1) + + ": Forward Time: " + + str(ftime) + + " Backward Time: " + + str(btime) + ) # evaluation - if (epoch+1) % args.test_step == 0: + if (epoch + 1) % args.test_step == 0: if val_exist: - split = 'val' - results = predict(val_loader, model, args, num_e, test_adjmtx, logger) + split = "val" + results = predict( + val_loader, model, args, num_e, test_adjmtx, logger + ) else: - split = 'test' - results = predict(test_loader, model, args, num_e, test_adjmtx, logger) + split = "test" + results = predict( + test_loader, model, args, num_e, test_adjmtx, logger + ) print("===========RAW===========") - print("Epoch {}, HITS10 {}".format(epoch + 1, results['hits@10_raw'])) - print("Epoch {}, HITS3 {}".format(epoch + 1, results['hits@3_raw'])) - print("Epoch {}, HITS1 {}".format(epoch + 1, results['hits@1_raw'])) - print("Epoch {}, MRR {}".format(epoch + 1, results['mrr_raw'])) - print("Epoch {}, MAR {}".format(epoch + 1, results['mar_raw'])) + print("Epoch {}, HITS10 {}".format(epoch + 1, results["hits@10_raw"])) + print("Epoch {}, HITS3 {}".format(epoch + 1, results["hits@3_raw"])) + print("Epoch {}, HITS1 {}".format(epoch + 1, results["hits@1_raw"])) + print("Epoch {}, MRR {}".format(epoch + 1, results["mrr_raw"])) + print("Epoch {}, MAR {}".format(epoch + 1, results["mar_raw"])) print("=====TIME AWARE FILTER=====") - print("Epoch {}, HITS10 {}".format(epoch + 1, results['hits@10'])) - print("Epoch {}, HITS3 {}".format(epoch + 1, results['hits@3'])) - print("Epoch {}, HITS1 {}".format(epoch + 1, results['hits@1'])) - print("Epoch {}, MRR {}".format(epoch + 1, results['mrr'])) - print("Epoch {}, MAR {}".format(epoch + 1, results['mar'])) + print("Epoch {}, HITS10 {}".format(epoch + 1, results["hits@10"])) + print("Epoch {}, HITS3 {}".format(epoch + 1, results["hits@3"])) + print("Epoch {}, HITS1 {}".format(epoch + 1, results["hits@1"])) + print("Epoch {}, MRR {}".format(epoch + 1, results["mrr"])) + print("Epoch {}, MAR {}".format(epoch + 1, results["mar"])) print("====TIME UNAWARE FILTER====") - print("Epoch {}, HITS10 {}".format(epoch + 1, results['hits@10_ind'])) - print("Epoch {}, HITS3 {}".format(epoch + 1, results['hits@3_ind'])) - print("Epoch {}, HITS1 {}".format(epoch + 1, results['hits@1_ind'])) - print("Epoch {}, MRR {}".format(epoch + 1, results['mrr_ind'])) - print("Epoch {}, MAR {}".format(epoch + 1, results['mar_ind'])) + print("Epoch {}, HITS10 {}".format(epoch + 1, results["hits@10_ind"])) + print("Epoch {}, HITS3 {}".format(epoch + 1, results["hits@3_ind"])) + print("Epoch {}, HITS1 {}".format(epoch + 1, results["hits@1_ind"])) + print("Epoch {}, MRR {}".format(epoch + 1, results["mrr_ind"])) + print("Epoch {}, MAR {}".format(epoch + 1, results["mar_ind"])) logger.info("===========RAW===========") - logger.info("Epoch {}, HITS10 {}".format(epoch + 1, results['hits@10_raw'])) - logger.info("Epoch {}, HITS3 {}".format(epoch + 1, results['hits@3_raw'])) - logger.info("Epoch {}, HITS1 {}".format(epoch + 1, results['hits@1_raw'])) - logger.info("Epoch {}, MRR {}".format(epoch + 1, results['mrr_raw'])) - logger.info("Epoch {}, MAR {}".format(epoch + 1, results['mar_raw'])) + logger.info( + "Epoch {}, HITS10 {}".format(epoch + 1, results["hits@10_raw"]) + ) + logger.info( + "Epoch {}, HITS3 {}".format(epoch + 1, results["hits@3_raw"]) + ) + logger.info( + "Epoch {}, HITS1 {}".format(epoch + 1, results["hits@1_raw"]) + ) + logger.info("Epoch {}, MRR {}".format(epoch + 1, results["mrr_raw"])) + logger.info("Epoch {}, MAR {}".format(epoch + 1, results["mar_raw"])) logger.info("=====TIME AWARE FILTER=====") - logger.info("Epoch {}, HITS10 {}".format(epoch + 1, results['hits@10'])) - logger.info("Epoch {}, HITS3 {}".format(epoch + 1, results['hits@3'])) - logger.info("Epoch {}, HITS1 {}".format(epoch + 1, results['hits@1'])) - logger.info("Epoch {}, MRR {}".format(epoch + 1, results['mrr'])) - logger.info("Epoch {}, MAR {}".format(epoch + 1, results['mar'])) + logger.info("Epoch {}, HITS10 {}".format(epoch + 1, results["hits@10"])) + logger.info("Epoch {}, HITS3 {}".format(epoch + 1, results["hits@3"])) + logger.info("Epoch {}, HITS1 {}".format(epoch + 1, results["hits@1"])) + logger.info("Epoch {}, MRR {}".format(epoch + 1, results["mrr"])) + logger.info("Epoch {}, MAR {}".format(epoch + 1, results["mar"])) logger.info("====TIME UNAWARE FILTER====") - logger.info("Epoch {}, HITS10 {}".format(epoch + 1, results['hits@10_ind'])) - logger.info("Epoch {}, HITS3 {}".format(epoch + 1, results['hits@3_ind'])) - logger.info("Epoch {}, HITS1 {}".format(epoch + 1, results['hits@1_ind'])) - logger.info("Epoch {}, MRR {}".format(epoch + 1, results['mrr_ind'])) - logger.info("Epoch {}, MAR {}".format(epoch + 1, results['mar_ind'])) - - if results['mrr'] > best_val_mrr: + logger.info( + "Epoch {}, HITS10 {}".format(epoch + 1, results["hits@10_ind"]) + ) + logger.info( + "Epoch {}, HITS3 {}".format(epoch + 1, results["hits@3_ind"]) + ) + logger.info( + "Epoch {}, HITS1 {}".format(epoch + 1, results["hits@1_ind"]) + ) + logger.info("Epoch {}, MRR {}".format(epoch + 1, results["mrr_ind"])) + logger.info("Epoch {}, MAR {}".format(epoch + 1, results["mar_ind"])) + + if results["mrr"] > best_val_mrr: # update best result best_val = results - best_val_mrr = results['mrr'] + best_val_mrr = results["mrr"] best_epoch = epoch save_model(model, args, best_val, best_epoch, optim, loadpth) kill_cnt = 0 @@ -338,54 +562,56 @@ def adjust_learning_rate(optimizer, lr, gamma): logger.info("Epoch {}, MRR {}".format(epoch + 1, best_val_mrr)) else: - if args.induct_test: # inductive link prediction, run if you have a trained model + if ( + args.induct_test + ): # inductive link prediction, run if you have a trained model print("Start inductive testing...") logger.info("Start inductive testing...") else: print("Start testing...") logger.info("Start testing...") epoch = 0 - split = 'test' + split = "test" results = predict(test_loader, model, args, num_e, test_adjmtx, logger) print("===========RAW===========") - print("Epoch {}, HITS10 {}".format(epoch + 1, results['hits@10_raw'])) - print("Epoch {}, HITS3 {}".format(epoch + 1, results['hits@3_raw'])) - print("Epoch {}, HITS1 {}".format(epoch + 1, results['hits@1_raw'])) - print("Epoch {}, MRR {}".format(epoch + 1, results['mrr_raw'])) - print("Epoch {}, MAR {}".format(epoch + 1, results['mar_raw'])) + print("Epoch {}, HITS10 {}".format(epoch + 1, results["hits@10_raw"])) + print("Epoch {}, HITS3 {}".format(epoch + 1, results["hits@3_raw"])) + print("Epoch {}, HITS1 {}".format(epoch + 1, results["hits@1_raw"])) + print("Epoch {}, MRR {}".format(epoch + 1, results["mrr_raw"])) + print("Epoch {}, MAR {}".format(epoch + 1, results["mar_raw"])) print("=====TIME AWARE FILTER=====") - print("Epoch {}, HITS10 {}".format(epoch + 1, results['hits@10'])) - print("Epoch {}, HITS3 {}".format(epoch + 1, results['hits@3'])) - print("Epoch {}, HITS1 {}".format(epoch + 1, results['hits@1'])) - print("Epoch {}, MRR {}".format(epoch + 1, results['mrr'])) - print("Epoch {}, MAR {}".format(epoch + 1, results['mar'])) + print("Epoch {}, HITS10 {}".format(epoch + 1, results["hits@10"])) + print("Epoch {}, HITS3 {}".format(epoch + 1, results["hits@3"])) + print("Epoch {}, HITS1 {}".format(epoch + 1, results["hits@1"])) + print("Epoch {}, MRR {}".format(epoch + 1, results["mrr"])) + print("Epoch {}, MAR {}".format(epoch + 1, results["mar"])) print("====TIME UNAWARE FILTER====") - print("Epoch {}, HITS10 {}".format(epoch + 1, results['hits@10_ind'])) - print("Epoch {}, HITS3 {}".format(epoch + 1, results['hits@3_ind'])) - print("Epoch {}, HITS1 {}".format(epoch + 1, results['hits@1_ind'])) - print("Epoch {}, MRR {}".format(epoch + 1, results['mrr_ind'])) - print("Epoch {}, MAR {}".format(epoch + 1, results['mar_ind'])) + print("Epoch {}, HITS10 {}".format(epoch + 1, results["hits@10_ind"])) + print("Epoch {}, HITS3 {}".format(epoch + 1, results["hits@3_ind"])) + print("Epoch {}, HITS1 {}".format(epoch + 1, results["hits@1_ind"])) + print("Epoch {}, MRR {}".format(epoch + 1, results["mrr_ind"])) + print("Epoch {}, MAR {}".format(epoch + 1, results["mar_ind"])) logger.info("===========RAW===========") - logger.info("Epoch {}, HITS10 {}".format(epoch + 1, results['hits@10_raw'])) - logger.info("Epoch {}, HITS3 {}".format(epoch + 1, results['hits@3_raw'])) - logger.info("Epoch {}, HITS1 {}".format(epoch + 1, results['hits@1_raw'])) - logger.info("Epoch {}, MRR {}".format(epoch + 1, results['mrr_raw'])) - logger.info("Epoch {}, MAR {}".format(epoch + 1, results['mar_raw'])) + logger.info("Epoch {}, HITS10 {}".format(epoch + 1, results["hits@10_raw"])) + logger.info("Epoch {}, HITS3 {}".format(epoch + 1, results["hits@3_raw"])) + logger.info("Epoch {}, HITS1 {}".format(epoch + 1, results["hits@1_raw"])) + logger.info("Epoch {}, MRR {}".format(epoch + 1, results["mrr_raw"])) + logger.info("Epoch {}, MAR {}".format(epoch + 1, results["mar_raw"])) logger.info("=====TIME AWARE FILTER=====") - logger.info("Epoch {}, HITS10 {}".format(epoch + 1, results['hits@10'])) - logger.info("Epoch {}, HITS3 {}".format(epoch + 1, results['hits@3'])) - logger.info("Epoch {}, HITS1 {}".format(epoch + 1, results['hits@1'])) - logger.info("Epoch {}, MRR {}".format(epoch + 1, results['mrr'])) - logger.info("Epoch {}, MAR {}".format(epoch + 1, results['mar'])) + logger.info("Epoch {}, HITS10 {}".format(epoch + 1, results["hits@10"])) + logger.info("Epoch {}, HITS3 {}".format(epoch + 1, results["hits@3"])) + logger.info("Epoch {}, HITS1 {}".format(epoch + 1, results["hits@1"])) + logger.info("Epoch {}, MRR {}".format(epoch + 1, results["mrr"])) + logger.info("Epoch {}, MAR {}".format(epoch + 1, results["mar"])) logger.info("====TIME UNAWARE FILTER====") - logger.info("Epoch {}, HITS10 {}".format(epoch + 1, results['hits@10_ind'])) - logger.info("Epoch {}, HITS3 {}".format(epoch + 1, results['hits@3_ind'])) - logger.info("Epoch {}, HITS1 {}".format(epoch + 1, results['hits@1_ind'])) - logger.info("Epoch {}, MRR {}".format(epoch + 1, results['mrr_ind'])) - logger.info("Epoch {}, MAR {}".format(epoch + 1, results['mar_ind'])) \ No newline at end of file + logger.info("Epoch {}, HITS10 {}".format(epoch + 1, results["hits@10_ind"])) + logger.info("Epoch {}, HITS3 {}".format(epoch + 1, results["hits@3_ind"])) + logger.info("Epoch {}, HITS1 {}".format(epoch + 1, results["hits@1_ind"])) + logger.info("Epoch {}, MRR {}".format(epoch + 1, results["mrr_ind"])) + logger.info("Epoch {}, MAR {}".format(epoch + 1, results["mar_ind"])) diff --git a/TANGO_dataloader.py b/TANGO_dataloader.py index 6f06c16..b89c126 100644 --- a/TANGO_dataloader.py +++ b/TANGO_dataloader.py @@ -2,20 +2,23 @@ import torch import torch.utils + class TANGOtrainDataset(torch.utils.data.Dataset): - def __init__(self, - params, - triples: list, # triples['train'] - adjs: list, # {'edge_index': tensor, 'edge_type': tensor} - adjlist: list, # [adjmtx,...,adjmtx], adjmtx is torch sparse tensor - so2r: list, - num_e: int, - input_steps: int, - target_steps: int, - delta_steps: int = 0, - time_stamps: list = None, - num_samp=None, - neg_samp=None): + def __init__( + self, + params, + triples: list, # triples['train'] + adjs: list, # {'edge_index': tensor, 'edge_type': tensor} + adjlist: list, # [adjmtx,...,adjmtx], adjmtx is torch sparse tensor + so2r: list, + num_e: int, + input_steps: int, + target_steps: int, + delta_steps: int = 0, + time_stamps: list = None, + num_samp=None, + neg_samp=None, + ): assert isinstance(triples, list) self.p = params @@ -29,73 +32,114 @@ def __init__(self, self.so2r = so2r self.neg_samp = neg_samp - self.len = len(self.triples) - self.input_steps - self.target_steps - self.delta_steps + 1 - - assert len(triples) == len(time_stamps), "length of time stamps do not match with trajectories" + self.len = ( + len(self.triples) + - self.input_steps + - self.target_steps + - self.delta_steps + + 1 + ) + + assert len(triples) == len( + time_stamps + ), "length of time stamps do not match with trajectories" self.time_stamps = time_stamps self.adjlist = adjlist def __getitem__(self, idx): # target timestamps target_time_stamps = [] - for t_idx in range(idx + self.input_steps + self.delta_steps, - idx + self.input_steps + self.delta_steps + self.target_steps): + for t_idx in range( + idx + self.input_steps + self.delta_steps, + idx + self.input_steps + self.delta_steps + self.target_steps, + ): target_time_stamps.append(self.time_stamps[t_idx]) # graph info: (sub, rel, obj) triple_input = [] for i_idx in range(idx, idx + self.input_steps): - triple_input.append(torch.tensor([list(trp['triple']) for trp in self.triples[i_idx]])) + triple_input.append( + torch.tensor([list(trp["triple"]) for trp in self.triples[i_idx]]) + ) # sub - subject_input = [torch.stack([_trp[i,:] for i in range(_trp.shape[0])], dim=0)[:,0] for _trp in triple_input] + subject_input = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 0] + for _trp in triple_input + ] # rel - relation_input = [torch.stack([_trp[i,:] for i in range(_trp.shape[0])], dim=0)[:,1] for _trp in triple_input] + relation_input = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 1] + for _trp in triple_input + ] # obj - object_input = [torch.stack([_trp[i,:] for i in range(_trp.shape[0])], dim=0)[:,2] for _trp in triple_input] - + object_input = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 2] + for _trp in triple_input + ] # graph info: label corresponding to (sub, rel, obj) label_input = [] for i_idx in range(idx, idx + self.input_steps): - label_input.append(torch.stack([self.get_label(trp['label']) for trp in self.triples[i_idx]], dim=0)) + label_input.append( + torch.stack( + [self.get_label(trp["label"]) for trp in self.triples[i_idx]], dim=0 + ) + ) # pred graph info: (sub, rel, obj) triple_tar = [] - for t_idx in range(idx + self.input_steps + self.delta_steps, - idx + self.input_steps + self.delta_steps + self.target_steps): - triple_tar.append(torch.tensor([list(trp['triple']) for trp in self.triples[t_idx]])) + for t_idx in range( + idx + self.input_steps + self.delta_steps, + idx + self.input_steps + self.delta_steps + self.target_steps, + ): + triple_tar.append( + torch.tensor([list(trp["triple"]) for trp in self.triples[t_idx]]) + ) # sub - subject_tar = [torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 0] for _trp in triple_tar] + subject_tar = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 0] + for _trp in triple_tar + ] # rel - relation_tar = [torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 1] for _trp in triple_tar] + relation_tar = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 1] + for _trp in triple_tar + ] # obj - object_tar = [torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 2] for _trp in triple_tar] - + object_tar = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 2] + for _trp in triple_tar + ] # pred graph info: label corresponding to (sub, rel, obj) label_tar = [] - for t_idx in range(idx + self.input_steps + self.delta_steps, - idx + self.input_steps + self.delta_steps + self.target_steps): - label_tar.append(torch.stack([self.get_label(trp['label']) for trp in self.triples[t_idx]], dim=0)) + for t_idx in range( + idx + self.input_steps + self.delta_steps, + idx + self.input_steps + self.delta_steps + self.target_steps, + ): + label_tar.append( + torch.stack( + [self.get_label(trp["label"]) for trp in self.triples[t_idx]], dim=0 + ) + ) # input timestamps input_time_stamps = [] for i_idx in range(idx, idx + self.input_steps): input_time_stamps.append(self.time_stamps[i_idx]) - # edge information edge_index_list = [] edge_type_list = [] for i_idx in range(idx, idx + self.input_steps): - edge_index_list.append(self.adjs[i_idx]['edge_index']) - edge_type_list.append(self.adjs[i_idx]['edge_type']) + edge_index_list.append(self.adjs[i_idx]["edge_index"]) + edge_type_list.append(self.adjs[i_idx]["edge_type"]) # adjacency tensor ('mtx' means matrix, we preserve this name) adj_mtx_list = [] @@ -117,7 +161,15 @@ def __getitem__(self, idx): for i, a in enumerate(adj_mtx_list): if i != len(adj_mtx_list) - 1: jumped = torch.nonzero(a._values()) - edge_id_jump.append(torch.cat([a._indices()[:, jumped][0], a._indices()[:, jumped][2]], dim=1).t()) + edge_id_jump.append( + torch.cat( + [ + a._indices()[:, jumped][0], + a._indices()[:, jumped][2], + ], + dim=1, + ).t() + ) edge_w_jump.append(a._values()[jumped]) rel_jump.append(a._indices()[:, jumped][1].squeeze(1)) else: @@ -130,33 +182,52 @@ def __getitem__(self, idx): edge_id_jump.append(a._indices()[:, jumped]) edge_w_jump.append(a._values()[jumped].unsqueeze(-1)) - return (subject_input, relation_input, object_input, label_input, subject_tar, relation_tar, object_tar, - label_tar, target_time_stamps, input_time_stamps, edge_index_list, edge_type_list, adj_mtx_list, - edge_w_jump, edge_id_jump, rel_jump) + return ( + subject_input, + relation_input, + object_input, + label_input, + subject_tar, + relation_tar, + object_tar, + label_tar, + target_time_stamps, + input_time_stamps, + edge_index_list, + edge_type_list, + adj_mtx_list, + edge_w_jump, + edge_id_jump, + rel_jump, + ) def __len__(self): return self.len def get_label(self, label): y = np.zeros([self.num_e], dtype=np.float32) - for e2 in label: y[e2] = 1.0 + for e2 in label: + y[e2] = 1.0 return torch.FloatTensor(y) + class TANGOtestDataset(torch.utils.data.Dataset): - def __init__(self, - params, - triples: list, # triples['train'] - adjs: list, # {'edge_index': tensor, 'edge_type': tensor} - adjlist: list, # [adjmtx,...,adjmtx], adjmtx is torch sparse tensor - so2r: list, - num_e: int, - input_steps: int, - target_steps: int, - delta_steps: int = 0, - time_stamps: list = None, - t_indep_trp: dict = None, - num_samp=None, - induct_tar=None): + def __init__( + self, + params, + triples: list, # triples['train'] + adjs: list, # {'edge_index': tensor, 'edge_type': tensor} + adjlist: list, # [adjmtx,...,adjmtx], adjmtx is torch sparse tensor + so2r: list, + num_e: int, + input_steps: int, + target_steps: int, + delta_steps: int = 0, + time_stamps: list = None, + t_indep_trp: dict = None, + num_samp=None, + induct_tar=None, + ): assert isinstance(triples, list) self.p = params @@ -171,85 +242,152 @@ def __init__(self, self.t_indep_trp = t_indep_trp self.induct_tar = induct_tar - self.len = len(self.triples) - self.input_steps - self.target_steps - self.delta_steps + 1 - - assert len(triples) == len(time_stamps), "length of time stamps do not match with trajectories" + self.len = ( + len(self.triples) + - self.input_steps + - self.target_steps + - self.delta_steps + + 1 + ) + + assert len(triples) == len( + time_stamps + ), "length of time stamps do not match with trajectories" self.time_stamps = time_stamps self.adjlist = adjlist - def __getitem__(self, idx): # target timestamps target_time_stamps = [] - for t_idx in range(idx + self.input_steps + self.delta_steps, - idx + self.input_steps + self.delta_steps + self.target_steps): + for t_idx in range( + idx + self.input_steps + self.delta_steps, + idx + self.input_steps + self.delta_steps + self.target_steps, + ): target_time_stamps.append(self.time_stamps[t_idx]) # graph info: (sub, rel, obj) triple_input = [] for i_idx in range(idx, idx + self.input_steps): - triple_input.append(torch.tensor([list(trp['triple']) for trp in self.triples[i_idx]])) + triple_input.append( + torch.tensor([list(trp["triple"]) for trp in self.triples[i_idx]]) + ) # sub - subject_input = [torch.stack([_trp[i,:] for i in range(_trp.shape[0])], dim=0)[:,0] for _trp in triple_input] + subject_input = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 0] + for _trp in triple_input + ] # rel - relation_input = [torch.stack([_trp[i,:] for i in range(_trp.shape[0])], dim=0)[:,1] for _trp in triple_input] + relation_input = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 1] + for _trp in triple_input + ] # obj - object_input = [torch.stack([_trp[i,:] for i in range(_trp.shape[0])], dim=0)[:,2] for _trp in triple_input] + object_input = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 2] + for _trp in triple_input + ] # graph info: label corresponding to (sub, rel, obj) label_input = [] for i_idx in range(idx, idx + self.input_steps): - label_input.append(torch.stack([self.get_label(trp['label']) for trp in self.triples[i_idx]], dim=0)) + label_input.append( + torch.stack( + [self.get_label(trp["label"]) for trp in self.triples[i_idx]], dim=0 + ) + ) if self.induct_tar == None: # pred graph info: (sub, rel, obj) triple_tar = [] - for t_idx in range(idx + self.input_steps + self.delta_steps, - idx + self.input_steps + self.delta_steps + self.target_steps): - triple_tar.append(torch.tensor([list(trp['triple']) for trp in self.triples[t_idx]])) + for t_idx in range( + idx + self.input_steps + self.delta_steps, + idx + self.input_steps + self.delta_steps + self.target_steps, + ): + triple_tar.append( + torch.tensor([list(trp["triple"]) for trp in self.triples[t_idx]]) + ) # sub - subject_tar = [torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 0] for _trp in triple_tar] + subject_tar = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 0] + for _trp in triple_tar + ] # rel - relation_tar = [torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 1] for _trp in triple_tar] + relation_tar = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 1] + for _trp in triple_tar + ] # obj - object_tar = [torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 2] for _trp in triple_tar] - + object_tar = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 2] + for _trp in triple_tar + ] # pred graph info: label corresponding to (sub, rel, obj) label_tar = [] - for t_idx in range(idx + self.input_steps + self.delta_steps, - idx + self.input_steps + self.delta_steps + self.target_steps): - label_tar.append(torch.stack([self.get_label(trp['label']) for trp in self.triples[t_idx]], dim=0)) + for t_idx in range( + idx + self.input_steps + self.delta_steps, + idx + self.input_steps + self.delta_steps + self.target_steps, + ): + label_tar.append( + torch.stack( + [self.get_label(trp["label"]) for trp in self.triples[t_idx]], + dim=0, + ) + ) else: # pred graph info: (sub, rel, obj) triple_tar = [] - for t_idx in range(idx + self.input_steps + self.delta_steps, - idx + self.input_steps + self.delta_steps + self.target_steps): - triple_tar.append(torch.tensor([list(trp['triple']) for trp in self.induct_tar[t_idx]])) + for t_idx in range( + idx + self.input_steps + self.delta_steps, + idx + self.input_steps + self.delta_steps + self.target_steps, + ): + triple_tar.append( + torch.tensor( + [list(trp["triple"]) for trp in self.induct_tar[t_idx]] + ) + ) if len(self.induct_tar[t_idx]) == 0: subject_tar, relation_tar, object_tar, label_tar = [], [], [], [] else: # sub - subject_tar = [torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 0] for _trp in triple_tar] + subject_tar = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 0] + for _trp in triple_tar + ] # rel - relation_tar = [torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 1] for _trp in triple_tar] + relation_tar = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 1] + for _trp in triple_tar + ] # obj - object_tar = [torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 2] for _trp in triple_tar] + object_tar = [ + torch.stack([_trp[i, :] for i in range(_trp.shape[0])], dim=0)[:, 2] + for _trp in triple_tar + ] # pred graph info: label corresponding to (sub, rel, obj) label_tar = [] - for t_idx in range(idx + self.input_steps + self.delta_steps, - idx + self.input_steps + self.delta_steps + self.target_steps): - label_tar.append(torch.stack([self.get_label(trp['label']) for trp in self.induct_tar[t_idx]], dim=0)) - + for t_idx in range( + idx + self.input_steps + self.delta_steps, + idx + self.input_steps + self.delta_steps + self.target_steps, + ): + label_tar.append( + torch.stack( + [ + self.get_label(trp["label"]) + for trp in self.induct_tar[t_idx] + ], + dim=0, + ) + ) # input timestamps input_time_stamps = [] @@ -260,15 +398,26 @@ def __getitem__(self, idx): edge_index_list = [] edge_type_list = [] for i_idx in range(idx, idx + self.input_steps): - edge_index_list.append(self.adjs[i_idx]['edge_index']) - edge_type_list.append(self.adjs[i_idx]['edge_type']) + edge_index_list.append(self.adjs[i_idx]["edge_index"]) + edge_type_list.append(self.adjs[i_idx]["edge_type"]) # time independent label indep_lab = [] - for t_idx in range(idx + self.input_steps + self.delta_steps, - idx + self.input_steps + self.delta_steps + self.target_steps): - indep_lab.append(torch.stack([self.get_label(self.t_indep_trp[(trp['triple'][0], trp['triple'][1])]) for trp in self.triples[t_idx]], dim=0)) - + for t_idx in range( + idx + self.input_steps + self.delta_steps, + idx + self.input_steps + self.delta_steps + self.target_steps, + ): + indep_lab.append( + torch.stack( + [ + self.get_label( + self.t_indep_trp[(trp["triple"][0], trp["triple"][1])] + ) + for trp in self.triples[t_idx] + ], + dim=0, + ) + ) # adjacency tensor adj_mtx_list = [] @@ -291,10 +440,18 @@ def __getitem__(self, idx): for i, a in enumerate(adj_mtx_list): if i != len(adj_mtx_list) - 1: jumped = torch.nonzero(a._values()) - edge_id_jump.append(torch.cat([a._indices()[:, jumped][0], a._indices()[:, jumped][2]], dim=1).t()) + edge_id_jump.append( + torch.cat( + [ + a._indices()[:, jumped][0], + a._indices()[:, jumped][2], + ], + dim=1, + ).t() + ) edge_w_jump.append(a._values()[jumped]) rel_jump.append(a._indices()[:, jumped][1].squeeze(1)) - #print(rel_jump[-1].shape) + # print(rel_jump[-1].shape) else: edge_id_jump.append(edge_id_jump[-1]) edge_w_jump.append(edge_w_jump[-1]) @@ -305,24 +462,41 @@ def __getitem__(self, idx): edge_id_jump.append(a._indices()[:, jumped]) edge_w_jump.append(a._values()[jumped].unsqueeze(-1)) - return (subject_input, relation_input, object_input, label_input, subject_tar, relation_tar, object_tar, - label_tar, target_time_stamps, input_time_stamps, edge_index_list, edge_type_list, indep_lab, - adj_mtx_list, edge_w_jump, edge_id_jump, rel_jump) + return ( + subject_input, + relation_input, + object_input, + label_input, + subject_tar, + relation_tar, + object_tar, + label_tar, + target_time_stamps, + input_time_stamps, + edge_index_list, + edge_type_list, + indep_lab, + adj_mtx_list, + edge_w_jump, + edge_id_jump, + rel_jump, + ) def __len__(self): return self.len def get_label(self, label): y = np.zeros([self.num_e], dtype=np.float32) - for e2 in label: y[e2] = 1.0 + for e2 in label: + y[e2] = 1.0 return torch.FloatTensor(y) + class TANGOtrainDataLoader(torch.utils.data.DataLoader): def __init__(self, *args, **kwargs): - kwargs['collate_fn'] = self.collate_fn + kwargs["collate_fn"] = self.collate_fn super(TANGOtrainDataLoader, self).__init__(*args, **kwargs) - def collate_fn(self, batch): for item in batch: sub_in = item[0] @@ -342,16 +516,31 @@ def collate_fn(self, batch): edg_jump_id = item[14] rel_jump = item[15] - return (sub_in, rel_in, obj_in, lab_in, sub_tar, rel_tar, obj_tar, lab_tar, tar_ts, in_ts, edg_id, edg_typ, - adj_mtx, edg_jump_w, edg_jump_id, rel_jump) + return ( + sub_in, + rel_in, + obj_in, + lab_in, + sub_tar, + rel_tar, + obj_tar, + lab_tar, + tar_ts, + in_ts, + edg_id, + edg_typ, + adj_mtx, + edg_jump_w, + edg_jump_id, + rel_jump, + ) class TANGOtestDataLoader(torch.utils.data.DataLoader): def __init__(self, *args, **kwargs): - kwargs['collate_fn'] = self.collate_fn + kwargs["collate_fn"] = self.collate_fn super(TANGOtestDataLoader, self).__init__(*args, **kwargs) - def collate_fn(self, batch): for item in batch: sub_in = item[0] @@ -372,5 +561,22 @@ def collate_fn(self, batch): edg_jump_id = item[15] rel_jump = item[16] - return (sub_in, rel_in, obj_in, lab_in, sub_tar, rel_tar, obj_tar, lab_tar, tar_ts, in_ts, edg_id, edg_typ, - indep_lab, adj_mtx, edg_jump_w, edg_jump_id, rel_jump) \ No newline at end of file + return ( + sub_in, + rel_in, + obj_in, + lab_in, + sub_tar, + rel_tar, + obj_tar, + lab_tar, + tar_ts, + in_ts, + edg_id, + edg_typ, + indep_lab, + adj_mtx, + edg_jump_w, + edg_jump_id, + rel_jump, + ) diff --git a/eval.py b/eval.py index 1036d23..1976264 100644 --- a/eval.py +++ b/eval.py @@ -2,6 +2,7 @@ import time from utils import * + def push_data(*args, device=None): out_args = [] for arg in args: @@ -30,88 +31,169 @@ def predict(loader, model, params, num_e, test_adjmtx, logger): print("Start evaluation") t1 = time.time() - for step, ( - sub_in, rel_in, obj_in, lab_in, sub_tar, rel_tar, obj_tar, lab_tar, tar_ts, in_ts, edge_idlist, edge_typelist, - indep_lab, adj_mtx, edge_jump_w, edge_jump_id, rel_jump) in enumerate(iter): + for ( + step, + ( + sub_in, + rel_in, + obj_in, + lab_in, + sub_tar, + rel_tar, + obj_tar, + lab_tar, + tar_ts, + in_ts, + edge_idlist, + edge_typelist, + indep_lab, + adj_mtx, + edge_jump_w, + edge_jump_id, + rel_jump, + ), + ) in enumerate(iter): if len(sub_tar) == 0: continue # forward - emb = model.forward_eval(in_ts, tar_ts, edge_idlist, edge_typelist, edge_jump_id, edge_jump_w, rel_jump) + emb = model.forward_eval( + in_ts, + tar_ts, + edge_idlist, + edge_typelist, + edge_jump_id, + edge_jump_w, + rel_jump, + ) rank_count = 0 while rank_count < sub_tar[0].shape[0]: - l, r = rank_count, (rank_count + rank_group_num) if (rank_count + rank_group_num) <= sub_tar[0].shape[ - 0] else sub_tar[0].shape[0] + l, r = ( + rank_count, + (rank_count + rank_group_num) + if (rank_count + rank_group_num) <= sub_tar[0].shape[0] + else sub_tar[0].shape[0], + ) # push data onto gpu - [sub_tar_, rel_tar_, obj_tar_, lab_tar_, indep_lab_] = \ - push_data2(sub_tar[0][l:r], rel_tar[0][l:r], obj_tar[0][l:r], lab_tar[0][l:r, :], - indep_lab[0][l:r, :], device=p.device) + [sub_tar_, rel_tar_, obj_tar_, lab_tar_, indep_lab_] = push_data2( + sub_tar[0][l:r], + rel_tar[0][l:r], + obj_tar[0][l:r], + lab_tar[0][l:r, :], + indep_lab[0][l:r, :], + device=p.device, + ) # compute scores for corresponding triples - score = model.score_comp(sub_tar_, rel_tar_, emb, model.odeblock.odefunc) + score = model.score_comp( + sub_tar_, rel_tar_, emb, model.odeblock.odefunc + ) b_range = torch.arange(score.shape[0], device=p.device) # raw ranking - ranks = 1 + torch.argsort(torch.argsort(score, dim=1, descending=True), dim=1, descending=False)[ - b_range, obj_tar_] + ranks = ( + 1 + + torch.argsort( + torch.argsort(score, dim=1, descending=True), + dim=1, + descending=False, + )[b_range, obj_tar_] + ) ranks = ranks.float() - results['count_raw'] = torch.numel(ranks) + results.get('count_raw', 0.0) - results['mar_raw'] = torch.sum(ranks).item() + results.get('mar_raw', 0.0) - results['mrr_raw'] = torch.sum(1.0 / ranks).item() + results.get('mrr_raw', 0.0) + results["count_raw"] = torch.numel(ranks) + results.get( + "count_raw", 0.0 + ) + results["mar_raw"] = torch.sum(ranks).item() + results.get( + "mar_raw", 0.0 + ) + results["mrr_raw"] = torch.sum(1.0 / ranks).item() + results.get( + "mrr_raw", 0.0 + ) for k in range(10): - results['hits@{}_raw'.format(k + 1)] = torch.numel(ranks[ranks <= (k + 1)]) + results.get( - 'hits@{}_raw'.format(k + 1), 0.0) + results["hits@{}_raw".format(k + 1)] = torch.numel( + ranks[ranks <= (k + 1)] + ) + results.get("hits@{}_raw".format(k + 1), 0.0) # time aware filtering target_score = score[b_range, obj_tar_] - score = torch.where(lab_tar_.byte(), -torch.ones_like(score) * 10000000, score) + score = torch.where( + lab_tar_.byte(), -torch.ones_like(score) * 10000000, score + ) score[b_range, obj_tar_] = target_score # time aware filtered ranking - ranks = 1 + torch.argsort(torch.argsort(score, dim=1, descending=True), dim=1, descending=False)[ - b_range, obj_tar_] + ranks = ( + 1 + + torch.argsort( + torch.argsort(score, dim=1, descending=True), + dim=1, + descending=False, + )[b_range, obj_tar_] + ) ranks = ranks.float() - results['count'] = torch.numel(ranks) + results.get('count', 0.0) - results['mar'] = torch.sum(ranks).item() + results.get('mar', 0.0) - results['mrr'] = torch.sum(1.0 / ranks).item() + results.get('mrr', 0.0) + results["count"] = torch.numel(ranks) + results.get("count", 0.0) + results["mar"] = torch.sum(ranks).item() + results.get("mar", 0.0) + results["mrr"] = torch.sum(1.0 / ranks).item() + results.get("mrr", 0.0) for k in range(10): - results['hits@{}'.format(k + 1)] = torch.numel(ranks[ranks <= (k + 1)]) + results.get( - 'hits@{}'.format(k + 1), 0.0) + results["hits@{}".format(k + 1)] = torch.numel( + ranks[ranks <= (k + 1)] + ) + results.get("hits@{}".format(k + 1), 0.0) # time unaware filtering - score = torch.where(indep_lab_.byte(), -torch.ones_like(score) * 10000000, score) + score = torch.where( + indep_lab_.byte(), -torch.ones_like(score) * 10000000, score + ) score[b_range, obj_tar_] = target_score # time unaware filtered ranking - ranks = 1 + torch.argsort(torch.argsort(score, dim=1, descending=True), dim=1, descending=False)[ - b_range, obj_tar_] + ranks = ( + 1 + + torch.argsort( + torch.argsort(score, dim=1, descending=True), + dim=1, + descending=False, + )[b_range, obj_tar_] + ) ranks = ranks.float() - results['count_ind'] = torch.numel(ranks) + results.get('count_ind', 0.0) - results['mar_ind'] = torch.sum(ranks).item() + results.get('mar_ind', 0.0) - results['mrr_ind'] = torch.sum(1.0 / ranks).item() + results.get('mrr_ind', 0.0) + results["count_ind"] = torch.numel(ranks) + results.get( + "count_ind", 0.0 + ) + results["mar_ind"] = torch.sum(ranks).item() + results.get( + "mar_ind", 0.0 + ) + results["mrr_ind"] = torch.sum(1.0 / ranks).item() + results.get( + "mrr_ind", 0.0 + ) for k in range(10): - results['hits@{}_ind'.format(k + 1)] = torch.numel(ranks[ranks <= (k + 1)]) + results.get( - 'hits@{}_ind'.format(k + 1), 0.0) + results["hits@{}_ind".format(k + 1)] = torch.numel( + ranks[ranks <= (k + 1)] + ) + results.get("hits@{}_ind".format(k + 1), 0.0) rank_count += rank_group_num del sub_tar_, rel_tar_, obj_tar_, lab_tar_, indep_lab_ - results['mar'] = round(results['mar'] / results['count'], 5) - results['mrr'] = round(results['mrr'] / results['count'], 5) - results['mar_raw'] = round(results['mar_raw'] / results['count_raw'], 5) - results['mrr_raw'] = round(results['mrr_raw'] / results['count_raw'], 5) - results['mar_ind'] = round(results['mar_ind'] / results['count_ind'], 5) - results['mrr_ind'] = round(results['mrr_ind'] / results['count_ind'], 5) + results["mar"] = round(results["mar"] / results["count"], 5) + results["mrr"] = round(results["mrr"] / results["count"], 5) + results["mar_raw"] = round(results["mar_raw"] / results["count_raw"], 5) + results["mrr_raw"] = round(results["mrr_raw"] / results["count_raw"], 5) + results["mar_ind"] = round(results["mar_ind"] / results["count_ind"], 5) + results["mrr_ind"] = round(results["mrr_ind"] / results["count_ind"], 5) for k in range(10): - results['hits@{}'.format(k + 1)] = round(results['hits@{}'.format(k + 1)] / results['count'], 5) - results['hits@{}_raw'.format(k + 1)] = round(results['hits@{}_raw'.format(k + 1)] / results['count_raw'], 5) - results['hits@{}_ind'.format(k + 1)] = round(results['hits@{}_ind'.format(k + 1)] / results['count_ind'], 5) + results["hits@{}".format(k + 1)] = round( + results["hits@{}".format(k + 1)] / results["count"], 5 + ) + results["hits@{}_raw".format(k + 1)] = round( + results["hits@{}_raw".format(k + 1)] / results["count_raw"], 5 + ) + results["hits@{}_ind".format(k + 1)] = round( + results["hits@{}_ind".format(k + 1)] / results["count_ind"], 5 + ) t2 = time.time() print("evaluation time: ", t2 - t1) logger.info("evaluation time: {}".format(t2 - t1)) - return results \ No newline at end of file + return results diff --git a/helper.py b/helper.py index 79c119a..6378888 100644 --- a/helper.py +++ b/helper.py @@ -38,12 +38,14 @@ def get_logger(name, log_dir, config_dir): A logger object which writes to both file and stdout """ - config_dict = json.load(open(config_dir + 'log_config.json')) - config_dict['handlers']['file_handler']['filename'] = log_dir + name.replace('/', '-') + config_dict = json.load(open(config_dir + "log_config.json")) + config_dict["handlers"]["file_handler"]["filename"] = log_dir + name.replace( + "/", "-" + ) logging.config.dictConfig(config_dict) logger = logging.getLogger(name) - std_out_format = '%(asctime)s - [%(levelname)s] - %(message)s' + std_out_format = "%(asctime)s - [%(levelname)s] - %(message)s" consoleHandler = logging.StreamHandler(sys.stdout) consoleHandler.setFormatter(logging.Formatter(std_out_format)) logger.addHandler(consoleHandler) @@ -53,25 +55,37 @@ def get_logger(name, log_dir, config_dir): def get_combined_results(left_results, right_results): results = {} - count = float(left_results['count']) + count = float(left_results["count"]) - results['left_mr'] = round(left_results['mr'] / count, 5) - results['left_mrr'] = round(left_results['mrr'] / count, 5) - results['right_mr'] = round(right_results['mr'] / count, 5) - results['right_mrr'] = round(right_results['mrr'] / count, 5) - results['mr'] = round((left_results['mr'] + right_results['mr']) / (2 * count), 5) - results['mrr'] = round((left_results['mrr'] + right_results['mrr']) / (2 * count), 5) + results["left_mr"] = round(left_results["mr"] / count, 5) + results["left_mrr"] = round(left_results["mrr"] / count, 5) + results["right_mr"] = round(right_results["mr"] / count, 5) + results["right_mrr"] = round(right_results["mrr"] / count, 5) + results["mr"] = round((left_results["mr"] + right_results["mr"]) / (2 * count), 5) + results["mrr"] = round( + (left_results["mrr"] + right_results["mrr"]) / (2 * count), 5 + ) for k in range(10): - results['left_hits@{}'.format(k + 1)] = round(left_results['hits@{}'.format(k + 1)] / count, 5) - results['right_hits@{}'.format(k + 1)] = round(right_results['hits@{}'.format(k + 1)] / count, 5) - results['hits@{}'.format(k + 1)] = round( - (left_results['hits@{}'.format(k + 1)] + right_results['hits@{}'.format(k + 1)]) / (2 * count), 5) + results["left_hits@{}".format(k + 1)] = round( + left_results["hits@{}".format(k + 1)] / count, 5 + ) + results["right_hits@{}".format(k + 1)] = round( + right_results["hits@{}".format(k + 1)] / count, 5 + ) + results["hits@{}".format(k + 1)] = round( + ( + left_results["hits@{}".format(k + 1)] + + right_results["hits@{}".format(k + 1)] + ) + / (2 * count), + 5, + ) return results def get_param(shape): - param = Parameter(torch.Tensor(*shape)); + param = Parameter(torch.Tensor(*shape)) xavier_normal_(param.data) return param @@ -88,8 +102,14 @@ def conj(a): def cconv(a, b): - return torch.irfft(com_mult(torch.rfft(a, 1), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],)) + return torch.irfft( + com_mult(torch.rfft(a, 1), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],) + ) def ccorr(a, b): - return torch.irfft(com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],)) \ No newline at end of file + return torch.irfft( + com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), + 1, + signal_sizes=(a.shape[-1],), + ) diff --git a/models/MGCN.py b/models/MGCN.py index 6a18582..3ecfe26 100644 --- a/models/MGCN.py +++ b/models/MGCN.py @@ -2,71 +2,113 @@ from torch.nn.parameter import Parameter from .MGCNLayer import MGCNConvLayer + class MGCNLayerWrapper(torch.nn.Module): - def __init__(self, edge_index, edge_type, num_e, num_rel, act, drop1, drop2, sub, rel, params=None): - super().__init__() - self.nfe = 0 - self.p = params - self.edge_index = edge_index - self.edge_type = edge_type - self.p.hidsize = self.p.embsize if self.p.core_layer == 1 else self.p.hidsize - self.num_e = num_e - self.num_rel = num_rel - self.device = self.p.device - self.act = act - self.drop_l1 = torch.nn.Dropout(drop1) - self.drop_l2 = torch.nn.Dropout(drop2) - self.sub = sub - self.rel = rel + def __init__( + self, + edge_index, + edge_type, + num_e, + num_rel, + act, + drop1, + drop2, + sub, + rel, + params=None, + ): + super().__init__() + self.nfe = 0 + self.p = params + self.edge_index = edge_index + self.edge_type = edge_type + self.p.hidsize = self.p.embsize if self.p.core_layer == 1 else self.p.hidsize + self.num_e = num_e + self.num_rel = num_rel + self.device = self.p.device + self.act = act + self.drop_l1 = torch.nn.Dropout(drop1) + self.drop_l2 = torch.nn.Dropout(drop2) + self.sub = sub + self.rel = rel - # transition layer - self.jump = None - self.jump_weight = None + # transition layer + self.jump = None + self.jump_weight = None - # residual layer - if self.p.res: - self.res = torch.nn.Parameter(torch.FloatTensor([0.1])) + # residual layer + if self.p.res: + self.res = torch.nn.Parameter(torch.FloatTensor([0.1])) - # define MGCN Layer - self.conv1 = MGCNConvLayer(self.p.initsize, self.p.hidsize, act=self.act, params=self.p) - self.conv2 = MGCNConvLayer(self.p.hidsize, self.p.embsize, act=self.act, params=self.p) if self.p.core_layer == 2 else None + # define MGCN Layer + self.conv1 = MGCNConvLayer( + self.p.initsize, self.p.hidsize, act=self.act, params=self.p + ) + self.conv2 = ( + MGCNConvLayer(self.p.hidsize, self.p.embsize, act=self.act, params=self.p) + if self.p.core_layer == 2 + else None + ) - self.register_parameter('bias', Parameter(torch.zeros(num_e))) + self.register_parameter("bias", Parameter(torch.zeros(num_e))) - def set_graph(self, edge_index, edge_type): - self.edge_index = edge_index - self.edge_type = edge_type + def set_graph(self, edge_index, edge_type): + self.edge_index = edge_index + self.edge_type = edge_type - def set_jumpfunc(self, edge_id_jump, edge_w_jump, jumpfunc, jumpw=None, skip=False, rel_jump=None): - self.edge_id_jump = edge_id_jump - self.edge_w_jump = edge_w_jump - self.jump = jumpfunc - self.jump_weight = jumpw - self.skip = skip - self.rel_jump = rel_jump + def set_jumpfunc( + self, edge_id_jump, edge_w_jump, jumpfunc, jumpw=None, skip=False, rel_jump=None + ): + self.edge_id_jump = edge_id_jump + self.edge_w_jump = edge_w_jump + self.jump = jumpfunc + self.jump_weight = jumpw + self.skip = skip + self.rel_jump = rel_jump - def forward(self, t, emb): - self.nfe += 1 - jump_emb = emb.clone() - if self.p.res: - emb = emb + self.res * self.conv1(emb, self.edge_index, self.edge_type, self.num_e) - emb = self.drop_l1(emb) - emb = (emb + self.res * self.conv2(emb, self.edge_index, self.edge_type, self.num_e)) if self.p.core_layer == 2 else emb - emb = self.drop_l2(emb) if self.p.core_layer == 2 else emb - else: - emb = self.conv1(emb, self.edge_index, self.edge_type, self.num_e) - emb = self.drop_l1(emb) - emb = self.conv2(emb, self.edge_index, self.edge_type, self.num_e) if self.p.core_layer == 2 else emb - emb = self.drop_l2(emb) if self.p.core_layer == 2 else emb + def forward(self, t, emb): + self.nfe += 1 + jump_emb = emb.clone() + if self.p.res: + emb = emb + self.res * self.conv1( + emb, self.edge_index, self.edge_type, self.num_e + ) + emb = self.drop_l1(emb) + emb = ( + ( + emb + + self.res + * self.conv2(emb, self.edge_index, self.edge_type, self.num_e) + ) + if self.p.core_layer == 2 + else emb + ) + emb = self.drop_l2(emb) if self.p.core_layer == 2 else emb + else: + emb = self.conv1(emb, self.edge_index, self.edge_type, self.num_e) + emb = self.drop_l1(emb) + emb = ( + self.conv2(emb, self.edge_index, self.edge_type, self.num_e) + if self.p.core_layer == 2 + else emb + ) + emb = self.drop_l2(emb) if self.p.core_layer == 2 else emb - if self.p.jump: - if not self.skip: - if self.p.rel_jump: - jump_res = self.jump.forward(jump_emb, self.edge_id_jump, self.rel_jump, self.num_e, - dN=self.edge_w_jump) - else: - jump_res = self.jump(jump_emb, self.edge_id_jump, dN=self.edge_w_jump) - emb = emb + self.jump_weight * jump_res - emb = self.drop_l2(emb) + if self.p.jump: + if not self.skip: + if self.p.rel_jump: + jump_res = self.jump.forward( + jump_emb, + self.edge_id_jump, + self.rel_jump, + self.num_e, + dN=self.edge_w_jump, + ) + else: + jump_res = self.jump( + jump_emb, self.edge_id_jump, dN=self.edge_w_jump + ) + emb = emb + self.jump_weight * jump_res + emb = self.drop_l2(emb) - return emb \ No newline at end of file + return emb diff --git a/models/MGCNLayer.py b/models/MGCNLayer.py index 3ab5d27..7dd0af4 100644 --- a/models/MGCNLayer.py +++ b/models/MGCNLayer.py @@ -4,127 +4,163 @@ from torch_scatter import scatter_add from .message_passing import MessagePassing + class MGCNConvLayer(MessagePassing): - def __init__(self, in_channels, out_channels, act=lambda x: x, params=None, isjump=False, diag=False): - super(self.__class__, self).__init__() - - self.p = params - self.in_channels = in_channels - self.out_channels = out_channels - - self.act = act - self.device = None - self.diag = diag - - if self.diag: - self.w = get_param((1, out_channels)) - self.w_rel = get_param((1, out_channels)) - else: - self.w = get_param((in_channels, out_channels)) - self.w_rel = get_param((in_channels, out_channels)) # for custom rgcn layer - - self.drop = torch.nn.Dropout(self.p.dropout) - self.bn = torch.nn.BatchNorm1d(out_channels) - - if self.p.bias: self.register_parameter('bias', Parameter(torch.zeros(out_channels))) - - def forward(self, x, edge_index, edge_type, num_e, dN=None): - if self.device is None: - self.device = self.p.device - - ent_emb = x[:num_e,:] - rel_embed = x[num_e:,:] - self.norm = self.compute_norm(edge_index, num_e) - res = self.propagate('add', edge_index, edge_type=edge_type, rel_embed=rel_embed, x=ent_emb, edge_norm=self.norm, dN=dN) - out = self.drop(res) - - if self.p.bias: out = out + self.bias - out = self.bn(out) - - # if self.diag: - # return torch.cat([self.act(out), rel_embed * self.w_rel], dim=0) - # else: - # return torch.cat([self.act(out), torch.matmul(rel_embed, self.w_rel)], dim=0) - return torch.cat([self.act(out), rel_embed], dim=0) - - def forward_jump(self, x, edge_index, rel_jump, num_e, dN): - if self.device is None: - self.device = self.p.device - - ent_emb = x[:num_e, :] - rel_embed = x[num_e:, :] - - self.norm = self.compute_norm(edge_index, num_e) - res = self.propagate_jump('add', edge_index, rel_jump=rel_jump, rel_embed=rel_embed, x=ent_emb, - edge_norm=self.norm, dN=dN) - out = self.drop(res) - - if self.p.bias: out = out + self.bias - out = self.bn(out) - - # return torch.cat([self.act(out), torch.matmul(rel_embed, self.w_rel)], dim=0) - return torch.cat([self.act(out), rel_embed], dim=0) - - def message(self, x_j, edge_type, rel_embed, edge_norm, dN): - weight = self.w - rel_emb = torch.index_select(rel_embed, 0, edge_type) - if self.p.opn == 'None': - xj_rel = x_j - else: - xj_rel = self.rel_transform(x_j, rel_emb) - if dN is not None: - if self.diag: - out = xj_rel * weight * dN - else: - out = torch.mm(xj_rel, weight) * dN - else: - if self.diag: - out = xj_rel * weight - else: - out = torch.mm(xj_rel, weight) - - return out if edge_norm is None else out * edge_norm.view(-1, 1) - - def message_jump(self, x_j, rel_jump, rel_embed, edge_norm, dN): - weight = self.w - rel_emb = [] - - for r_j in rel_jump: - rel_emb.append(torch.mean(torch.index_select(rel_embed, 0, r_j), dim=0, keepdim=True)) - rel_emb = torch.stack(rel_emb, 0).squeeze(1) - if self.p.opn == 'None': - xj_rel = x_j - else: - xj_rel = self.rel_transform(x_j, rel_emb) - if dN is not None: - out = torch.mm(xj_rel, weight) * dN - else: - out = torch.mm(xj_rel, weight) - - return out if edge_norm is None else out * edge_norm.view(-1, 1) - - def rel_transform(self, ent_embed, rel_embed): - if self.p.opn == 'corr': trans_embed = ccorr(ent_embed, rel_embed) - elif self.p.opn == 'sub': trans_embed = ent_embed - rel_embed - elif self.p.opn == 'mult': trans_embed = ent_embed * rel_embed - else: raise NotImplementedError - - return trans_embed - - def update(self, aggr_out): - return aggr_out - - def compute_norm(self, edge_index, num_ent): - row, col = edge_index - edge_weight = torch.ones_like(row).float() - deg = scatter_add( edge_weight, row, dim=0, dim_size=num_ent) # Summing number of weights of the edges - deg_inv = deg.pow(-0.5) # D^{-0.5} - deg_inv[deg_inv == float('inf')] = 0 - norm = deg_inv[row] * edge_weight * deg_inv[col] # D^{-0.5} - - return norm + def __init__( + self, + in_channels, + out_channels, + act=lambda x: x, + params=None, + isjump=False, + diag=False, + ): + super(self.__class__, self).__init__() + + self.p = params + self.in_channels = in_channels + self.out_channels = out_channels + + self.act = act + self.device = None + self.diag = diag + + if self.diag: + self.w = get_param((1, out_channels)) + self.w_rel = get_param((1, out_channels)) + else: + self.w = get_param((in_channels, out_channels)) + self.w_rel = get_param((in_channels, out_channels)) # for custom rgcn layer + + self.drop = torch.nn.Dropout(self.p.dropout) + self.bn = torch.nn.BatchNorm1d(out_channels) + + if self.p.bias: + self.register_parameter("bias", Parameter(torch.zeros(out_channels))) + + def forward(self, x, edge_index, edge_type, num_e, dN=None): + if self.device is None: + self.device = self.p.device + + ent_emb = x[:num_e, :] + rel_embed = x[num_e:, :] + self.norm = self.compute_norm(edge_index, num_e) + res = self.propagate( + "add", + edge_index, + edge_type=edge_type, + rel_embed=rel_embed, + x=ent_emb, + edge_norm=self.norm, + dN=dN, + ) + out = self.drop(res) + + if self.p.bias: + out = out + self.bias + out = self.bn(out) + + # if self.diag: + # return torch.cat([self.act(out), rel_embed * self.w_rel], dim=0) + # else: + # return torch.cat([self.act(out), torch.matmul(rel_embed, self.w_rel)], dim=0) + return torch.cat([self.act(out), rel_embed], dim=0) + + def forward_jump(self, x, edge_index, rel_jump, num_e, dN): + if self.device is None: + self.device = self.p.device + + ent_emb = x[:num_e, :] + rel_embed = x[num_e:, :] + + self.norm = self.compute_norm(edge_index, num_e) + res = self.propagate_jump( + "add", + edge_index, + rel_jump=rel_jump, + rel_embed=rel_embed, + x=ent_emb, + edge_norm=self.norm, + dN=dN, + ) + out = self.drop(res) + + if self.p.bias: + out = out + self.bias + out = self.bn(out) + + # return torch.cat([self.act(out), torch.matmul(rel_embed, self.w_rel)], dim=0) + return torch.cat([self.act(out), rel_embed], dim=0) + + def message(self, x_j, edge_type, rel_embed, edge_norm, dN): + weight = self.w + rel_emb = torch.index_select(rel_embed, 0, edge_type) + if self.p.opn == "None": + xj_rel = x_j + else: + xj_rel = self.rel_transform(x_j, rel_emb) + if dN is not None: + if self.diag: + out = xj_rel * weight * dN + else: + out = torch.mm(xj_rel, weight) * dN + else: + if self.diag: + out = xj_rel * weight + else: + out = torch.mm(xj_rel, weight) + + return out if edge_norm is None else out * edge_norm.view(-1, 1) + + def message_jump(self, x_j, rel_jump, rel_embed, edge_norm, dN): + weight = self.w + rel_emb = [] + + for r_j in rel_jump: + rel_emb.append( + torch.mean(torch.index_select(rel_embed, 0, r_j), dim=0, keepdim=True) + ) + rel_emb = torch.stack(rel_emb, 0).squeeze(1) + if self.p.opn == "None": + xj_rel = x_j + else: + xj_rel = self.rel_transform(x_j, rel_emb) + if dN is not None: + out = torch.mm(xj_rel, weight) * dN + else: + out = torch.mm(xj_rel, weight) + + return out if edge_norm is None else out * edge_norm.view(-1, 1) + + def rel_transform(self, ent_embed, rel_embed): + if self.p.opn == "corr": + trans_embed = ccorr(ent_embed, rel_embed) + elif self.p.opn == "sub": + trans_embed = ent_embed - rel_embed + elif self.p.opn == "mult": + trans_embed = ent_embed * rel_embed + else: + raise NotImplementedError + + return trans_embed + + def update(self, aggr_out): + return aggr_out + + def compute_norm(self, edge_index, num_ent): + row, col = edge_index + edge_weight = torch.ones_like(row).float() + deg = scatter_add( + edge_weight, row, dim=0, dim_size=num_ent + ) # Summing number of weights of the edges + deg_inv = deg.pow(-0.5) # D^{-0.5} + deg_inv[deg_inv == float("inf")] = 0 + norm = deg_inv[row] * edge_weight * deg_inv[col] # D^{-0.5} + + return norm + def get_param(shape): - param = torch.nn.Parameter(torch.Tensor(*shape)); - xavier_normal_(param.data) - return param \ No newline at end of file + param = torch.nn.Parameter(torch.Tensor(*shape)) + xavier_normal_(param.data) + return param diff --git a/models/message_passing.py b/models/message_passing.py index ab0da5e..09d77ea 100644 --- a/models/message_passing.py +++ b/models/message_passing.py @@ -1,8 +1,9 @@ import inspect, torch from torch_scatter import scatter + def scatter_(name, src, index, dim_size=None): - r"""Aggregates all values from the :attr:`src` tensor at the indices + r"""Aggregates all values from the :attr:`src` tensor at the indices specified in the :attr:`index` tensor along the first dimension. If multiple indices reference the same location, their contributions are aggregated according to :attr:`name` (either :obj:`"add"`, @@ -19,14 +20,15 @@ def scatter_(name, src, index, dim_size=None): :rtype: :class:`Tensor` """ - if name == 'add': name = 'sum' - assert name in ['sum', 'mean', 'max'] - out = scatter(src, index, dim=0, out=None, dim_size=dim_size, reduce=name) - return out[0] if isinstance(out, tuple) else out + if name == "add": + name = "sum" + assert name in ["sum", "mean", "max"] + out = scatter(src, index, dim=0, out=None, dim_size=dim_size, reduce=name) + return out[0] if isinstance(out, tuple) else out class MessagePassing(torch.nn.Module): - r"""Base class for creating message passing layers + r"""Base class for creating message passing layers .. math:: \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, @@ -42,103 +44,125 @@ class MessagePassing(torch.nn.Module): """ - def __init__(self, aggr='add'): - super(MessagePassing, self).__init__() + def __init__(self, aggr="add"): + super(MessagePassing, self).__init__() - self.message_args = inspect.getargspec(self.message)[0][1:] # In the defined message function: get the list of arguments as list of string| For eg. in rgcn this will be ['x_j', 'edge_type', 'edge_norm'] (arguments of message function) - self.update_args = inspect.getargspec(self.update)[0][2:] # Same for update function starting from 3rd argument | first=self, second=out + self.message_args = inspect.getargspec(self.message)[0][ + 1: + ] # In the defined message function: get the list of arguments as list of string| For eg. in rgcn this will be ['x_j', 'edge_type', 'edge_norm'] (arguments of message function) + self.update_args = inspect.getargspec(self.update)[0][ + 2: + ] # Same for update function starting from 3rd argument | first=self, second=out - self.message_args_jump = inspect.getargspec(self.message_jump)[0][1:] + self.message_args_jump = inspect.getargspec(self.message_jump)[0][1:] - def propagate(self, aggr, edge_index, **kwargs): - r"""The initial call to start propagating messages. + def propagate(self, aggr, edge_index, **kwargs): + r"""The initial call to start propagating messages. Takes in an aggregation scheme (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`), the edge indices, and all additional data which is needed to construct messages and to update node embeddings.""" - assert aggr in ['add', 'mean', 'max'] - kwargs['edge_index'] = edge_index - - - size = None - message_args = [] - for arg in self.message_args: - if arg[-2:] == '_i': # If arguments ends with _i then include indic - tmp = kwargs[arg[:-2]] # Take the front part of the variable | Mostly it will be 'x', - size = tmp.size(0) - message_args.append(tmp[edge_index[0]]) # Lookup for head entities in edges - elif arg[-2:] == '_j': - tmp = kwargs[arg[:-2]] # tmp = kwargs['x'] - size = tmp.size(0) - message_args.append(tmp[edge_index[1]]) # Lookup for tail entities in edges - else: - message_args.append(kwargs[arg]) # Take things from kwargs - - update_args = [kwargs[arg] for arg in self.update_args] # Take update args from kwargs - - out = self.message(*message_args) - out = scatter_(aggr, out, edge_index[0], dim_size=size) # Aggregated neighbors for each vertex - out = self.update(out, *update_args) - - return out - - def message(self, x_j): # pragma: no cover - r"""Constructs messages in analogy to :math:`\phi_{\mathbf{\Theta}}` + assert aggr in ["add", "mean", "max"] + kwargs["edge_index"] = edge_index + + size = None + message_args = [] + for arg in self.message_args: + if arg[-2:] == "_i": # If arguments ends with _i then include indic + tmp = kwargs[ + arg[:-2] + ] # Take the front part of the variable | Mostly it will be 'x', + size = tmp.size(0) + message_args.append( + tmp[edge_index[0]] + ) # Lookup for head entities in edges + elif arg[-2:] == "_j": + tmp = kwargs[arg[:-2]] # tmp = kwargs['x'] + size = tmp.size(0) + message_args.append( + tmp[edge_index[1]] + ) # Lookup for tail entities in edges + else: + message_args.append(kwargs[arg]) # Take things from kwargs + + update_args = [ + kwargs[arg] for arg in self.update_args + ] # Take update args from kwargs + + out = self.message(*message_args) + out = scatter_( + aggr, out, edge_index[0], dim_size=size + ) # Aggregated neighbors for each vertex + out = self.update(out, *update_args) + + return out + + def message(self, x_j): # pragma: no cover + r"""Constructs messages in analogy to :math:`\phi_{\mathbf{\Theta}}` for each edge in :math:`(i,j) \in \mathcal{E}`. Can take any argument which was initially passed to :meth:`propagate`. In addition, features can be lifted to the source node :math:`i` and target node :math:`j` by appending :obj:`_i` or :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.""" - return x_j + return x_j - def propagate_jump(self, aggr, edge_index, **kwargs): - """The initial call to start propagating messages. + def propagate_jump(self, aggr, edge_index, **kwargs): + """The initial call to start propagating messages. Takes in an aggregation scheme (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`), the edge indices, and all additional data which is needed to construct messages and to update node embeddings.""" - assert aggr in ['add', 'mean', 'max'] - kwargs['edge_index'] = edge_index - - - size = None - message_args = [] - for arg in self.message_args_jump: - if arg[-2:] == '_i': # If arguments ends with _i then include indic - tmp = kwargs[arg[:-2]] # Take the front part of the variable | Mostly it will be 'x', - size = tmp.size(0) - message_args.append(tmp[edge_index[0]]) # Lookup for head entities in edges - elif arg[-2:] == '_j': - tmp = kwargs[arg[:-2]] # tmp = kwargs['x'] - size = tmp.size(0) - message_args.append(tmp[edge_index[1]]) # Lookup for tail entities in edges - else: - message_args.append(kwargs[arg]) # Take things from kwargs - - update_args = [kwargs[arg] for arg in self.update_args] # Take update args from kwargs - - out = self.message_jump(*message_args) - out = scatter_(aggr, out, edge_index[0], dim_size=size) # Aggregated neighbors for each vertex - out = self.update(out, *update_args) - - return out - - def message_jump(self, x_j): # pragma: no cover - r"""Constructs messages in analogy to :math:`\phi_{\mathbf{\Theta}}` + assert aggr in ["add", "mean", "max"] + kwargs["edge_index"] = edge_index + + size = None + message_args = [] + for arg in self.message_args_jump: + if arg[-2:] == "_i": # If arguments ends with _i then include indic + tmp = kwargs[ + arg[:-2] + ] # Take the front part of the variable | Mostly it will be 'x', + size = tmp.size(0) + message_args.append( + tmp[edge_index[0]] + ) # Lookup for head entities in edges + elif arg[-2:] == "_j": + tmp = kwargs[arg[:-2]] # tmp = kwargs['x'] + size = tmp.size(0) + message_args.append( + tmp[edge_index[1]] + ) # Lookup for tail entities in edges + else: + message_args.append(kwargs[arg]) # Take things from kwargs + + update_args = [ + kwargs[arg] for arg in self.update_args + ] # Take update args from kwargs + + out = self.message_jump(*message_args) + out = scatter_( + aggr, out, edge_index[0], dim_size=size + ) # Aggregated neighbors for each vertex + out = self.update(out, *update_args) + + return out + + def message_jump(self, x_j): # pragma: no cover + r"""Constructs messages in analogy to :math:`\phi_{\mathbf{\Theta}}` for each edge in :math:`(i,j) \in \mathcal{E}`. Can take any argument which was initially passed to :meth:`propagate`. In addition, features can be lifted to the source node :math:`i` and target node :math:`j` by appending :obj:`_i` or :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.""" - return x_j + return x_j - def update(self, aggr_out): # pragma: no cover - r"""Updates node embeddings in analogy to + def update(self, aggr_out): # pragma: no cover + r"""Updates node embeddings in analogy to :math:`\gamma_{\mathbf{\Theta}}` for each node :math:`i \in \mathcal{V}`. Takes in the output of aggregation as first argument and any argument which was initially passed to :meth:`propagate`.""" - return aggr_out + return aggr_out diff --git a/models/models.py b/models/models.py index 440d1af..b7c521b 100644 --- a/models/models.py +++ b/models/models.py @@ -9,6 +9,7 @@ from .MGCN import * from .MGCNLayer import * + class TANGO(nn.Module): def __init__(self, num_e, num_rel, params, device, logger): super().__init__() @@ -31,11 +32,11 @@ def __init__(self, num_e, num_rel, params, device, logger): self.device = device self.logger = logger - if self.p.activation.lower() == 'tanh': + if self.p.activation.lower() == "tanh": self.act = torch.tanh - elif self.p.activation.lower() == 'relu': + elif self.p.activation.lower() == "relu": self.act = F.relu - elif self.p.activation.lower() == 'leakyrelu': + elif self.p.activation.lower() == "leakyrelu": self.act = F.leaky_relu # define loss @@ -59,7 +60,14 @@ def __init__(self, num_e, num_rel, params, device, logger): # score function TuckER if self.score_func.lower() == "tucker": - self.W_tk, self.input_dropout, self.hidden_dropout1, self.hidden_dropout2, self.bn0, self.bn1 = self.TuckER() + ( + self.W_tk, + self.input_dropout, + self.hidden_dropout1, + self.hidden_dropout2, + self.bn0, + self.bn1, + ) = self.TuckER() def get_param(self, shape): # a function to initialize embedding @@ -68,8 +76,18 @@ def get_param(self, shape): return param def add_base(self): - model = MGCNLayerWrapper(None, None, self.num_e, self.num_rel, self.act, drop1=self.drop, drop2=self.drop, - sub=None, rel=None, params=self.p) + model = MGCNLayerWrapper( + None, + None, + self.num_e, + self.num_rel, + self.act, + drop1=self.drop, + drop2=self.drop, + sub=None, + rel=None, + params=self.p, + ) model.to(self.device) return model @@ -78,12 +96,24 @@ def construct_gde_func(self): return gdefunc def construct_GDEBlock(self, gdefunc): - gde = ODEBlock(odefunc=gdefunc, method=self.solver, atol=self.atol, rtol=self.rtol, adjoint=self.adjoint_flag).to(self.device) + gde = ODEBlock( + odefunc=gdefunc, + method=self.solver, + atol=self.atol, + rtol=self.rtol, + adjoint=self.adjoint_flag, + ).to(self.device) return gde def TuckER(self): - W = torch.nn.Parameter(torch.tensor(np.random.uniform(-1, 1, (self.hidsize, self.hidsize, self.hidsize)), - dtype=torch.float, device=self.device, requires_grad=True)) + W = torch.nn.Parameter( + torch.tensor( + np.random.uniform(-1, 1, (self.hidsize, self.hidsize, self.hidsize)), + dtype=torch.float, + device=self.device, + requires_grad=True, + ) + ) input_dropout = torch.nn.Dropout(self.drop) hidden_dropout1 = torch.nn.Dropout(self.drop) hidden_dropout2 = torch.nn.Dropout(self.drop) @@ -101,7 +131,14 @@ def TuckER(self): def Jump(self): if self.p.rel_jump: - jump = MGCNConvLayer(self.hidsize, self.hidsize, act=self.act, params=self.p, isjump=True, diag=True) + jump = MGCNConvLayer( + self.hidsize, + self.hidsize, + act=self.act, + params=self.p, + isjump=True, + diag=True, + ) else: jump = GCNConvLayer(self.hidsize, self.hidsize, act=self.act, params=self.p) @@ -113,14 +150,17 @@ def loss_comp(self, sub, rel, emb, label, core, obj=None): score = self.score_comp(sub, rel, emb, core) return self.loss(score, obj) - def score_comp(self, sub, rel, emb, core): sub_emb, rel_emb, all_emb = self.find_related(sub, rel, emb) - if self.score_func.lower() == 'distmult': - obj_emb = torch.cat([torch.index_select(self.emb_e, 0, sub), sub_emb], dim=1) * rel_emb.repeat(1,2) - s = torch.mm(obj_emb, torch.cat([self.emb_e, all_emb], dim=1).transpose(1,0)) + if self.score_func.lower() == "distmult": + obj_emb = torch.cat( + [torch.index_select(self.emb_e, 0, sub), sub_emb], dim=1 + ) * rel_emb.repeat(1, 2) + s = torch.mm( + obj_emb, torch.cat([self.emb_e, all_emb], dim=1).transpose(1, 0) + ) - if self.score_func.lower() == 'tucker': + if self.score_func.lower() == "tucker": x = self.bn0(sub_emb) x = self.input_dropout(x) x = x.view(-1, 1, sub_emb.size(1)) @@ -138,8 +178,8 @@ def score_comp(self, sub, rel, emb, core): return s def find_related(self, sub, rel, emb): - x = emb[:self.num_e,:] - r = emb[self.num_e:,:] + x = emb[: self.num_e, :] + r = emb[self.num_e :, :] assert x.shape[0] == self.num_e assert r.shape[0] == self.num_rel * 2 sub_emb = torch.index_select(x, 0, sub) @@ -153,22 +193,73 @@ def push_data(self, *args): out_args.append(arg) return out_args - def forward(self, sub_tar, rel_tar, obj_tar, lab_tar, times, tar_times, edge_index_list, - edge_type_list, edge_id_jump, edge_w_jump, rel_jump): + def forward( + self, + sub_tar, + rel_tar, + obj_tar, + lab_tar, + times, + tar_times, + edge_index_list, + edge_type_list, + edge_id_jump, + edge_w_jump, + rel_jump, + ): # self.test_flag = 0 # push data onto gpu if self.p.jump: - [sub_tar, rel_tar, obj_tar, lab_tar, times, tar_times, edge_index_list, edge_type_list, edge_id_jump, edge_w_jump, rel_jump] = \ - self.push_data(sub_tar, rel_tar, obj_tar, lab_tar, times, tar_times, edge_index_list, edge_type_list, edge_id_jump, edge_w_jump, rel_jump) + [ + sub_tar, + rel_tar, + obj_tar, + lab_tar, + times, + tar_times, + edge_index_list, + edge_type_list, + edge_id_jump, + edge_w_jump, + rel_jump, + ] = self.push_data( + sub_tar, + rel_tar, + obj_tar, + lab_tar, + times, + tar_times, + edge_index_list, + edge_type_list, + edge_id_jump, + edge_w_jump, + rel_jump, + ) else: - [sub_tar, rel_tar, obj_tar, lab_tar, times, tar_times, edge_index_list, edge_type_list] = \ - self.push_data(sub_tar, rel_tar, obj_tar, lab_tar, times, tar_times, edge_index_list, edge_type_list) + [ + sub_tar, + rel_tar, + obj_tar, + lab_tar, + times, + tar_times, + edge_index_list, + edge_type_list, + ] = self.push_data( + sub_tar, + rel_tar, + obj_tar, + lab_tar, + times, + tar_times, + edge_index_list, + edge_type_list, + ) # for RE decoder # if self.score_func.lower() == 're' or self.score_func.lower(): # [obj_tar] = self.push_data(obj_tar) - emb = torch.cat([self.emb_e, self.emb_r], dim=0) for i in range(len(times)): @@ -177,38 +268,101 @@ def forward(self, sub_tar, rel_tar, obj_tar, lab_tar, times, tar_times, edge_ind # ODE if self.p.jump: if self.p.rel_jump: - self.odeblock.odefunc.set_jumpfunc(edge_id_jump[i], edge_w_jump[i], self.jump, jumpw=self.jump_weight, - skip=False, rel_jump=rel_jump[i]) + self.odeblock.odefunc.set_jumpfunc( + edge_id_jump[i], + edge_w_jump[i], + self.jump, + jumpw=self.jump_weight, + skip=False, + rel_jump=rel_jump[i], + ) else: - self.odeblock.odefunc.set_jumpfunc(edge_id_jump[i], edge_w_jump[i], self.jump, self.jump_weight, False) - - emb = self.odeblock.forward_nobatch(emb, start=times[i], end=times[i+1], cheby_grid=self.p.cheby_grid) + self.odeblock.odefunc.set_jumpfunc( + edge_id_jump[i], + edge_w_jump[i], + self.jump, + self.jump_weight, + False, + ) + + emb = self.odeblock.forward_nobatch( + emb, start=times[i], end=times[i + 1], cheby_grid=self.p.cheby_grid + ) else: # ODE if self.p.jump: if self.p.rel_jump: - self.odeblock.odefunc.set_jumpfunc(edge_id_jump[i], edge_w_jump[i], self.jump, - jumpw=self.jump_weight, - skip=False, rel_jump=rel_jump[i]) + self.odeblock.odefunc.set_jumpfunc( + edge_id_jump[i], + edge_w_jump[i], + self.jump, + jumpw=self.jump_weight, + skip=False, + rel_jump=rel_jump[i], + ) else: - self.odeblock.odefunc.set_jumpfunc(edge_id_jump[i], edge_w_jump[i], self.jump, self.jump_weight, - False) - - emb = self.odeblock.forward_nobatch(emb, start=times[i], end=tar_times[0], cheby_grid=self.p.cheby_grid) - - loss = self.loss_comp(sub_tar[0], rel_tar[0], emb, lab_tar[0], self.odeblock.odefunc, - obj=obj_tar[0]) + self.odeblock.odefunc.set_jumpfunc( + edge_id_jump[i], + edge_w_jump[i], + self.jump, + self.jump_weight, + False, + ) + + emb = self.odeblock.forward_nobatch( + emb, start=times[i], end=tar_times[0], cheby_grid=self.p.cheby_grid + ) + + loss = self.loss_comp( + sub_tar[0], + rel_tar[0], + emb, + lab_tar[0], + self.odeblock.odefunc, + obj=obj_tar[0], + ) return loss - def forward_eval(self, times, tar_times, edge_index_list, edge_type_list, edge_id_jump, edge_w_jump, rel_jump): + def forward_eval( + self, + times, + tar_times, + edge_index_list, + edge_type_list, + edge_id_jump, + edge_w_jump, + rel_jump, + ): # push data onto gpu if self.p.jump: - [times, tar_times, edge_index_list, edge_type_list, edge_id_jump, edge_w_jump, rel_jump] = \ - self.push_data(times, tar_times, edge_index_list, edge_type_list, edge_id_jump, edge_w_jump, rel_jump) + [ + times, + tar_times, + edge_index_list, + edge_type_list, + edge_id_jump, + edge_w_jump, + rel_jump, + ] = self.push_data( + times, + tar_times, + edge_index_list, + edge_type_list, + edge_id_jump, + edge_w_jump, + rel_jump, + ) else: - [times, tar_times, edge_index_list, edge_type_list, edge_index_list] = \ - self.push_data(times, tar_times, edge_index_list, edge_type_list, edge_index_list) + [ + times, + tar_times, + edge_index_list, + edge_type_list, + edge_index_list, + ] = self.push_data( + times, tar_times, edge_index_list, edge_type_list, edge_index_list + ) emb = torch.cat([self.emb_e, self.emb_r], dim=0) @@ -218,23 +372,47 @@ def forward_eval(self, times, tar_times, edge_index_list, edge_type_list, edge_i # ODE if self.p.jump: if self.p.rel_jump: - self.odeblock.odefunc.set_jumpfunc(edge_id_jump[i], edge_w_jump[i], self.jump, - jumpw=self.jump_weight, - skip=False, rel_jump=rel_jump[i]) + self.odeblock.odefunc.set_jumpfunc( + edge_id_jump[i], + edge_w_jump[i], + self.jump, + jumpw=self.jump_weight, + skip=False, + rel_jump=rel_jump[i], + ) else: - self.odeblock.odefunc.set_jumpfunc(edge_id_jump[i], edge_w_jump[i], self.jump, self.jump_weight, - False) - emb = self.odeblock.forward_nobatch(emb, start=times[i], end=times[i + 1], cheby_grid=self.p.cheby_grid) + self.odeblock.odefunc.set_jumpfunc( + edge_id_jump[i], + edge_w_jump[i], + self.jump, + self.jump_weight, + False, + ) + emb = self.odeblock.forward_nobatch( + emb, start=times[i], end=times[i + 1], cheby_grid=self.p.cheby_grid + ) else: # ODE if self.p.jump: if self.p.rel_jump: - self.odeblock.odefunc.set_jumpfunc(edge_id_jump[i], edge_w_jump[i], self.jump, - jumpw=self.jump_weight, - skip=False, rel_jump=rel_jump[i]) + self.odeblock.odefunc.set_jumpfunc( + edge_id_jump[i], + edge_w_jump[i], + self.jump, + jumpw=self.jump_weight, + skip=False, + rel_jump=rel_jump[i], + ) else: - self.odeblock.odefunc.set_jumpfunc(edge_id_jump[i], edge_w_jump[i], self.jump, self.jump_weight, - False) - emb = self.odeblock.forward_nobatch(emb, start=times[i], end=tar_times[0], cheby_grid=self.p.cheby_grid) - - return emb \ No newline at end of file + self.odeblock.odefunc.set_jumpfunc( + edge_id_jump[i], + edge_w_jump[i], + self.jump, + self.jump_weight, + False, + ) + emb = self.odeblock.forward_nobatch( + emb, start=times[i], end=tar_times[0], cheby_grid=self.p.cheby_grid + ) + + return emb diff --git a/models/odeblock.py b/models/odeblock.py index ee8d117..4aedf75 100644 --- a/models/odeblock.py +++ b/models/odeblock.py @@ -4,7 +4,14 @@ class ODEBlock(nn.Module): - def __init__(self, odefunc:nn.Module, method:str='dopri5', rtol:float=1e-3, atol:float=1e-4, adjoint:bool=True): + def __init__( + self, + odefunc: nn.Module, + method: str = "dopri5", + rtol: float = 1e-3, + atol: float = 1e-4, + adjoint: bool = True, + ): """ Standard ODEBlock class. Can handle all types of ODE functions :method:str = {'euler', 'rk4', 'dopri5', 'adams'} """ @@ -14,36 +21,68 @@ def __init__(self, odefunc:nn.Module, method:str='dopri5', rtol:float=1e-3, atol self.adjoint_flag = adjoint self.atol, self.rtol = atol, rtol - def forward(self, x:torch.Tensor, start, stop): + def forward(self, x: torch.Tensor, start, stop): self.integration_time = torch.tensor([start, stop]).float() self.integration_time = self.integration_time.type_as(x) if self.adjoint_flag: - out = torchdiffeq.odeint_adjoint(self.odefunc, x, self.integration_time, - rtol=self.rtol, atol=self.atol, method=self.method) + out = torchdiffeq.odeint_adjoint( + self.odefunc, + x, + self.integration_time, + rtol=self.rtol, + atol=self.atol, + method=self.method, + ) else: - out = torchdiffeq.odeint(self.odefunc, x, self.integration_time, - rtol=self.rtol, atol=self.atol, method=self.method) - + out = torchdiffeq.odeint( + self.odefunc, + x, + self.integration_time, + rtol=self.rtol, + atol=self.atol, + method=self.method, + ) + return out[-1] - def forward_nobatch(self, x: torch.Tensor, start: float, end: float, cheby_grid: int=0): + def forward_nobatch( + self, x: torch.Tensor, start: float, end: float, cheby_grid: int = 0 + ): self.integration_time = torch.tensor([start, end]).float() self.integration_time = self.integration_time.type_as(x) if self.adjoint_flag: - out = torchdiffeq.odeint_adjoint(self.odefunc, x, self.integration_time, - rtol=self.rtol, atol=self.atol, method=self.method, cheby_grid=cheby_grid) + out = torchdiffeq.odeint_adjoint( + self.odefunc, + x, + self.integration_time, + rtol=self.rtol, + atol=self.atol, + method=self.method, + cheby_grid=cheby_grid, + ) else: - out = torchdiffeq.odeint(self.odefunc, x, self.integration_time, - rtol=self.rtol, atol=self.atol, method=self.method) + out = torchdiffeq.odeint( + self.odefunc, + x, + self.integration_time, + rtol=self.rtol, + atol=self.atol, + method=self.method, + ) return out[-1] - def trajectory(self, x:torch.Tensor, T:int, num_points:int): + def trajectory(self, x: torch.Tensor, T: int, num_points: int): self.integration_time = torch.linspace(0, t_end, num_points) self.integration_time = self.integration_time.type_as(x) - out = torchdiffeq.odeint(self.odefunc, x, self.integration_time, - rtol=self.rtol, atol=self.atol, method=self.method) + out = torchdiffeq.odeint( + self.odefunc, + x, + self.integration_time, + rtol=self.rtol, + atol=self.atol, + method=self.method, + ) return out - diff --git a/torchdiffeq/__init__.py b/torchdiffeq/__init__.py index 2858051..86ad134 100644 --- a/torchdiffeq/__init__.py +++ b/torchdiffeq/__init__.py @@ -1,3 +1,4 @@ from ._impl import odeint from ._impl import odeint_adjoint + __version__ = "0.1.0" diff --git a/torchdiffeq/_impl/adams.py b/torchdiffeq/_impl/adams.py index cfcca6d..fa44a5d 100644 --- a/torchdiffeq/_impl/adams.py +++ b/torchdiffeq/_impl/adams.py @@ -2,20 +2,38 @@ import torch from .solvers import AdaptiveStepsizeODESolver from .misc import ( - _handle_unused_kwargs, _select_initial_step, - _optimal_step_size, _compute_error_ratio + _handle_unused_kwargs, + _select_initial_step, + _optimal_step_size, + _compute_error_ratio, ) _MIN_ORDER = 1 _MAX_ORDER = 12 gamma_star = [ - 1, -1 / 2, -1 / 12, -1 / 24, -19 / 720, -3 / 160, -863 / 60480, -275 / 24192, -33953 / 3628800, -0.00789255, - -0.00678585, -0.00592406, -0.00523669, -0.0046775, -0.00421495, -0.0038269 + 1, + -1 / 2, + -1 / 12, + -1 / 24, + -19 / 720, + -3 / 160, + -863 / 60480, + -275 / 24192, + -33953 / 3628800, + -0.00789255, + -0.00678585, + -0.00592406, + -0.00523669, + -0.0046775, + -0.00421495, + -0.0038269, ] -class _VCABMState(collections.namedtuple('_VCABMState', 'y_n, prev_f, prev_t, next_t, phi, order')): +class _VCABMState( + collections.namedtuple("_VCABMState", "y_n, prev_f, prev_t, next_t, phi, order") +): """Saved state of the variable step size Adams-Bashforth-Moulton solver as described in Solving Ordinary Differential Equations I - Nonstiff Problems III.5 @@ -54,15 +72,29 @@ def compute_implicit_phi(explicit_phi, f_n, k): implicit_phi = collections.deque(maxlen=k) implicit_phi.append(f_n) for j in range(1, k): - implicit_phi.append(tuple(iphi_ - ephi_ for iphi_, ephi_ in zip(implicit_phi[j - 1], explicit_phi[j - 1]))) + implicit_phi.append( + tuple( + iphi_ - ephi_ + for iphi_, ephi_ in zip(implicit_phi[j - 1], explicit_phi[j - 1]) + ) + ) return implicit_phi class VariableCoefficientAdamsBashforth(AdaptiveStepsizeODESolver): - def __init__( - self, func, y0, rtol, atol, implicit=True, first_step=None, max_order=_MAX_ORDER, safety=0.9, ifactor=10.0, dfactor=0.2, - **unused_kwargs + self, + func, + y0, + rtol, + atol, + implicit=True, + first_step=None, + max_order=_MAX_ORDER, + safety=0.9, + ifactor=10.0, + dfactor=0.2, + **unused_kwargs, ): _handle_unused_kwargs(self, unused_kwargs) del unused_kwargs @@ -74,9 +106,15 @@ def __init__( self.implicit = implicit self.first_step = first_step self.max_order = int(max(_MIN_ORDER, min(max_order, _MAX_ORDER))) - self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) - self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) - self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) + self.safety = _convert_to_tensor( + safety, dtype=torch.float64, device=y0[0].device + ) + self.ifactor = _convert_to_tensor( + ifactor, dtype=torch.float64, device=y0[0].device + ) + self.dfactor = _convert_to_tensor( + dfactor, dtype=torch.float64, device=y0[0].device + ) def before_integrate(self, t): prev_f = collections.deque(maxlen=self.max_order + 1) @@ -89,11 +127,17 @@ def before_integrate(self, t): prev_f.appendleft(f0) phi.appendleft(f0) if self.first_step is None: - first_step = _select_initial_step(self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0).to(t) + first_step = _select_initial_step( + self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0 + ).to(t) else: - first_step = _select_initial_step(self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0).to(t) + first_step = _select_initial_step( + self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0 + ).to(t) - self.vcabm_state = _VCABMState(self.y0, prev_f, prev_t, next_t=t[0] + first_step, phi=phi, order=1) + self.vcabm_state = _VCABMState( + self.y0, prev_f, prev_t, next_t=t[0] + first_step, phi=phi, order=1 + ) def advance(self, final_t): final_t = _convert_to_tensor(final_t).to(self.vcabm_state.prev_t[0]) @@ -106,14 +150,17 @@ def _adaptive_adams_step(self, vcabm_state, final_t): y0, prev_f, prev_t, next_t, prev_phi, order = vcabm_state if next_t > final_t: next_t = final_t - dt = (next_t - prev_t[0]) + dt = next_t - prev_t[0] dt_cast = dt.to(y0[0]) # Explicit predictor step. g, phi = g_and_explicit_phi(prev_t, next_t, prev_phi, order) g = g.to(y0[0]) p_next = tuple( - y0_ + _scaled_dot_product(dt_cast, g[:max(1, order - 1)], phi_[:max(1, order - 1)]) + y0_ + + _scaled_dot_product( + dt_cast, g[: max(1, order - 1)], phi_[: max(1, order - 1)] + ) for y0_, phi_ in zip(y0, tuple(zip(*phi))) ) @@ -123,7 +170,8 @@ def _adaptive_adams_step(self, vcabm_state, final_t): # Implicit corrector step. y_next = tuple( - p_next_ + dt_cast * g[order - 1] * iphi_ for p_next_, iphi_ in zip(p_next, implicit_phi_p[order - 1]) + p_next_ + dt_cast * g[order - 1] * iphi_ + for p_next_, iphi_ in zip(p_next, implicit_phi_p[order - 1]) ) # Error estimation. @@ -131,14 +179,21 @@ def _adaptive_adams_step(self, vcabm_state, final_t): atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_)) for atol_, rtol_, y0_, y1_ in zip(self.atol, self.rtol, y0, y_next) ) - local_error = tuple(dt_cast * (g[order] - g[order - 1]) * iphi_ for iphi_ in implicit_phi_p[order]) + local_error = tuple( + dt_cast * (g[order] - g[order - 1]) * iphi_ + for iphi_ in implicit_phi_p[order] + ) error_k = _compute_error_ratio(local_error, tolerance) accept_step = (torch.tensor(error_k) <= 1).all() if not accept_step: # Retry with adjusted step size if step is rejected. - dt_next = _optimal_step_size(dt, error_k, self.safety, self.ifactor, self.dfactor, order=order) - return _VCABMState(y0, prev_f, prev_t, prev_t[0] + dt_next, prev_phi, order=order) + dt_next = _optimal_step_size( + dt, error_k, self.safety, self.ifactor, self.dfactor, order=order + ) + return _VCABMState( + y0, prev_f, prev_t, prev_t[0] + dt_next, prev_phi, order=order + ) # We accept the step. Evaluate f and update phi. next_f0 = self.func(next_t.to(p_next[0]), y_next) @@ -150,25 +205,43 @@ def _adaptive_adams_step(self, vcabm_state, final_t): next_order = min(order + 1, 3, self.max_order) else: error_km1 = _compute_error_ratio( - tuple(dt_cast * (g[order - 1] - g[order - 2]) * iphi_ for iphi_ in implicit_phi_p[order - 1]), tolerance + tuple( + dt_cast * (g[order - 1] - g[order - 2]) * iphi_ + for iphi_ in implicit_phi_p[order - 1] + ), + tolerance, ) error_km2 = _compute_error_ratio( - tuple(dt_cast * (g[order - 2] - g[order - 3]) * iphi_ for iphi_ in implicit_phi_p[order - 2]), tolerance + tuple( + dt_cast * (g[order - 2] - g[order - 3]) * iphi_ + for iphi_ in implicit_phi_p[order - 2] + ), + tolerance, ) if min(error_km1 + error_km2) < max(error_k): next_order = order - 1 elif order < self.max_order: error_kp1 = _compute_error_ratio( - tuple(dt_cast * gamma_star[order] * iphi_ for iphi_ in implicit_phi_p[order]), tolerance + tuple( + dt_cast * gamma_star[order] * iphi_ + for iphi_ in implicit_phi_p[order] + ), + tolerance, ) if max(error_kp1) < max(error_k): next_order = order + 1 # Keep step size constant if increasing order. Else use adaptive step size. - dt_next = dt if next_order > order else _optimal_step_size( - dt, error_k, self.safety, self.ifactor, self.dfactor, order=order + 1 + dt_next = ( + dt + if next_order > order + else _optimal_step_size( + dt, error_k, self.safety, self.ifactor, self.dfactor, order=order + 1 + ) ) prev_f.appendleft(next_f0) prev_t.appendleft(next_t) - return _VCABMState(p_next, prev_f, prev_t, next_t + dt_next, implicit_phi, order=next_order) + return _VCABMState( + p_next, prev_f, prev_t, next_t + dt_next, implicit_phi, order=next_order + ) diff --git a/torchdiffeq/_impl/adaptive_heun.py b/torchdiffeq/_impl/adaptive_heun.py index 10f9696..4433179 100644 --- a/torchdiffeq/_impl/adaptive_heun.py +++ b/torchdiffeq/_impl/adaptive_heun.py @@ -3,20 +3,13 @@ _ADAPTIVE_HEUN_TABLEAU = _ButcherTableau( - alpha=torch.tensor([1.], dtype=torch.float64), - beta=[ - torch.tensor([1.], dtype=torch.float64), - ], + alpha=torch.tensor([1.0], dtype=torch.float64), + beta=[torch.tensor([1.0], dtype=torch.float64),], c_sol=torch.tensor([0.5, 0.5], dtype=torch.float64), - c_error=torch.tensor([ - 0.5, - -0.5, - ], dtype=torch.float64), + c_error=torch.tensor([0.5, -0.5,], dtype=torch.float64), ) -_AH_C_MID = torch.tensor([ - 0.5, 0. -], dtype=torch.float64) +_AH_C_MID = torch.tensor([0.5, 0.0], dtype=torch.float64) class AdaptiveHeunSolver(RKAdaptiveStepsizeODESolver): diff --git a/torchdiffeq/_impl/adjoint.py b/torchdiffeq/_impl/adjoint.py index 9cb1ef7..f4b3e82 100644 --- a/torchdiffeq/_impl/adjoint.py +++ b/torchdiffeq/_impl/adjoint.py @@ -1,14 +1,38 @@ import torch import torch.nn as nn from .odeint import SOLVERS, odeint -from .misc import _check_inputs, _flat_to_shape, _rms_norm, _mixed_linf_rms_norm, _wrap_norm, cby_grid_type1, barycentric_weights, _cby1_interp +from .misc import ( + _check_inputs, + _flat_to_shape, + _rms_norm, + _mixed_linf_rms_norm, + _wrap_norm, + cby_grid_type1, + barycentric_weights, + _cby1_interp, +) class OdeintAdjointMethod(torch.autograd.Function): - @staticmethod - def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, adjoint_rtol, adjoint_atol, adjoint_method, cheby_grid, - adjoint_options, t_requires_grad, *adjoint_params): + def forward( + ctx, + shapes, + func, + y0, + t, + rtol, + atol, + method, + options, + adjoint_rtol, + adjoint_atol, + adjoint_method, + cheby_grid, + adjoint_options, + t_requires_grad, + *adjoint_params, + ): ctx.shapes = shapes ctx.func = func @@ -17,24 +41,41 @@ def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, adjoint_rtol, ctx.adjoint_method = adjoint_method ctx.adjoint_options = adjoint_options ctx.t_requires_grad = t_requires_grad - ctx.cheby = cheby_grid # cheby node number + ctx.cheby = cheby_grid # cheby node number ########################################################################################################### if cheby_grid > 0: - #ctx.cheby = 1 # cheby flag + # ctx.cheby = 1 # cheby flag num_cby_grids = cheby_grid - ctx.cby_grids = torch.tensor(cby_grid_type1(t_min=t[0].item(), t_max=t[-1].item(), n=num_cby_grids), - device=y0[0].device, dtype=torch.float32) - ctx.weights = torch.tensor(barycentric_weights(num_cby_grids), device=y0[0].device, dtype=torch.float32) + ctx.cby_grids = torch.tensor( + cby_grid_type1(t_min=t[0].item(), t_max=t[-1].item(), n=num_cby_grids), + device=y0[0].device, + dtype=torch.float32, + ) + ctx.weights = torch.tensor( + barycentric_weights(num_cby_grids), + device=y0[0].device, + dtype=torch.float32, + ) with torch.no_grad(): - y = odeint(func, y0, ctx.cby_grids, rtol=rtol, atol=atol, method=method, options=options) + y = odeint( + func, + y0, + ctx.cby_grids, + rtol=rtol, + atol=atol, + method=method, + options=options, + ) ctx.values = y y = torch.cat((y[0:1], y[-1:]), 0) ctx.save_for_backward(t, y, *adjoint_params) else: - #ctx.cheby = 0 # cheby flag + # ctx.cheby = 0 # cheby flag with torch.no_grad(): - y = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options) + y = odeint( + func, y0, t, rtol=rtol, atol=atol, method=method, options=options + ) ctx.save_for_backward(t, y, *adjoint_params) return y @@ -69,19 +110,19 @@ def cby_interpolate(_t): # We assume that any grid points are given to us ordered in the same direction as for the forward pass (for # compatibility with setting adjoint_options = options), so we need to flip them around here. try: - grid_points = adjoint_options['grid_points'] + grid_points = adjoint_options["grid_points"] except KeyError: pass else: - adjoint_options['grid_points'] = grid_points.flip(0) + adjoint_options["grid_points"] = grid_points.flip(0) # Backward compatibility: by default use a mixed L-infinity/RMS norm over the input, where we treat t, each # element of y, and each element of adj_y separately over the Linf, but consider all the parameters # together. - #if 'norm' not in adjoint_options: + # if 'norm' not in adjoint_options: # if shapes is None: # shapes = [y[-1].shape] # [-1] because y has shape (len(t), *y0.shape) - # adj_t, y, adj_y, adj_params, corresponding to the order in aug_state below + # adj_t, y, adj_y, adj_params, corresponding to the order in aug_state below # adjoint_shapes = [torch.Size(())] + shapes + shapes + [torch.Size([sum(param.numel() # for param in adjoint_params)])] # adjoint_options['norm'] = _mixed_linf_rms_norm(adjoint_shapes) @@ -92,12 +133,23 @@ def cby_interpolate(_t): ################################## if cheby_flag: # For interpolation, without y - aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), grad_y[-1]] # vjp_t, vjp_y - aug_state.extend([torch.zeros_like(param) for param in adjoint_params]) # vjp_params + aug_state = [ + torch.zeros((), dtype=y.dtype, device=y.device), + grad_y[-1], + ] # vjp_t, vjp_y + aug_state.extend( + [torch.zeros_like(param) for param in adjoint_params] + ) # vjp_params else: # [-1] because y and grad_y are both of shape (len(t), *y0.shape) - aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), y[-1], grad_y[-1]] # vjp_t, y, vjp_y - aug_state.extend([torch.zeros_like(param) for param in adjoint_params]) # vjp_params + aug_state = [ + torch.zeros((), dtype=y.dtype, device=y.device), + y[-1], + grad_y[-1], + ] # vjp_t, y, vjp_y + aug_state.extend( + [torch.zeros_like(param) for param in adjoint_params] + ) # vjp_params ########################################################################################################### ################################## @@ -127,20 +179,31 @@ def augmented_dynamics_cheby(t, y_aug): # Workaround for PyTorch bug #39784 _t = torch.as_strided(t, (), ()) _y = torch.as_strided(y, (), ()) - _params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params) + _params = tuple( + torch.as_strided(param, (), ()) for param in adjoint_params + ) vjp_t, vjp_y, *vjp_params = torch.autograd.grad( - func_eval, (t, y) + adjoint_params, -adj_y, - allow_unused=True, retain_graph=True + func_eval, + (t, y) + adjoint_params, + -adj_y, + allow_unused=True, + retain_graph=True, ) # autograd.grad returns None if no gradient, set to zero. vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y - vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param - for param, vjp_param in zip(adjoint_params, vjp_params)] + vjp_params = [ + torch.zeros_like(param) if vjp_param is None else vjp_param + for param, vjp_param in zip(adjoint_params, vjp_params) + ] - return (vjp_t, vjp_y, *vjp_params) # For interpolation, without func_eval + return ( + vjp_t, + vjp_y, + *vjp_params, + ) # For interpolation, without func_eval def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with @@ -162,22 +225,28 @@ def augmented_dynamics(t, y_aug): # Workaround for PyTorch bug #39784 _t = torch.as_strided(t, (), ()) _y = torch.as_strided(y, (), ()) - _params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params) + _params = tuple( + torch.as_strided(param, (), ()) for param in adjoint_params + ) vjp_t, vjp_y, *vjp_params = torch.autograd.grad( - func_eval, (t, y) + adjoint_params, -adj_y, - allow_unused=True, retain_graph=True + func_eval, + (t, y) + adjoint_params, + -adj_y, + allow_unused=True, + retain_graph=True, ) # autograd.grad returns None if no gradient, set to zero. vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y - vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param - for param, vjp_param in zip(adjoint_params, vjp_params)] + vjp_params = [ + torch.zeros_like(param) if vjp_param is None else vjp_param + for param, vjp_param in zip(adjoint_params, vjp_params) + ] return (vjp_t, func_eval, vjp_y, *vjp_params) - ################################## # Solve adjoint ODE # ################################## @@ -198,24 +267,40 @@ def augmented_dynamics(t, y_aug): if cheby_flag: # Run the augmented system backwards in time. aug_state = odeint( - augmented_dynamics_cheby, tuple(aug_state), - t[i - 1:i + 1].flip(0), - rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options + augmented_dynamics_cheby, + tuple(aug_state), + t[i - 1 : i + 1].flip(0), + rtol=adjoint_rtol, + atol=adjoint_atol, + method=adjoint_method, + options=adjoint_options, ) - aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value - ###################################################################################################### + aug_state = [ + a[1] for a in aug_state + ] # extract just the t[i - 1] value + ###################################################################################################### aug_state[1] += grad_y[i - 1] # For Interpolation, y is neglected ###################################################################################################### else: # Run the augmented system backwards in time. aug_state = odeint( - augmented_dynamics, tuple(aug_state), - t[i - 1:i + 1].flip(0), - rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options + augmented_dynamics, + tuple(aug_state), + t[i - 1 : i + 1].flip(0), + rtol=adjoint_rtol, + atol=adjoint_atol, + method=adjoint_method, + options=adjoint_options, ) - aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value - aug_state[1] = y[i - 1] # update to use our forward-pass estimate of the state - aug_state[2] += grad_y[i - 1] # update any gradients wrt state at this time point + aug_state = [ + a[1] for a in aug_state + ] # extract just the t[i - 1] value + aug_state[1] = y[ + i - 1 + ] # update to use our forward-pass estimate of the state + aug_state[2] += grad_y[ + i - 1 + ] # update any gradients wrt state at this time point if t_requires_grad: time_vjps[0] = aug_state[0] @@ -230,18 +315,49 @@ def augmented_dynamics(t, y_aug): adj_params = aug_state[3:] ########################################################################################################### - return (None, None, adj_y, time_vjps, None, None, None, None, None, None, None, None, None, None, *adj_params) - - -def odeint_adjoint(func, y0, t, rtol=1e-7, atol=1e-9, method=None, options=None, adjoint_rtol=None, adjoint_atol=None, - adjoint_method=None, cheby_grid=0, adjoint_options=None, adjoint_params=None): + return ( + None, + None, + adj_y, + time_vjps, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + *adj_params, + ) + + +def odeint_adjoint( + func, + y0, + t, + rtol=1e-7, + atol=1e-9, + method=None, + options=None, + adjoint_rtol=None, + adjoint_atol=None, + adjoint_method=None, + cheby_grid=0, + adjoint_options=None, + adjoint_params=None, +): # We need this in order to access the variables inside this module, # since we have no other way of getting variables along the execution path. if adjoint_params is None and not isinstance(func, nn.Module): - raise ValueError('func must be an instance of nn.Module to specify the adjoint parameters; alternatively they ' - 'can be specified explicitly via the `adjoint_params` argument. If there are no parameters ' - 'then it is allowable to set `adjoint_params=()`.') + raise ValueError( + "func must be an instance of nn.Module to specify the adjoint parameters; alternatively they " + "can be specified explicitly via the `adjoint_params` argument. If there are no parameters " + "then it is allowable to set `adjoint_params=()`." + ) # Must come before _check_inputs as we don't want to use normalised input (in particular any changes to options) if adjoint_rtol is None: @@ -251,21 +367,46 @@ def odeint_adjoint(func, y0, t, rtol=1e-7, atol=1e-9, method=None, options=None, if adjoint_method is None: adjoint_method = method if adjoint_options is None: - adjoint_options = {k: v for k, v in options.items() if k != "norm"} if options is not None else {} + adjoint_options = ( + {k: v for k, v in options.items() if k != "norm"} + if options is not None + else {} + ) if adjoint_params is None: adjoint_params = tuple(func.parameters()) # Normalise to non-tupled input - shapes, func, y0, t, rtol, atol, method, options = _check_inputs(func, y0, t, rtol, atol, method, options, SOLVERS) + shapes, func, y0, t, rtol, atol, method, options = _check_inputs( + func, y0, t, rtol, atol, method, options, SOLVERS + ) if "norm" in options and "norm" not in adjoint_options: - adjoint_shapes = [torch.Size(()), y0.shape, y0.shape] + [torch.Size([sum(param.numel() for param in adjoint_params)])] - adjoint_options["norm"] = _wrap_norm([_rms_norm, options["norm"], options["norm"]], adjoint_shapes) - - #Odeint = OdeintAdjointMethod() - #Odeint.add_cheby(cheby_grid=cheby_grid) - solution = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, adjoint_rtol, adjoint_atol, - adjoint_method, cheby_grid, adjoint_options, t.requires_grad, *adjoint_params) + adjoint_shapes = [torch.Size(()), y0.shape, y0.shape] + [ + torch.Size([sum(param.numel() for param in adjoint_params)]) + ] + adjoint_options["norm"] = _wrap_norm( + [_rms_norm, options["norm"], options["norm"]], adjoint_shapes + ) + + # Odeint = OdeintAdjointMethod() + # Odeint.add_cheby(cheby_grid=cheby_grid) + solution = OdeintAdjointMethod.apply( + shapes, + func, + y0, + t, + rtol, + atol, + method, + options, + adjoint_rtol, + adjoint_atol, + adjoint_method, + cheby_grid, + adjoint_options, + t.requires_grad, + *adjoint_params, + ) if shapes is not None: solution = _flat_to_shape(solution, (len(t),), shapes) diff --git a/torchdiffeq/_impl/bosh3.py b/torchdiffeq/_impl/bosh3.py index aec6c2b..6d5fa4b 100644 --- a/torchdiffeq/_impl/bosh3.py +++ b/torchdiffeq/_impl/bosh3.py @@ -3,17 +3,19 @@ _BOGACKI_SHAMPINE_TABLEAU = _ButcherTableau( - alpha=torch.tensor([1/2, 3/4, 1.], dtype=torch.float64), + alpha=torch.tensor([1 / 2, 3 / 4, 1.0], dtype=torch.float64), beta=[ - torch.tensor([1/2], dtype=torch.float64), - torch.tensor([0., 3/4], dtype=torch.float64), - torch.tensor([2/9, 1/3, 4/9], dtype=torch.float64) + torch.tensor([1 / 2], dtype=torch.float64), + torch.tensor([0.0, 3 / 4], dtype=torch.float64), + torch.tensor([2 / 9, 1 / 3, 4 / 9], dtype=torch.float64), ], - c_sol=torch.tensor([2/9, 1/3, 4/9, 0.], dtype=torch.float64), - c_error=torch.tensor([2/9-7/24, 1/3-1/4, 4/9-1/3, -1/8], dtype=torch.float64), + c_sol=torch.tensor([2 / 9, 1 / 3, 4 / 9, 0.0], dtype=torch.float64), + c_error=torch.tensor( + [2 / 9 - 7 / 24, 1 / 3 - 1 / 4, 4 / 9 - 1 / 3, -1 / 8], dtype=torch.float64 + ), ) -_BS_C_MID = torch.tensor([ 0., 0.5, 0., 0. ], dtype=torch.float64) +_BS_C_MID = torch.tensor([0.0, 0.5, 0.0, 0.0], dtype=torch.float64) class Bosh3Solver(RKAdaptiveStepsizeODESolver): diff --git a/torchdiffeq/_impl/dopri5.py b/torchdiffeq/_impl/dopri5.py index 1a925ef..858edf6 100644 --- a/torchdiffeq/_impl/dopri5.py +++ b/torchdiffeq/_impl/dopri5.py @@ -3,31 +3,53 @@ _DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau( - alpha=torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], dtype=torch.float64), + alpha=torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1.0, 1.0], dtype=torch.float64), beta=[ torch.tensor([1 / 5], dtype=torch.float64), torch.tensor([3 / 40, 9 / 40], dtype=torch.float64), torch.tensor([44 / 45, -56 / 15, 32 / 9], dtype=torch.float64), - torch.tensor([19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], dtype=torch.float64), - torch.tensor([9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], dtype=torch.float64), - torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], dtype=torch.float64), + torch.tensor( + [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], dtype=torch.float64 + ), + torch.tensor( + [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], + dtype=torch.float64, + ), + torch.tensor( + [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], + dtype=torch.float64, + ), ], - c_sol=torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], dtype=torch.float64), - c_error=torch.tensor([ - 35 / 384 - 1951 / 21600, - 0, - 500 / 1113 - 22642 / 50085, - 125 / 192 - 451 / 720, - -2187 / 6784 - -12231 / 42400, - 11 / 84 - 649 / 6300, - -1. / 60., - ], dtype=torch.float64), + c_sol=torch.tensor( + [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], + dtype=torch.float64, + ), + c_error=torch.tensor( + [ + 35 / 384 - 1951 / 21600, + 0, + 500 / 1113 - 22642 / 50085, + 125 / 192 - 451 / 720, + -2187 / 6784 - -12231 / 42400, + 11 / 84 - 649 / 6300, + -1.0 / 60.0, + ], + dtype=torch.float64, + ), ) -DPS_C_MID = torch.tensor([ - 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2, - 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 -], dtype=torch.float64) +DPS_C_MID = torch.tensor( + [ + 6025192743 / 30085553152 / 2, + 0, + 51252292925 / 65400821598 / 2, + -2691868925 / 45128329728 / 2, + 187940372067 / 1594534317056 / 2, + -1776094331 / 19743644256 / 2, + 11237099 / 235043384 / 2, + ], + dtype=torch.float64, +) class Dopri5Solver(RKAdaptiveStepsizeODESolver): diff --git a/torchdiffeq/_impl/dopri8.py b/torchdiffeq/_impl/dopri8.py index 1f9fb22..13bbdf8 100644 --- a/torchdiffeq/_impl/dopri8.py +++ b/torchdiffeq/_impl/dopri8.py @@ -2,63 +2,240 @@ from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver -A = [ 1/18, 1/12, 1/8, 5/16, 3/8, 59/400, 93/200, 5490023248/9719169821, 13/20, 1201146811/1299019798, 1, 1, 1] - -B = [ - [1/18], - - [1/48, 1/16], - - [1/32, 0, 3/32], - - [5/16, 0, -75/64, 75/64], - - [3/80, 0, 0, 3/16, 3/20], - - [29443841/614563906, 0, 0, 77736538/692538347, -28693883/1125000000, 23124283/1800000000], - - [16016141/946692911, 0, 0, 61564180/158732637, 22789713/633445777, 545815736/2771057229, -180193667/1043307555], - - [39632708/573591083, 0, 0, -433636366/683701615, -421739975/2616292301, 100302831/723423059, 790204164/839813087, 800635310/3783071287], - - [246121993/1340847787, 0, 0, -37695042795/15268766246, -309121744/1061227803, -12992083/490766935, 6005943493/2108947869, 393006217/1396673457, 123872331/1001029789], - - [-1028468189/846180014, 0, 0, 8478235783/508512852, 1311729495/1432422823, -10304129995/1701304382, -48777925059/3047939560, 15336726248/1032824649, -45442868181/3398467696, 3065993473/597172653], - - [185892177/718116043, 0, 0, -3185094517/667107341, -477755414/1098053517, -703635378/230739211, 5731566787/1027545527, 5232866602/850066563, -4093664535/808688257, 3962137247/1805957418, 65686358/487910083], - - [403863854/491063109, 0, 0, -5068492393/434740067, -411421997/543043805, 652783627/914296604, 11173962825/925320556, -13158990841/6184727034, 3936647629/1978049680, -160528059/685178525, 248638103/1413531060, 0], - - [ 14005451/335480064, 0, 0, 0, 0, -59238493/1068277825, 181606767/758867731, 561292985/797845732, -1041891430/1371343529, 760417239/1151165299, 118820643/751138087, -528747749/2220607170, 1/4] +A = [ + 1 / 18, + 1 / 12, + 1 / 8, + 5 / 16, + 3 / 8, + 59 / 400, + 93 / 200, + 5490023248 / 9719169821, + 13 / 20, + 1201146811 / 1299019798, + 1, + 1, + 1, ] -C_sol = [ 14005451/335480064, 0, 0, 0, 0, -59238493/1068277825, 181606767/758867731, 561292985/797845732, -1041891430/1371343529, 760417239/1151165299, 118820643/751138087, -528747749/2220607170, 1/4, 0] - -C_err = [ 14005451/335480064 - 13451932/455176623, 0, 0, 0, 0, -59238493/1068277825 - -808719846/976000145, 181606767/758867731 - 1757004468/5645159321, 561292985/797845732 - 656045339/265891186, -1041891430/1371343529 - -3867574721/1518517206, 760417239/1151165299 - 465885868/322736535, 118820643/751138087 - 53011238/667516719, -528747749/2220607170 - 2/45, 1/4, 0] - -h = 1/2 - -C_mid = [0.] * 14 - -C_mid[0] = (- 6.3448349392860401388*(h**5) + 22.1396504998094068976*(h**4) - 30.0610568289666450593*(h**3) + 19.9990069333683970610*(h**2) - 6.6910181737837595697*h + 1.0) / (1/h) - -C_mid[5] = (- 39.6107919852202505218*(h**5) + 116.4422149550342161651*(h**4) - 121.4999627731334642623*(h**3) + 52.2273532792945524050*(h**2) - 7.6142658045872677172*h) / (1/h) - -C_mid[6] = (20.3761213808791436958*(h**5) - 67.1451318825957197185*(h**4) + 83.1721004639847717481*(h**3) - 46.8919164181093621583*(h**2) + 10.7281392630428866124*h) / (1/h) - -C_mid[7] = (7.3347098826795362023*(h**5) - 16.5672243527496524646*(h**4) + 9.5724507555993664382*(h**3) - 0.1890893225010595467*(h**2) + 0.5526637063753648783*h) / (1/h) - -C_mid[8] = (32.8801774352459155182*(h**5) - 89.9916014847245016028*(h**4) + 87.8406057677205645007*(h**3) - 35.7075975946222072821*(h**2) + 4.2186562625665153803*h) / (1/h) - -C_mid[9] = (- 10.1588990526426760954*(h**5) + 22.6237489648532849093*(h**4) - 17.4152107770762969005*(h**3) + 6.2736448083240352160*(h**2) - 0.6627209125361597559*h) / (1/h) - -C_mid[10] = (- 12.5401268098782561200*(h**5) + 32.2362340167355370113*(h**4) - 28.5903289514790976966*(h**3) + 10.3160881272450748458*(h**2) - 1.2636789001135462218*h) / (1/h) +B = [ + [1 / 18], + [1 / 48, 1 / 16], + [1 / 32, 0, 3 / 32], + [5 / 16, 0, -75 / 64, 75 / 64], + [3 / 80, 0, 0, 3 / 16, 3 / 20], + [ + 29443841 / 614563906, + 0, + 0, + 77736538 / 692538347, + -28693883 / 1125000000, + 23124283 / 1800000000, + ], + [ + 16016141 / 946692911, + 0, + 0, + 61564180 / 158732637, + 22789713 / 633445777, + 545815736 / 2771057229, + -180193667 / 1043307555, + ], + [ + 39632708 / 573591083, + 0, + 0, + -433636366 / 683701615, + -421739975 / 2616292301, + 100302831 / 723423059, + 790204164 / 839813087, + 800635310 / 3783071287, + ], + [ + 246121993 / 1340847787, + 0, + 0, + -37695042795 / 15268766246, + -309121744 / 1061227803, + -12992083 / 490766935, + 6005943493 / 2108947869, + 393006217 / 1396673457, + 123872331 / 1001029789, + ], + [ + -1028468189 / 846180014, + 0, + 0, + 8478235783 / 508512852, + 1311729495 / 1432422823, + -10304129995 / 1701304382, + -48777925059 / 3047939560, + 15336726248 / 1032824649, + -45442868181 / 3398467696, + 3065993473 / 597172653, + ], + [ + 185892177 / 718116043, + 0, + 0, + -3185094517 / 667107341, + -477755414 / 1098053517, + -703635378 / 230739211, + 5731566787 / 1027545527, + 5232866602 / 850066563, + -4093664535 / 808688257, + 3962137247 / 1805957418, + 65686358 / 487910083, + ], + [ + 403863854 / 491063109, + 0, + 0, + -5068492393 / 434740067, + -411421997 / 543043805, + 652783627 / 914296604, + 11173962825 / 925320556, + -13158990841 / 6184727034, + 3936647629 / 1978049680, + -160528059 / 685178525, + 248638103 / 1413531060, + 0, + ], + [ + 14005451 / 335480064, + 0, + 0, + 0, + 0, + -59238493 / 1068277825, + 181606767 / 758867731, + 561292985 / 797845732, + -1041891430 / 1371343529, + 760417239 / 1151165299, + 118820643 / 751138087, + -528747749 / 2220607170, + 1 / 4, + ], +] -C_mid[11] = (29.5553001484516038033*(h**5) - 82.1020315488359848644*(h**4) + 81.6630950584341412934*(h**3) - 34.7650769866611817349*(h**2) + 5.4106037898590422230*h) / (1/h) +C_sol = [ + 14005451 / 335480064, + 0, + 0, + 0, + 0, + -59238493 / 1068277825, + 181606767 / 758867731, + 561292985 / 797845732, + -1041891430 / 1371343529, + 760417239 / 1151165299, + 118820643 / 751138087, + -528747749 / 2220607170, + 1 / 4, + 0, +] -C_mid[12] = (- 41.7923486424390588923*(h**5) + 116.2662185791119533462*(h**4) - 114.9375291377009418170*(h**3) + 47.7457971078225540396*(h**2) - 7.0321379067945741781*h) / (1/h) +C_err = [ + 14005451 / 335480064 - 13451932 / 455176623, + 0, + 0, + 0, + 0, + -59238493 / 1068277825 - -808719846 / 976000145, + 181606767 / 758867731 - 1757004468 / 5645159321, + 561292985 / 797845732 - 656045339 / 265891186, + -1041891430 / 1371343529 - -3867574721 / 1518517206, + 760417239 / 1151165299 - 465885868 / 322736535, + 118820643 / 751138087 - 53011238 / 667516719, + -528747749 / 2220607170 - 2 / 45, + 1 / 4, + 0, +] -C_mid[13] = (20.3006925822100825485*(h**5) - 53.9020777466385396792*(h**4) + 50.2558364226176017553*(h**3) - 19.0082099341608028453*(h**2) + 2.3537586759714983486*h) / (1/h) +h = 1 / 2 + +C_mid = [0.0] * 14 + +C_mid[0] = ( + -6.3448349392860401388 * (h ** 5) + + 22.1396504998094068976 * (h ** 4) + - 30.0610568289666450593 * (h ** 3) + + 19.9990069333683970610 * (h ** 2) + - 6.6910181737837595697 * h + + 1.0 +) / (1 / h) + +C_mid[5] = ( + -39.6107919852202505218 * (h ** 5) + + 116.4422149550342161651 * (h ** 4) + - 121.4999627731334642623 * (h ** 3) + + 52.2273532792945524050 * (h ** 2) + - 7.6142658045872677172 * h +) / (1 / h) + +C_mid[6] = ( + 20.3761213808791436958 * (h ** 5) + - 67.1451318825957197185 * (h ** 4) + + 83.1721004639847717481 * (h ** 3) + - 46.8919164181093621583 * (h ** 2) + + 10.7281392630428866124 * h +) / (1 / h) + +C_mid[7] = ( + 7.3347098826795362023 * (h ** 5) + - 16.5672243527496524646 * (h ** 4) + + 9.5724507555993664382 * (h ** 3) + - 0.1890893225010595467 * (h ** 2) + + 0.5526637063753648783 * h +) / (1 / h) + +C_mid[8] = ( + 32.8801774352459155182 * (h ** 5) + - 89.9916014847245016028 * (h ** 4) + + 87.8406057677205645007 * (h ** 3) + - 35.7075975946222072821 * (h ** 2) + + 4.2186562625665153803 * h +) / (1 / h) + +C_mid[9] = ( + -10.1588990526426760954 * (h ** 5) + + 22.6237489648532849093 * (h ** 4) + - 17.4152107770762969005 * (h ** 3) + + 6.2736448083240352160 * (h ** 2) + - 0.6627209125361597559 * h +) / (1 / h) + +C_mid[10] = ( + -12.5401268098782561200 * (h ** 5) + + 32.2362340167355370113 * (h ** 4) + - 28.5903289514790976966 * (h ** 3) + + 10.3160881272450748458 * (h ** 2) + - 1.2636789001135462218 * h +) / (1 / h) + +C_mid[11] = ( + 29.5553001484516038033 * (h ** 5) + - 82.1020315488359848644 * (h ** 4) + + 81.6630950584341412934 * (h ** 3) + - 34.7650769866611817349 * (h ** 2) + + 5.4106037898590422230 * h +) / (1 / h) + +C_mid[12] = ( + -41.7923486424390588923 * (h ** 5) + + 116.2662185791119533462 * (h ** 4) + - 114.9375291377009418170 * (h ** 3) + + 47.7457971078225540396 * (h ** 2) + - 7.0321379067945741781 * h +) / (1 / h) + +C_mid[13] = ( + 20.3006925822100825485 * (h ** 5) + - 53.9020777466385396792 * (h ** 4) + + 50.2558364226176017553 * (h ** 3) + - 19.0082099341608028453 * (h ** 2) + + 2.3537586759714983486 * h +) / (1 / h) A = torch.tensor(A, dtype=torch.float64) diff --git a/torchdiffeq/_impl/fixed_adams.py b/torchdiffeq/_impl/fixed_adams.py index 17724b5..08c7e20 100644 --- a/torchdiffeq/_impl/fixed_adams.py +++ b/torchdiffeq/_impl/fixed_adams.py @@ -16,61 +16,203 @@ [4277, -7923, 9982, -7298, 2877, -475], [198721, -447288, 705549, -688256, 407139, -134472, 19087], [434241, -1152169, 2183877, -2664477, 2102243, -1041723, 295767, -36799], - [14097247, -43125206, 95476786, -139855262, 137968480, -91172642, 38833486, -9664106, 1070017], - [30277247, -104995189, 265932680, -454661776, 538363838, -444772162, 252618224, -94307320, 20884811, -2082753], [ - 2132509567, -8271795124, 23591063805, -46113029016, 63716378958, -63176201472, 44857168434, -22329634920, - 7417904451, -1479574348, 134211265 + 14097247, + -43125206, + 95476786, + -139855262, + 137968480, + -91172642, + 38833486, + -9664106, + 1070017, ], [ - 4527766399, -19433810163, 61633227185, -135579356757, 214139355366, -247741639374, 211103573298, -131365867290, - 58189107627, -17410248271, 3158642445, -262747265 + 30277247, + -104995189, + 265932680, + -454661776, + 538363838, + -444772162, + 252618224, + -94307320, + 20884811, + -2082753, ], [ - 13064406523627, -61497552797274, 214696591002612, -524924579905150, 932884546055895, -1233589244941764, - 1226443086129408, -915883387152444, 507140369728425, -202322913738370, 55060974662412, -9160551085734, - 703604254357 + 2132509567, + -8271795124, + 23591063805, + -46113029016, + 63716378958, + -63176201472, + 44857168434, + -22329634920, + 7417904451, + -1479574348, + 134211265, ], [ - 27511554976875, -140970750679621, 537247052515662, -1445313351681906, 2854429571790805, -4246767353305755, - 4825671323488452, -4204551925534524, 2793869602879077, -1393306307155755, 505586141196430, -126174972681906, - 19382853593787, -1382741929621 + 4527766399, + -19433810163, + 61633227185, + -135579356757, + 214139355366, + -247741639374, + 211103573298, + -131365867290, + 58189107627, + -17410248271, + 3158642445, + -262747265, ], [ - 173233498598849, -960122866404112, 3966421670215481, -11643637530577472, 25298910337081429, -41825269932507728, - 53471026659940509, -53246738660646912, 41280216336284259, -24704503655607728, 11205849753515179, - -3728807256577472, 859236476684231, -122594813904112, 8164168737599 + 13064406523627, + -61497552797274, + 214696591002612, + -524924579905150, + 932884546055895, + -1233589244941764, + 1226443086129408, + -915883387152444, + 507140369728425, + -202322913738370, + 55060974662412, + -9160551085734, + 703604254357, ], [ - 362555126427073, -2161567671248849, 9622096909515337, -30607373860520569, 72558117072259733, - -131963191940828581, 187463140112902893, -210020588912321949, 186087544263596643, -129930094104237331, - 70724351582843483, -29417910911251819, 9038571752734087, -1934443196892599, 257650275915823, -16088129229375 + 27511554976875, + -140970750679621, + 537247052515662, + -1445313351681906, + 2854429571790805, + -4246767353305755, + 4825671323488452, + -4204551925534524, + 2793869602879077, + -1393306307155755, + 505586141196430, + -126174972681906, + 19382853593787, + -1382741929621, ], [ - 192996103681340479, -1231887339593444974, 5878428128276811750, -20141834622844109630, 51733880057282977010, - -102651404730855807942, 160414858999474733422, -199694296833704562550, 199061418623907202560, - -158848144481581407370, 100878076849144434322, -50353311405771659322, 19338911944324897550, - -5518639984393844930, 1102560345141059610, -137692773163513234, 8092989203533249 + 173233498598849, + -960122866404112, + 3966421670215481, + -11643637530577472, + 25298910337081429, + -41825269932507728, + 53471026659940509, + -53246738660646912, + 41280216336284259, + -24704503655607728, + 11205849753515179, + -3728807256577472, + 859236476684231, + -122594813904112, + 8164168737599, ], [ - 401972381695456831, -2735437642844079789, 13930159965811142228, -51150187791975812900, 141500575026572531760, - -304188128232928718008, 518600355541383671092, -710171024091234303204, 786600875277595877750, - -706174326992944287370, 512538584122114046748, -298477260353977522892, 137563142659866897224, - -49070094880794267600, 13071639236569712860, -2448689255584545196, 287848942064256339, -15980174332775873 + 362555126427073, + -2161567671248849, + 9622096909515337, + -30607373860520569, + 72558117072259733, + -131963191940828581, + 187463140112902893, + -210020588912321949, + 186087544263596643, + -129930094104237331, + 70724351582843483, + -29417910911251819, + 9038571752734087, + -1934443196892599, + 257650275915823, + -16088129229375, ], [ - 333374427829017307697, -2409687649238345289684, 13044139139831833251471, -51099831122607588046344, - 151474888613495715415020, -350702929608291455167896, 647758157491921902292692, -967713746544629658690408, - 1179078743786280451953222, -1176161829956768365219840, 960377035444205950813626, -639182123082298748001432, - 343690461612471516746028, -147118738993288163742312, 48988597853073465932820, -12236035290567356418552, - 2157574942881818312049, -239560589366324764716, 12600467236042756559 + 192996103681340479, + -1231887339593444974, + 5878428128276811750, + -20141834622844109630, + 51733880057282977010, + -102651404730855807942, + 160414858999474733422, + -199694296833704562550, + 199061418623907202560, + -158848144481581407370, + 100878076849144434322, + -50353311405771659322, + 19338911944324897550, + -5518639984393844930, + 1102560345141059610, + -137692773163513234, + 8092989203533249, ], [ - 691668239157222107697, -5292843584961252933125, 30349492858024727686755, -126346544855927856134295, - 399537307669842150996468, -991168450545135070835076, 1971629028083798845750380, -3191065388846318679544380, - 4241614331208149947151790, -4654326468801478894406214, 4222756879776354065593786, -3161821089800186539248210, - 1943018818982002395655620, -970350191086531368649620, 387739787034699092364924, -121059601023985433003532, - 28462032496476316665705, -4740335757093710713245, 498669220956647866875, -24919383499187492303 + 401972381695456831, + -2735437642844079789, + 13930159965811142228, + -51150187791975812900, + 141500575026572531760, + -304188128232928718008, + 518600355541383671092, + -710171024091234303204, + 786600875277595877750, + -706174326992944287370, + 512538584122114046748, + -298477260353977522892, + 137563142659866897224, + -49070094880794267600, + 13071639236569712860, + -2448689255584545196, + 287848942064256339, + -15980174332775873, + ], + [ + 333374427829017307697, + -2409687649238345289684, + 13044139139831833251471, + -51099831122607588046344, + 151474888613495715415020, + -350702929608291455167896, + 647758157491921902292692, + -967713746544629658690408, + 1179078743786280451953222, + -1176161829956768365219840, + 960377035444205950813626, + -639182123082298748001432, + 343690461612471516746028, + -147118738993288163742312, + 48988597853073465932820, + -12236035290567356418552, + 2157574942881818312049, + -239560589366324764716, + 12600467236042756559, + ], + [ + 691668239157222107697, + -5292843584961252933125, + 30349492858024727686755, + -126346544855927856134295, + 399537307669842150996468, + -991168450545135070835076, + 1971629028083798845750380, + -3191065388846318679544380, + 4241614331208149947151790, + -4654326468801478894406214, + 4222756879776354065593786, + -3161821089800186539248210, + 1943018818982002395655620, + -970350191086531368649620, + 387739787034699092364924, + -121059601023985433003532, + 28462032496476316665705, + -4740335757093710713245, + 498669220956647866875, + -24919383499187492303, ], ] @@ -85,70 +227,227 @@ [19087, 65112, -46461, 37504, -20211, 6312, -863], [36799, 139849, -121797, 123133, -88547, 41499, -11351, 1375], [1070017, 4467094, -4604594, 5595358, -5033120, 3146338, -1291214, 312874, -33953], - [2082753, 9449717, -11271304, 16002320, -17283646, 13510082, -7394032, 2687864, -583435, 57281], [ - 134211265, 656185652, -890175549, 1446205080, -1823311566, 1710774528, -1170597042, 567450984, -184776195, - 36284876, -3250433 + 2082753, + 9449717, + -11271304, + 16002320, + -17283646, + 13510082, + -7394032, + 2687864, + -583435, + 57281, ], [ - 262747265, 1374799219, -2092490673, 3828828885, -5519460582, 6043521486, -4963166514, 3007739418, -1305971115, - 384709327, -68928781, 5675265 + 134211265, + 656185652, + -890175549, + 1446205080, + -1823311566, + 1710774528, + -1170597042, + 567450984, + -184776195, + 36284876, + -3250433, ], [ - 703604254357, 3917551216986, -6616420957428, 13465774256510, -21847538039895, 27345870698436, -26204344465152, - 19058185652796, -10344711794985, 4063327863170, -1092096992268, 179842822566, -13695779093 + 262747265, + 1374799219, + -2092490673, + 3828828885, + -5519460582, + 6043521486, + -4963166514, + 3007739418, + -1305971115, + 384709327, + -68928781, + 5675265, ], [ - 1382741929621, 8153167962181, -15141235084110, 33928990133618, -61188680131285, 86180228689563, -94393338653892, - 80101021029180, -52177910882661, 25620259777835, -9181635605134, 2268078814386, -345457086395, 24466579093 + 703604254357, + 3917551216986, + -6616420957428, + 13465774256510, + -21847538039895, + 27345870698436, + -26204344465152, + 19058185652796, + -10344711794985, + 4063327863170, + -1092096992268, + 179842822566, + -13695779093, ], [ - 8164168737599, 50770967534864, -102885148956217, 251724894607936, -499547203754837, 781911618071632, - -963605400824733, 934600833490944, -710312834197347, 418551804601264, -187504936597931, 61759426692544, - -14110480969927, 1998759236336, -132282840127 + 1382741929621, + 8153167962181, + -15141235084110, + 33928990133618, + -61188680131285, + 86180228689563, + -94393338653892, + 80101021029180, + -52177910882661, + 25620259777835, + -9181635605134, + 2268078814386, + -345457086395, + 24466579093, ], [ - 16088129229375, 105145058757073, -230992163723849, 612744541065337, -1326978663058069, 2285168598349733, - -3129453071993581, 3414941728852893, -2966365730265699, 2039345879546643, -1096355235402331, 451403108933483, - -137515713789319, 29219384284087, -3867689367599, 240208245823 + 8164168737599, + 50770967534864, + -102885148956217, + 251724894607936, + -499547203754837, + 781911618071632, + -963605400824733, + 934600833490944, + -710312834197347, + 418551804601264, + -187504936597931, + 61759426692544, + -14110480969927, + 1998759236336, + -132282840127, ], [ - 8092989203533249, 55415287221275246, -131240807912923110, 375195469874202430, -880520318434977010, - 1654462865819232198, -2492570347928318318, 3022404969160106870, -2953729295811279360, 2320851086013919370, - -1455690451266780818, 719242466216944698, -273894214307914510, 77597639915764930, -15407325991235610, - 1913813460537746, -111956703448001 + 16088129229375, + 105145058757073, + -230992163723849, + 612744541065337, + -1326978663058069, + 2285168598349733, + -3129453071993581, + 3414941728852893, + -2966365730265699, + 2039345879546643, + -1096355235402331, + 451403108933483, + -137515713789319, + 29219384284087, + -3867689367599, + 240208245823, ], [ - 15980174332775873, 114329243705491117, -290470969929371220, 890337710266029860, -2250854333681641520, - 4582441343348851896, -7532171919277411636, 10047287575124288740, -10910555637627652470, 9644799218032932490, - -6913858539337636636, 3985516155854664396, -1821304040326216520, 645008976643217360, -170761422500096220, - 31816981024600492, -3722582669836627, 205804074290625 + 8092989203533249, + 55415287221275246, + -131240807912923110, + 375195469874202430, + -880520318434977010, + 1654462865819232198, + -2492570347928318318, + 3022404969160106870, + -2953729295811279360, + 2320851086013919370, + -1455690451266780818, + 719242466216944698, + -273894214307914510, + 77597639915764930, + -15407325991235610, + 1913813460537746, + -111956703448001, ], [ - 12600467236042756559, 93965550344204933076, -255007751875033918095, 834286388106402145800, - -2260420115705863623660, 4956655592790542146968, -8827052559979384209108, 12845814402199484797800, - -15345231910046032448070, 15072781455122686545920, -12155867625610599812538, 8008520809622324571288, - -4269779992576330506540, 1814584564159445787240, -600505972582990474260, 149186846171741510136, - -26182538841925312881, 2895045518506940460, -151711881512390095 + 15980174332775873, + 114329243705491117, + -290470969929371220, + 890337710266029860, + -2250854333681641520, + 4582441343348851896, + -7532171919277411636, + 10047287575124288740, + -10910555637627652470, + 9644799218032932490, + -6913858539337636636, + 3985516155854664396, + -1821304040326216520, + 645008976643217360, + -170761422500096220, + 31816981024600492, + -3722582669836627, + 205804074290625, ], [ - 24919383499187492303, 193280569173472261637, -558160720115629395555, 1941395668950986461335, - -5612131802364455926260, 13187185898439270330756, -25293146116627869170796, 39878419226784442421820, - -51970649453670274135470, 56154678684618739939910, -50320851025594566473146, 37297227252822858381906, - -22726350407538133839300, 11268210124987992327060, -4474886658024166985340, 1389665263296211699212, - -325187970422032795497, 53935307402575440285, -5652892248087175675, 281550972898020815 + 12600467236042756559, + 93965550344204933076, + -255007751875033918095, + 834286388106402145800, + -2260420115705863623660, + 4956655592790542146968, + -8827052559979384209108, + 12845814402199484797800, + -15345231910046032448070, + 15072781455122686545920, + -12155867625610599812538, + 8008520809622324571288, + -4269779992576330506540, + 1814584564159445787240, + -600505972582990474260, + 149186846171741510136, + -26182538841925312881, + 2895045518506940460, + -151711881512390095, + ], + [ + 24919383499187492303, + 193280569173472261637, + -558160720115629395555, + 1941395668950986461335, + -5612131802364455926260, + 13187185898439270330756, + -25293146116627869170796, + 39878419226784442421820, + -51970649453670274135470, + 56154678684618739939910, + -50320851025594566473146, + 37297227252822858381906, + -22726350407538133839300, + 11268210124987992327060, + -4474886658024166985340, + 1389665263296211699212, + -325187970422032795497, + 53935307402575440285, + -5652892248087175675, + 281550972898020815, ], ] _DIVISOR = [ - None, 11, 2, 12, 24, 720, 1440, 60480, 120960, 3628800, 7257600, 479001600, 958003200, 2615348736000, 5230697472000, - 31384184832000, 62768369664000, 32011868528640000, 64023737057280000, 51090942171709440000, 102181884343418880000 + None, + 11, + 2, + 12, + 24, + 720, + 1440, + 60480, + 120960, + 3628800, + 7257600, + 479001600, + 958003200, + 2615348736000, + 5230697472000, + 31384184832000, + 62768369664000, + 32011868528640000, + 64023737057280000, + 51090942171709440000, + 102181884343418880000, ] -_BASHFORTH_DIVISOR = [torch.tensor([b / divisor for b in bashforth], dtype=torch.float64) - for bashforth, divisor in zip(_BASHFORTH_COEFFICIENTS, _DIVISOR)] -_MOULTON_DIVISOR = [torch.tensor([m / divisor for m in moulton], dtype=torch.float64) - for moulton, divisor in zip(_MOULTON_COEFFICIENTS, _DIVISOR)] +_BASHFORTH_DIVISOR = [ + torch.tensor([b / divisor for b in bashforth], dtype=torch.float64) + for bashforth, divisor in zip(_BASHFORTH_COEFFICIENTS, _DIVISOR) +] +_MOULTON_DIVISOR = [ + torch.tensor([m / divisor for m in moulton], dtype=torch.float64) + for moulton, divisor in zip(_MOULTON_COEFFICIENTS, _DIVISOR) +] _MIN_ORDER = 4 _MAX_ORDER = 12 @@ -163,12 +462,27 @@ def _dot_product(x, y): class AdamsBashforthMoulton(FixedGridODESolver): order = 4 - def __init__(self, func, y0, rtol=1e-3, atol=1e-4, implicit=True, max_iters=_MAX_ITERS, max_order=_MAX_ORDER, - **kwargs): + def __init__( + self, + func, + y0, + rtol=1e-3, + atol=1e-4, + implicit=True, + max_iters=_MAX_ITERS, + max_order=_MAX_ORDER, + **kwargs, + ): super(AdamsBashforthMoulton, self).__init__(func, y0, **kwargs) - assert max_order <= _MAX_ORDER, "max_order must be at most {}".format(_MAX_ORDER) + assert max_order <= _MAX_ORDER, "max_order must be at most {}".format( + _MAX_ORDER + ) if max_order < _MIN_ORDER: - warnings.warn("max_order is below {}, so the solver reduces to `rk4`.".format(_MIN_ORDER)) + warnings.warn( + "max_order is below {}, so the solver reduces to `rk4`.".format( + _MIN_ORDER + ) + ) self.rtol = torch.as_tensor(rtol, dtype=y0.dtype, device=y0.device) self.atol = torch.as_tensor(atol, dtype=y0.dtype, device=y0.device) @@ -188,7 +502,9 @@ def _update_history(self, t, f): def _has_converged(self, y0, y1): """Checks that each element is within the error tolerance.""" - error_ratio = _compute_error_ratio(torch.abs(y0 - y1), self.rtol, self.atol, y0, y1, _linf_norm) + error_ratio = _compute_error_ratio( + torch.abs(y0 - y1), self.rtol, self.atol, y0, y1, _linf_norm + ) return error_ratio < 1 def _step_func(self, func, t, dt, y): @@ -201,22 +517,31 @@ def _step_func(self, func, t, dt, y): else: # Adams-Bashforth predictor. bashforth_coeffs = self.bashforth[order] - dy = _dot_product(dt * bashforth_coeffs, self.prev_f).type_as(y) # bashforth is float64 so cast back + dy = _dot_product(dt * bashforth_coeffs, self.prev_f).type_as( + y + ) # bashforth is float64 so cast back # Adams-Moulton corrector. if self.implicit: moulton_coeffs = self.moulton[order + 1] - delta = dt * _dot_product(moulton_coeffs[1:], self.prev_f).type_as(y) # moulton is float64 so cast back + delta = dt * _dot_product(moulton_coeffs[1:], self.prev_f).type_as( + y + ) # moulton is float64 so cast back converged = False for _ in range(self.max_iters): dy_old = dy f = func(t + dt, y + dy) - dy = (dt * (moulton_coeffs[0]) * f).type_as(y) + delta # moulton is float64 so cast back + dy = (dt * (moulton_coeffs[0]) * f).type_as( + y + ) + delta # moulton is float64 so cast back converged = self._has_converged(dy_old, dy) if converged: break if not converged: - warnings.warn('Functional iteration did not converge. Solution may be incorrect.', file=sys.stderr) + warnings.warn( + "Functional iteration did not converge. Solution may be incorrect.", + file=sys.stderr, + ) self.prev_f.pop() self._update_history(t, f) return dy diff --git a/torchdiffeq/_impl/fixed_grid.py b/torchdiffeq/_impl/fixed_grid.py index 6906c2b..ec2da26 100644 --- a/torchdiffeq/_impl/fixed_grid.py +++ b/torchdiffeq/_impl/fixed_grid.py @@ -6,7 +6,7 @@ class Euler(FixedGridODESolver): order = 1 - def __init__(self, eps=0., **kwargs): + def __init__(self, eps=0.0, **kwargs): super(Euler, self).__init__(**kwargs) self.eps = torch.as_tensor(eps, dtype=self.dtype, device=self.device) @@ -17,7 +17,7 @@ def _step_func(self, func, t, dt, y): class Midpoint(FixedGridODESolver): order = 2 - def __init__(self, eps=0., **kwargs): + def __init__(self, eps=0.0, **kwargs): super(Midpoint, self).__init__(**kwargs) self.eps = torch.as_tensor(eps, dtype=self.dtype, device=self.device) @@ -30,7 +30,7 @@ def _step_func(self, func, t, dt, y): class RK4(FixedGridODESolver): order = 4 - def __init__(self, eps=0., **kwargs): + def __init__(self, eps=0.0, **kwargs): super(RK4, self).__init__(**kwargs) self.eps = torch.as_tensor(eps, dtype=self.dtype, device=self.device) diff --git a/torchdiffeq/_impl/interp.py b/torchdiffeq/_impl/interp.py index c6a083b..f5e272d 100644 --- a/torchdiffeq/_impl/interp.py +++ b/torchdiffeq/_impl/interp.py @@ -35,7 +35,9 @@ def _interp_evaluate(coefficients, t0, t1, t): Polynomial interpolation of the coefficients at time `t`. """ - assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1) + assert (t0 <= t) & ( + t <= t1 + ), "invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}".format(t0, t, t1) x = (t - t0) / (t1 - t0) total = coefficients[0] + x * coefficients[1] diff --git a/torchdiffeq/_impl/misc.py b/torchdiffeq/_impl/misc.py index 01cd4f0..354fa5c 100644 --- a/torchdiffeq/_impl/misc.py +++ b/torchdiffeq/_impl/misc.py @@ -5,7 +5,11 @@ def _handle_unused_kwargs(solver, unused_kwargs): if len(unused_kwargs) > 0: - warnings.warn('{}: Unexpected arguments {}'.format(solver.__class__.__name__, unused_kwargs)) + warnings.warn( + "{}: Unexpected arguments {}".format( + solver.__class__.__name__, unused_kwargs + ) + ) def _rms_norm(tensor): @@ -20,8 +24,11 @@ def _norm(tensor): next_total = total + shape.numel() out.append(_rms_norm(tensor[total:next_total])) total = next_total - assert total == tensor.numel(), "Shapes do not total to the full size of the tensor." + assert ( + total == tensor.numel() + ), "Shapes do not total to the full size of the tensor." return max(out) + return _norm @@ -36,8 +43,11 @@ def _norm(tensor): else: out.append(_rms_norm(tensor[total:next_total])) total = next_total - assert total == tensor.numel(), "Shapes do not total to the full size of the tensor." + assert ( + total == tensor.numel() + ), "Shapes do not total to the full size of the tensor." return max(out) + return _norm @@ -82,7 +92,7 @@ def _select_initial_step(func, t0, y0, order, rtol, atol, norm, f0=None): if d1 <= 1e-15 and d2 <= 1e-15: h1 = torch.max(torch.tensor(1e-6, dtype=dtype, device=device), h0 * 1e-3) else: - h1 = (0.01 / max(d1, d2)) ** (1. / float(order + 1)) + h1 = (0.01 / max(d1, d2)) ** (1.0 / float(order + 1)) return torch.min(100 * h0, h1).to(t_dtype) @@ -99,7 +109,9 @@ def _optimal_step_size(last_step, error_ratio, safety, ifactor, dfactor, order): if error_ratio < 1: dfactor = torch.ones((), dtype=last_step.dtype, device=last_step.device) error_ratio = error_ratio.type_as(last_step) - exponent = torch.tensor(order, dtype=last_step.dtype, device=last_step.device).reciprocal() + exponent = torch.tensor( + order, dtype=last_step.dtype, device=last_step.device + ).reciprocal() factor = torch.min(ifactor, torch.max(safety / error_ratio ** exponent, dfactor)) return last_step * factor @@ -113,12 +125,16 @@ def _assert_one_dimensional(name, t): def _assert_increasing(name, t): - assert (t[1:] > t[:-1]).all(), '{} must be strictly increasing or decreasing'.format(name) + assert ( + t[1:] > t[:-1] + ).all(), "{} must be strictly increasing or decreasing".format(name) def _assert_floating(name, t): if not torch.is_floating_point(t): - raise TypeError('`{}` must be a floating point Tensor but is a {}'.format(name, t.type())) + raise TypeError( + "`{}` must be a floating point Tensor but is a {}".format(name, t.type()) + ) def _tuple_tol(name, tol, shapes): @@ -127,8 +143,12 @@ def _tuple_tol(name, tol, shapes): except TypeError: return tol tol = tuple(tol) - assert len(tol) == len(shapes), "If using tupled {} it must have the same length as the tuple y0".format(name) - tol = [torch.as_tensor(tol_).expand(shape.numel()) for tol_, shape in zip(tol, shapes)] + assert len(tol) == len( + shapes + ), "If using tupled {} it must have the same length as the tuple y0".format(name) + tol = [ + torch.as_tensor(tol_).expand(shape.numel()) for tol_, shape in zip(tol, shapes) + ] return torch.cat(tol) @@ -167,13 +187,13 @@ def _check_inputs(func, y0, t, rtol, atol, method, options, SOLVERS): # Normalise to tensor (non-tupled) input shapes = None if not torch.is_tensor(y0): - assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple' + assert isinstance(y0, tuple), "y0 must be either a torch.Tensor or a tuple" shapes = [y0_.shape for y0_ in y0] - rtol = _tuple_tol('rtol', rtol, shapes) - atol = _tuple_tol('atol', atol, shapes) + rtol = _tuple_tol("rtol", rtol, shapes) + atol = _tuple_tol("atol", atol, shapes) y0 = torch.cat([y0_.reshape(-1) for y0_ in y0]) func = _TupleFunc(func, shapes) - _assert_floating('y0', y0) + _assert_floating("y0", y0) # Normalise method and options if options is None: @@ -181,51 +201,54 @@ def _check_inputs(func, y0, t, rtol, atol, method, options, SOLVERS): else: options = options.copy() if method is None: - method = 'dopri5' + method = "dopri5" if method not in SOLVERS: - raise ValueError('Invalid method "{}". Must be one of {}'.format(method, - '{"' + '", "'.join(SOLVERS.keys()) + '"}.')) + raise ValueError( + 'Invalid method "{}". Must be one of {}'.format( + method, '{"' + '", "'.join(SOLVERS.keys()) + '"}.' + ) + ) try: - grid_points = options['grid_points'] + grid_points = options["grid_points"] except KeyError: pass else: - assert torch.is_tensor(grid_points), 'grid_points must be a torch.Tensor' - _assert_one_dimensional('grid_points', grid_points) + assert torch.is_tensor(grid_points), "grid_points must be a torch.Tensor" + _assert_one_dimensional("grid_points", grid_points) assert not grid_points.requires_grad, "grid_points cannot require gradient" - _assert_floating('grid_points', grid_points) + _assert_floating("grid_points", grid_points) - if 'norm' not in options: + if "norm" not in options: if shapes is None: # L2 norm over a single input - options['norm'] = _rms_norm + options["norm"] = _rms_norm else: # Mixed Linf/L2 norm over tupled input (chosen mostly just for backward compatibility reasons) - options['norm'] = _mixed_linf_rms_norm(shapes) + options["norm"] = _mixed_linf_rms_norm(shapes) # Normalise time - assert torch.is_tensor(t), 't must be a torch.Tensor' - _assert_one_dimensional('t', t) - _assert_floating('t', t) + assert torch.is_tensor(t), "t must be a torch.Tensor" + _assert_one_dimensional("t", t) + _assert_floating("t", t) if _decreasing(t): t = -t func = _ReverseFunc(func) try: - grid_points = options['grid_points'] + grid_points = options["grid_points"] except KeyError: pass else: - options['grid_points'] = -grid_points + options["grid_points"] = -grid_points # Can only do after having normalised time - _assert_increasing('t', t) + _assert_increasing("t", t) try: - grid_points = options['grid_points'] + grid_points = options["grid_points"] except KeyError: pass else: - _assert_increasing('grid_points', grid_points) + _assert_increasing("grid_points", grid_points) # Tol checking if torch.is_tensor(rtol): @@ -241,13 +264,14 @@ def _check_inputs(func, y0, t, rtol, atol, method, options, SOLVERS): return shapes, func, y0, t, rtol, atol, method, options + #################################################################################################################### # Chebyshev grids def _cby_grids(t_min, t_max, n): - k = torch.linspace(0.0, np.pi, n+1, dtype=t_min.dtype, device=t_min.device) + k = torch.linspace(0.0, np.pi, n + 1, dtype=t_min.dtype, device=t_min.device) grids = torch.cos(k) - grids = (t_min + t_max) / 2. + grids * (t_min - t_max) / 2. + grids = (t_min + t_max) / 2.0 + grids * (t_min - t_max) / 2.0 return grids @@ -259,16 +283,16 @@ def cby_grid_type1(t_min=0, t_max=1, n=11): c = np.cos(theta) if (n % 2) == 1: if 2 * i + 1 == n: - c = 0. - grids[i] = ((1. - c) * t_min + (1. + c) * t_max) / 2. + c = 0.0 + grids[i] = ((1.0 - c) * t_min + (1.0 + c) * t_max) / 2.0 return grids def barycentric_weights(n): w = np.zeros(n) - s = 1. + s = 1.0 for j in range(n): - w[j] = s * np.sin((2 * j + 1) * np.pi / (2. * n)) + w[j] = s * np.sin((2 * j + 1) * np.pi / (2.0 * n)) s = -s return w @@ -280,15 +304,15 @@ def _cby1_interp(w, nodes, values, t): return values[idx][0] cof = w / (t - nodes) - s = '' + s = "" if len(values.shape) == 5: - s = 'i,ijkmn->jkmn' + s = "i,ijkmn->jkmn" elif len(values.shape) == 4: - s = 'i,ijkm->jkm' + s = "i,ijkm->jkm" elif len(values.shape) == 3: - s = 'i,ijk->jk' + s = "i,ijk->jk" elif len(values.shape) == 2: - s = 'i,ij->j' + s = "i,ij->j" num = torch.einsum(s, [cof, values]) den = cof.sum() res = num / den diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index df6b36f..9b791fa 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -7,17 +7,17 @@ from .misc import _check_inputs, _flat_to_shape SOLVERS = { - 'dopri8': Dopri8Solver, - 'dopri5': Dopri5Solver, - 'bosh3': Bosh3Solver, - 'adaptive_heun': AdaptiveHeunSolver, - 'euler': Euler, - 'midpoint': Midpoint, - 'rk4': RK4, - 'explicit_adams': AdamsBashforth, - 'implicit_adams': AdamsBashforthMoulton, + "dopri8": Dopri8Solver, + "dopri5": Dopri5Solver, + "bosh3": Bosh3Solver, + "adaptive_heun": AdaptiveHeunSolver, + "euler": Euler, + "midpoint": Midpoint, + "rk4": RK4, + "explicit_adams": AdamsBashforth, + "implicit_adams": AdamsBashforthMoulton, # Backward compatibility: use the same name as before - 'fixed_adams': AdamsBashforthMoulton, + "fixed_adams": AdamsBashforthMoulton, # ~Backwards compatibility } @@ -59,7 +59,9 @@ def odeint(func, y0, t, rtol=1e-7, atol=1e-9, method=None, options=None): Raises: ValueError: if an invalid `method` is provided. """ - shapes, func, y0, t, rtol, atol, method, options = _check_inputs(func, y0, t, rtol, atol, method, options, SOLVERS) + shapes, func, y0, t, rtol, atol, method, options = _check_inputs( + func, y0, t, rtol, atol, method, options, SOLVERS + ) solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options) solution = solver.integrate(t) diff --git a/torchdiffeq/_impl/rk_common.py b/torchdiffeq/_impl/rk_common.py index 34eb824..2efead2 100644 --- a/torchdiffeq/_impl/rk_common.py +++ b/torchdiffeq/_impl/rk_common.py @@ -2,16 +2,18 @@ import collections import torch from .interp import _interp_evaluate, _interp_fit -from .misc import (_compute_error_ratio, - _select_initial_step, - _optimal_step_size) +from .misc import _compute_error_ratio, _select_initial_step, _optimal_step_size from .solvers import AdaptiveStepsizeODESolver -_ButcherTableau = collections.namedtuple('_ButcherTableau', 'alpha, beta, c_sol, c_error') +_ButcherTableau = collections.namedtuple( + "_ButcherTableau", "alpha, beta, c_sol, c_error" +) -_RungeKuttaState = collections.namedtuple('_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff') +_RungeKuttaState = collections.namedtuple( + "_RungeKuttaState", "y1, f1, t0, t1, dt, interp_coeff" +) # Saved state of the Runge Kutta solver. # # Attributes: @@ -63,7 +65,7 @@ def _runge_kutta_step(func, y0, f0, t0, dt, tableau): k = _UncheckedAssign.apply(k, f0, (..., 0)) for i, (alpha_i, beta_i) in enumerate(zip(tableau.alpha, tableau.beta)): ti = t0 + alpha_i * dt - yi = y0 + k[..., :i + 1].matmul(beta_i * dt).view_as(f0) + yi = y0 + k[..., : i + 1].matmul(beta_i * dt).view_as(f0) f = func(ti, yi) k = _UncheckedAssign.apply(k, f, (..., i + 1)) @@ -108,8 +110,22 @@ class RKAdaptiveStepsizeODESolver(AdaptiveStepsizeODESolver): tableau: _ButcherTableau mid: torch.Tensor - def __init__(self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, - max_num_steps=2 ** 31 - 1, grid_points=None, eps=0., dtype=torch.float64, **kwargs): + def __init__( + self, + func, + y0, + rtol, + atol, + first_step=None, + safety=0.9, + ifactor=10.0, + dfactor=0.2, + max_num_steps=2 ** 31 - 1, + grid_points=None, + eps=0.0, + dtype=torch.float64, + **kwargs, + ): super(RKAdaptiveStepsizeODESolver, self).__init__(dtype=dtype, y0=y0, **kwargs) # We use mixed precision. y has its original dtype (probably float32), whilst all 'time'-like objects use @@ -120,41 +136,69 @@ def __init__(self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10 self.func = lambda t, y: func(t.type_as(y), y) self.rtol = torch.as_tensor(rtol, dtype=dtype, device=device) self.atol = torch.as_tensor(atol, dtype=dtype, device=device) - self.first_step = None if first_step is None else torch.as_tensor(first_step, dtype=dtype, device=device) + self.first_step = ( + None + if first_step is None + else torch.as_tensor(first_step, dtype=dtype, device=device) + ) self.safety = torch.as_tensor(safety, dtype=dtype, device=device) self.ifactor = torch.as_tensor(ifactor, dtype=dtype, device=device) self.dfactor = torch.as_tensor(dfactor, dtype=dtype, device=device) - self.max_num_steps = torch.as_tensor(max_num_steps, dtype=torch.int32, device=device) - grid_points = torch.tensor([], dtype=dtype, device=device) if grid_points is None else grid_points.to(dtype) + self.max_num_steps = torch.as_tensor( + max_num_steps, dtype=torch.int32, device=device + ) + grid_points = ( + torch.tensor([], dtype=dtype, device=device) + if grid_points is None + else grid_points.to(dtype) + ) self.grid_points = grid_points self.eps = torch.as_tensor(eps, dtype=dtype, device=device) self.dtype = dtype # Copy from class to instance to set device - self.tableau = _ButcherTableau(alpha=self.tableau.alpha.to(device=device, dtype=y0.dtype), - beta=[b.to(device=device, dtype=y0.dtype) for b in self.tableau.beta], - c_sol=self.tableau.c_sol.to(device=device, dtype=y0.dtype), - c_error=self.tableau.c_error.to(device=device, dtype=y0.dtype)) + self.tableau = _ButcherTableau( + alpha=self.tableau.alpha.to(device=device, dtype=y0.dtype), + beta=[b.to(device=device, dtype=y0.dtype) for b in self.tableau.beta], + c_sol=self.tableau.c_sol.to(device=device, dtype=y0.dtype), + c_error=self.tableau.c_error.to(device=device, dtype=y0.dtype), + ) self.mid = self.mid.to(device=device, dtype=y0.dtype) def _before_integrate(self, t): f0 = self.func(t[0], self.y0) if self.first_step is None: - first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol, - self.norm, f0=f0) + first_step = _select_initial_step( + self.func, + t[0], + self.y0, + self.order - 1, + self.rtol, + self.atol, + self.norm, + f0=f0, + ) else: first_step = self.first_step - self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, [self.y0] * 5) - self.next_grid_index = min(bisect.bisect(self.grid_points.tolist(), t[0]), len(self.grid_points) - 1) + self.rk_state = _RungeKuttaState( + self.y0, f0, t[0], t[0], first_step, [self.y0] * 5 + ) + self.next_grid_index = min( + bisect.bisect(self.grid_points.tolist(), t[0]), len(self.grid_points) - 1 + ) def _advance(self, next_t): """Interpolate through the next time point, integrating as necessary.""" n_steps = 0 while next_t > self.rk_state.t1: - assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) + assert ( + n_steps < self.max_num_steps + ), "max_num_steps exceeded ({}>={})".format(n_steps, self.max_num_steps) self.rk_state = self._adaptive_step(self.rk_state) n_steps += 1 - return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t) + return _interp_evaluate( + self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t + ) def _adaptive_step(self, rk_state): """Take an adaptive Runge-Kutta step to integrate the ODE.""" @@ -168,17 +212,19 @@ def _adaptive_step(self, rk_state): # dt.dtype == self.dtype # for coeff in interp_coeff: coeff.dtype == self.y0.dtype - ######################################################## # Assertions # ######################################################## - assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) - assert torch.isfinite(y0).all(), 'non-finite values in state `y`: {}'.format(y0) + assert t0 + dt > t0, "underflow in dt {}".format(dt.item()) + assert torch.isfinite(y0).all(), "non-finite values in state `y`: {}".format(y0) ######################################################## # Make step, respecting prescribed grid points # ######################################################## - on_grid = len(self.grid_points) and t0 < self.grid_points[self.next_grid_index] < t0 + dt + on_grid = ( + len(self.grid_points) + and t0 < self.grid_points[self.next_grid_index] < t0 + dt + ) if on_grid: dt = self.grid_points[self.next_grid_index] - t0 eps = min(0.5 * dt, self.eps) @@ -186,7 +232,9 @@ def _adaptive_step(self, rk_state): else: eps = 0 - y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=self.tableau) + y1, f1, y1_error, k = _runge_kutta_step( + self.func, y0, f0, t0, dt, tableau=self.tableau + ) # dtypes: # y1.dtype == self.y0.dtype # f1.dtype == self.y0.dtype @@ -196,7 +244,9 @@ def _adaptive_step(self, rk_state): ######################################################## # Error Ratio # ######################################################## - error_ratio = _compute_error_ratio(y1_error, self.rtol, self.atol, y0, y1, self.norm) + error_ratio = _compute_error_ratio( + y1_error, self.rtol, self.atol, y0, y1, self.norm + ) accept_step = error_ratio <= 1 # dtypes: # error_ratio.dtype == self.dtype @@ -215,7 +265,9 @@ def _adaptive_step(self, rk_state): self.next_grid_index += 1 f_next = f1 if accept_step else f0 interp_coeff = self._interp_fit(y0, y1, k, dt) if accept_step else interp_coeff - dt_next = _optimal_step_size(dt, error_ratio, self.safety, self.ifactor, self.dfactor, self.order) + dt_next = _optimal_step_size( + dt, error_ratio, self.safety, self.ifactor, self.dfactor, self.order + ) rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff) return rk_state diff --git a/torchdiffeq/_impl/solvers.py b/torchdiffeq/_impl/solvers.py index e00bd1e..1a5804d 100644 --- a/torchdiffeq/_impl/solvers.py +++ b/torchdiffeq/_impl/solvers.py @@ -21,7 +21,9 @@ def _advance(self, next_t): raise NotImplementedError def integrate(self, t): - solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) + solution = torch.empty( + len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device + ) solution[0] = self.y0 t = t.to(self.dtype) self._before_integrate(t) @@ -33,10 +35,12 @@ def integrate(self, t): class FixedGridODESolver(metaclass=abc.ABCMeta): order: int - def __init__(self, func, y0, step_size=None, grid_constructor=None, **unused_kwargs): - unused_kwargs.pop('rtol', None) - unused_kwargs.pop('atol', None) - unused_kwargs.pop('norm', None) + def __init__( + self, func, y0, step_size=None, grid_constructor=None, **unused_kwargs + ): + unused_kwargs.pop("rtol", None) + unused_kwargs.pop("atol", None) + unused_kwargs.pop("norm", None) _handle_unused_kwargs(self, unused_kwargs) del unused_kwargs @@ -54,7 +58,9 @@ def __init__(self, func, y0, step_size=None, grid_constructor=None, **unused_kwa if grid_constructor is None: self.grid_constructor = self._grid_constructor_from_step_size(step_size) else: - raise ValueError("step_size and grid_constructor are mutually exclusive arguments.") + raise ValueError( + "step_size and grid_constructor are mutually exclusive arguments." + ) @staticmethod def _grid_constructor_from_step_size(step_size): @@ -63,11 +69,15 @@ def _grid_constructor(func, y0, t): end_time = t[-1] niters = torch.ceil((end_time - start_time) / step_size + 1).item() - t_infer = torch.arange(0, niters, dtype=t.dtype, device=t.device) * step_size + start_time + t_infer = ( + torch.arange(0, niters, dtype=t.dtype, device=t.device) * step_size + + start_time + ) if t_infer[-1] > t[-1]: t_infer[-1] = t[-1] return t_infer + return _grid_constructor @abc.abstractmethod @@ -78,7 +88,9 @@ def integrate(self, t): time_grid = self.grid_constructor(self.func, self.y0, t) assert time_grid[0] == t[0] and time_grid[-1] == t[-1] - solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) + solution = torch.empty( + len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device + ) solution[0] = self.y0 j = 1 diff --git a/torchdiffeq/_impl/tsit5.py b/torchdiffeq/_impl/tsit5.py index cabb585..9fb0610 100644 --- a/torchdiffeq/_impl/tsit5.py +++ b/torchdiffeq/_impl/tsit5.py @@ -6,16 +6,42 @@ # Parameters from Tsitouras (2011). _TSITOURAS_TABLEAU = _ButcherTableau( - alpha=[0.161, 0.327, 0.9, 0.9800255409045097, 1., 1.], + alpha=[0.161, 0.327, 0.9, 0.9800255409045097, 1.0, 1.0], beta=[ [0.161], [-0.008480655492357, 0.3354806554923570], [2.897153057105494, -6.359448489975075, 4.362295432869581], - [5.32586482843925895, -11.74888356406283, 7.495539342889836, -0.09249506636175525], - [5.86145544294642038, -12.92096931784711, 8.159367898576159, -0.071584973281401006, -0.02826905039406838], - [0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774], + [ + 5.32586482843925895, + -11.74888356406283, + 7.495539342889836, + -0.09249506636175525, + ], + [ + 5.86145544294642038, + -12.92096931784711, + 8.159367898576159, + -0.071584973281401006, + -0.02826905039406838, + ], + [ + 0.09646076681806523, + 0.01, + 0.4798896504144996, + 1.379008574103742, + -3.290069515436081, + 2.324710524099774, + ], + ], + c_sol=[ + 0.09646076681806523, + 0.01, + 0.4798896504144996, + 1.379008574103742, + -3.290069515436081, + 2.324710524099774, + 0, ], - c_sol=[0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774, 0], c_error=[ 0.09646076681806523 - 0.001780011052226, 0.01 - 0.000816434459657, @@ -30,13 +56,32 @@ def _interp_coeff_tsit5(t0, dt, eval_t): t = float((eval_t - t0) / dt) - b1 = -1.0530884977290216 * t * (t - 1.3299890189751412) * (t**2 - 1.4364028541716351 * t + 0.7139816917074209) - b2 = 0.1017 * t**2 * (t**2 - 2.1966568338249754 * t + 1.2949852507374631) - b3 = 2.490627285651252793 * t**2 * (t**2 - 2.38535645472061657 * t + 1.57803468208092486) - b4 = -16.54810288924490272 * (t - 1.21712927295533244) * (t - 0.61620406037800089) * t**2 - b5 = 47.37952196281928122 * (t - 1.203071208372362603) * (t - 0.658047292653547382) * t**2 - b6 = -34.87065786149660974 * (t - 1.2) * (t - 0.666666666666666667) * t**2 - b7 = 2.5 * (t - 1) * (t - 0.6) * t**2 + b1 = ( + -1.0530884977290216 + * t + * (t - 1.3299890189751412) + * (t ** 2 - 1.4364028541716351 * t + 0.7139816917074209) + ) + b2 = 0.1017 * t ** 2 * (t ** 2 - 2.1966568338249754 * t + 1.2949852507374631) + b3 = ( + 2.490627285651252793 + * t ** 2 + * (t ** 2 - 2.38535645472061657 * t + 1.57803468208092486) + ) + b4 = ( + -16.54810288924490272 + * (t - 1.21712927295533244) + * (t - 0.61620406037800089) + * t ** 2 + ) + b5 = ( + 47.37952196281928122 + * (t - 1.203071208372362603) + * (t - 0.658047292653547382) + * t ** 2 + ) + b6 = -34.87065786149660974 * (t - 1.2) * (t - 0.666666666666666667) * t ** 2 + b7 = 2.5 * (t - 1) * (t - 0.6) * t ** 2 return [b1, b2, b3, b4, b5, b6, b7] @@ -44,19 +89,27 @@ def _interp_eval_tsit5(t0, t1, k, eval_t): dt = t1 - t0 y0 = tuple(k_[0] for k_ in k) interp_coeff = _interp_coeff_tsit5(t0, dt, eval_t) - y_t = tuple(y0_ + _scaled_dot_product(dt, interp_coeff, k_) for y0_, k_ in zip(y0, k)) + y_t = tuple( + y0_ + _scaled_dot_product(dt, interp_coeff, k_) for y0_, k_ in zip(y0, k) + ) return y_t -def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5): +def _optimal_step_size( + last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5 +): """Calculate the optimal size for the next Runge-Kutta step.""" if mean_error_ratio == 0: return last_step * ifactor if mean_error_ratio < 1: - dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device) + dfactor = _convert_to_tensor( + 1, dtype=torch.float64, device=mean_error_ratio.device + ) error_ratio = torch.sqrt(mean_error_ratio).type_as(last_step) exponent = torch.tensor(1 / order).type_as(last_step) - factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor)) + factor = torch.max( + 1 / ifactor, torch.min(error_ratio ** exponent / safety, 1 / dfactor) + ) return last_step / factor @@ -65,10 +118,20 @@ def _abs_square(x): class Tsit5Solver(AdaptiveStepsizeODESolver): - def __init__( - self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1, - grid_points=(), eps=0., **unused_kwargs + self, + func, + y0, + rtol, + atol, + first_step=None, + safety=0.9, + ifactor=10.0, + dfactor=0.2, + max_num_steps=2 ** 31 - 1, + grid_points=(), + eps=0.0, + **unused_kwargs, ): _handle_unused_kwargs(self, unused_kwargs) del unused_kwargs @@ -78,34 +141,57 @@ def __init__( self.rtol = rtol self.atol = atol self.first_step = first_step - self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) - self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) - self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) - self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device) - self.grid_points = tuple(_convert_to_tensor(point, dtype=torch.float64, device=y0[0].device) - for point in grid_points) + self.safety = _convert_to_tensor( + safety, dtype=torch.float64, device=y0[0].device + ) + self.ifactor = _convert_to_tensor( + ifactor, dtype=torch.float64, device=y0[0].device + ) + self.dfactor = _convert_to_tensor( + dfactor, dtype=torch.float64, device=y0[0].device + ) + self.max_num_steps = _convert_to_tensor( + max_num_steps, dtype=torch.int32, device=y0[0].device + ) + self.grid_points = tuple( + _convert_to_tensor(point, dtype=torch.float64, device=y0[0].device) + for point in grid_points + ) self.eps = _convert_to_tensor(eps, dtype=torch.float64, device=y0[0].device) def before_integrate(self, t): if self.first_step is None: - first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol, self.atol).to(t) + first_step = _select_initial_step( + self.func, t[0], self.y0, 4, self.rtol, self.atol + ).to(t) else: - first_step = _convert_to_tensor(self.first_step, dtype=t.dtype, device=t.device) + first_step = _convert_to_tensor( + self.first_step, dtype=t.dtype, device=t.device + ) self.rk_state = _RungeKuttaState( self.y0, - self.func(t[0].type_as(self.y0[0]), self.y0), t[0], t[0], first_step, - tuple(map(lambda x: [x] * 7, self.y0)) + self.func(t[0].type_as(self.y0[0]), self.y0), + t[0], + t[0], + first_step, + tuple(map(lambda x: [x] * 7, self.y0)), + ) + self.next_grid_index = min( + bisect.bisect(self.grid_points, t[0]), len(self.grid_points) - 1 ) - self.next_grid_index = min(bisect.bisect(self.grid_points, t[0]), len(self.grid_points) - 1) def advance(self, next_t): """Interpolate through the next time point, integrating as necessary.""" n_steps = 0 while next_t > self.rk_state.t1: - assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) + assert ( + n_steps < self.max_num_steps + ), "max_num_steps exceeded ({}>={})".format(n_steps, self.max_num_steps) self.rk_state = self._adaptive_tsit5_step(self.rk_state) n_steps += 1 - return _interp_eval_tsit5(self.rk_state.t0, self.rk_state.t1, self.rk_state.interp_coeff, next_t) + return _interp_eval_tsit5( + self.rk_state.t0, self.rk_state.t1, self.rk_state.interp_coeff, next_t + ) def _adaptive_tsit5_step(self, rk_state): """Take an adaptive Runge-Kutta step to integrate the ODE.""" @@ -113,14 +199,19 @@ def _adaptive_tsit5_step(self, rk_state): ######################################################## # Assertions # ######################################################## - assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) + assert t0 + dt > t0, "underflow in dt {}".format(dt.item()) for y0_ in y0: - assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) + assert _is_finite( + torch.abs(y0_) + ), "non-finite values in state `y`: {}".format(y0_) ######################################################## # Make step, respecting prescribed grid points # ######################################################## - on_grid = len(self.grid_points) and t0 < self.grid_points[self.next_grid_index] < t0 + dt + on_grid = ( + len(self.grid_points) + and t0 < self.grid_points[self.next_grid_index] < t0 + dt + ) if on_grid: dt = self.grid_points[self.next_grid_index] - t0 eps = min(0.5 * dt, self.eps) @@ -128,20 +219,27 @@ def _adaptive_tsit5_step(self, rk_state): else: eps = 0 - y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_TSITOURAS_TABLEAU) + y1, f1, y1_error, k = _runge_kutta_step( + self.func, y0, f0, t0, dt, tableau=_TSITOURAS_TABLEAU + ) ######################################################## # Error Ratio # ######################################################## - error_tol = tuple(self.atol + self.rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1)) - tensor_error_ratio = tuple(y1_error_ / error_tol_ for y1_error_, error_tol_ in zip(y1_error, error_tol)) - sq_error_ratio = tuple( - torch.mul(tensor_error_ratio_, tensor_error_ratio_) for tensor_error_ratio_ in tensor_error_ratio + error_tol = tuple( + self.atol + self.rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) + for y0_, y1_ in zip(y0, y1) + ) + tensor_error_ratio = tuple( + y1_error_ / error_tol_ for y1_error_, error_tol_ in zip(y1_error, error_tol) ) - mean_error_ratio = ( - sum(torch.sum(sq_error_ratio_) for sq_error_ratio_ in sq_error_ratio) / - sum(sq_error_ratio_.numel() for sq_error_ratio_ in sq_error_ratio) + sq_error_ratio = tuple( + torch.mul(tensor_error_ratio_, tensor_error_ratio_) + for tensor_error_ratio_ in tensor_error_ratio ) + mean_error_ratio = sum( + torch.sum(sq_error_ratio_) for sq_error_ratio_ in sq_error_ratio + ) / sum(sq_error_ratio_.numel() for sq_error_ratio_ in sq_error_ratio) accept_step = mean_error_ratio <= 1 ######################################################## @@ -157,7 +255,9 @@ def _adaptive_tsit5_step(self, rk_state): if self.next_grid_index != len(self.grid_points) - 1: self.next_grid_index += 1 f_next = f1 if accept_step else f0 - dt_next = _optimal_step_size(dt, mean_error_ratio, self.safety, self.ifactor, self.dfactor) + dt_next = _optimal_step_size( + dt, mean_error_ratio, self.safety, self.ifactor, self.dfactor + ) k_next = k if accept_step else self.rk_state.interp_coeff rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, k_next) return rk_state diff --git a/utils.py b/utils.py index f6c4b2b..7cd7155 100644 --- a/utils.py +++ b/utils.py @@ -3,85 +3,102 @@ import logging import os + def setup_logger(name): cur_dir = os.getcwd() - if not os.path.exists(cur_dir+'/log/'): - os.mkdir(cur_dir+'/log/') - logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', - datefmt='%m-%d %H:%M', - filename=cur_dir+'/log/'+name+'.log', - filemode='a') + if not os.path.exists(cur_dir + "/log/"): + os.mkdir(cur_dir + "/log/") + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s", + datefmt="%m-%d %H:%M", + filename=cur_dir + "/log/" + name + ".log", + filemode="a", + ) logger = logging.getLogger(name) return logger + def setup_tKG(dataset, logger, embsize, scale, val_exist, input_step): # load data after preprocess - loaddata = open('{}/data_tKG.pkl'.format(dataset), 'rb') + loaddata = open("{}/data_tKG.pkl".format(dataset), "rb") data = pickle.load(loaddata) loaddata.close() - loadsr2o = open('{}/sr2o_all_tKG.pkl'.format(dataset), 'rb') + loadsr2o = open("{}/sr2o_all_tKG.pkl".format(dataset), "rb") sr2o = pickle.load(loadsr2o) loadsr2o.close() - loadso2r = open('{}/so2r_all_tKG.pkl'.format(dataset), 'rb') + loadso2r = open("{}/so2r_all_tKG.pkl".format(dataset), "rb") so2r = pickle.load(loadso2r) loadso2r.close() - loadtriples = open('{}/triples_tKG.pkl'.format(dataset), 'rb') + loadtriples = open("{}/triples_tKG.pkl".format(dataset), "rb") triples = pickle.load(loadtriples) loadtriples.close() - loadadjs = open('{}/adjs_tKG.pkl'.format(dataset), 'rb') + loadadjs = open("{}/adjs_tKG.pkl".format(dataset), "rb") adjs = pickle.load(loadadjs) loadadjs.close() - loadtimestamp = open('{}/timestamp_tKG.pkl'.format(dataset), 'rb') + loadtimestamp = open("{}/timestamp_tKG.pkl".format(dataset), "rb") timestamp = pickle.load(loadtimestamp) loadtimestamp.close() - loadindep = open('{}/t_indep_trp.pkl'.format(dataset), 'rb') + loadindep = open("{}/t_indep_trp.pkl".format(dataset), "rb") t_indep_trp = pickle.load(loadindep) loadindep.close() - loadnei = open('{}/neighbor_tKG.pkl'.format(dataset), 'rb') + loadnei = open("{}/neighbor_tKG.pkl".format(dataset), "rb") neighbor = pickle.load(loadnei) loadnei.close() def get_total_number(inPath, fileName): - with open(os.path.join(inPath, fileName), 'r') as fr: + with open(os.path.join(inPath, fileName), "r") as fr: for line in fr: line_split = line.split() return int(line_split[0]), int(line_split[1]) - num_e, num_rel = get_total_number('{}/'.format(dataset), 'stat.txt') + num_e, num_rel = get_total_number("{}/".format(dataset), "stat.txt") logger.info("number of entities:" + str(num_e)) logger.info("number of relations:" + str(num_rel)) if val_exist: # timestamps # normalize timestamps, and scale - ts_max = max(max(max(timestamp['train']), max(timestamp['test'])), max(timestamp['valid'])) # max timestamp in the dataset - train_timestamps = (torch.tensor(timestamp['train']) / torch.tensor(ts_max, dtype=torch.float)) * scale - test_timestamps = (torch.tensor(timestamp['test']) / torch.tensor(ts_max, dtype=torch.float)) * scale - val_timestamps = (torch.tensor(timestamp['valid']) / torch.tensor(ts_max, dtype=torch.float)) * scale + ts_max = max( + max(max(timestamp["train"]), max(timestamp["test"])), + max(timestamp["valid"]), + ) # max timestamp in the dataset + train_timestamps = ( + torch.tensor(timestamp["train"]) / torch.tensor(ts_max, dtype=torch.float) + ) * scale + test_timestamps = ( + torch.tensor(timestamp["test"]) / torch.tensor(ts_max, dtype=torch.float) + ) * scale + val_timestamps = ( + torch.tensor(timestamp["valid"]) / torch.tensor(ts_max, dtype=torch.float) + ) * scale # extend val and test timestamps - val_timestamps = torch.cat([train_timestamps[-input_step:], val_timestamps], dim=0) - test_timestamps = torch.cat([val_timestamps[-input_step:], test_timestamps], dim=0) - - print("number of training snapshots:", len(timestamp['train'])) - print("number of validation snapshots:", len(timestamp['valid'])) - print("number of testing snapshots:", len(timestamp['test'])) - logger.info("number of training snapshots:" + str(len(timestamp['train']))) - logger.info("number of validation snapshots:" + str(len(timestamp['valid']))) - logger.info("number of testing snapshots:" + str(len(timestamp['test']))) + val_timestamps = torch.cat( + [train_timestamps[-input_step:], val_timestamps], dim=0 + ) + test_timestamps = torch.cat( + [val_timestamps[-input_step:], test_timestamps], dim=0 + ) + + print("number of training snapshots:", len(timestamp["train"])) + print("number of validation snapshots:", len(timestamp["valid"])) + print("number of testing snapshots:", len(timestamp["test"])) + logger.info("number of training snapshots:" + str(len(timestamp["train"]))) + logger.info("number of validation snapshots:" + str(len(timestamp["valid"]))) + logger.info("number of testing snapshots:" + str(len(timestamp["test"]))) # adjs - train_adj = adjs['train'] - test_adj = adjs['test'] - val_adj = adjs['valid'] + train_adj = adjs["train"] + test_adj = adjs["test"] + val_adj = adjs["valid"] # extend val and test adj val_adj_extend = train_adj[-input_step:] @@ -90,9 +107,9 @@ def get_total_number(inPath, fileName): test_adj = test_adj_extend + test_adj # triples - train_triple = triples['train'] - val_triple = triples['valid'] - test_triple = triples['test'] + train_triple = triples["train"] + val_triple = triples["valid"] + test_triple = triples["test"] # extend val and test triples val_triple_extend = train_triple[-input_step:] @@ -101,9 +118,9 @@ def get_total_number(inPath, fileName): test_triple = test_triple_extend + test_triple # one hop neighbor - train_1nei = neighbor['train'] - test_1nei = neighbor['test'] - val_1nei = neighbor['valid'] + train_1nei = neighbor["train"] + test_1nei = neighbor["test"] + val_1nei = neighbor["valid"] # extend val and test neighbor val_1nei_extend = train_1nei[-input_step:] @@ -112,9 +129,9 @@ def get_total_number(inPath, fileName): test_1nei = test_1nei_extend + test_1nei # so2r - train_so2r = so2r['train'] - val_so2r = so2r['valid'] - test_so2r = so2r['test'] + train_so2r = so2r["train"] + val_so2r = so2r["valid"] + test_so2r = so2r["test"] # extend val and test so2r val_so2r_extend = train_so2r[-input_step:] @@ -125,128 +142,187 @@ def get_total_number(inPath, fileName): else: # timestamps # normalize timestamps, and scale - ts_max = max(max(timestamp['train']), max(timestamp['test'])) # max timestamp in the dataset - train_timestamps = (torch.tensor(timestamp['train']) / torch.tensor(ts_max, dtype=torch.float)) * scale - test_timestamps = (torch.tensor(timestamp['test']) / torch.tensor(ts_max, dtype=torch.float)) * scale - #train_timestamps = torch.tensor(timestamp['train']) * scale / 2.4 - #test_timestamps = torch.tensor(timestamp['test']) * scale / 2.4 + ts_max = max( + max(timestamp["train"]), max(timestamp["test"]) + ) # max timestamp in the dataset + train_timestamps = ( + torch.tensor(timestamp["train"]) / torch.tensor(ts_max, dtype=torch.float) + ) * scale + test_timestamps = ( + torch.tensor(timestamp["test"]) / torch.tensor(ts_max, dtype=torch.float) + ) * scale + # train_timestamps = torch.tensor(timestamp['train']) * scale / 2.4 + # test_timestamps = torch.tensor(timestamp['test']) * scale / 2.4 # extend test timestamps - test_timestamps = torch.cat([train_timestamps[-input_step:], test_timestamps], dim=0) + test_timestamps = torch.cat( + [train_timestamps[-input_step:], test_timestamps], dim=0 + ) - print("number of training snapshots:", len(timestamp['train'])) - print("number of testing snapshots:", len(timestamp['test'])) - logger.info("number of training snapshots:" + str(len(timestamp['train']))) - logger.info("number of testing snapshots:" + str(len(timestamp['test']))) + print("number of training snapshots:", len(timestamp["train"])) + print("number of testing snapshots:", len(timestamp["test"])) + logger.info("number of training snapshots:" + str(len(timestamp["train"]))) + logger.info("number of testing snapshots:" + str(len(timestamp["test"]))) # adjs - train_adj = adjs['train'] - test_adj = adjs['test'] + train_adj = adjs["train"] + test_adj = adjs["test"] # extend test adj test_adj_extend = train_adj[-input_step:] test_adj = test_adj_extend + test_adj # triples - train_triple = triples['train'] - test_triple = triples['test'] + train_triple = triples["train"] + test_triple = triples["test"] # extend test triples test_triple_extend = train_triple[-input_step:] test_triple = test_triple_extend + test_triple # one hop neighbor - train_1nei = neighbor['train'] - test_1nei = neighbor['test'] + train_1nei = neighbor["train"] + test_1nei = neighbor["test"] # extend test neighbor test_1nei_extend = train_1nei[-input_step:] test_1nei = test_1nei_extend + test_1nei # so2r - train_so2r = so2r['train'] - test_so2r = so2r['test'] + train_so2r = so2r["train"] + test_so2r = so2r["test"] # extend val and test so2r test_so2r_extend = train_so2r[-input_step:] test_so2r = test_so2r_extend + test_so2r if val_exist: - return num_e, num_rel, train_timestamps, test_timestamps, val_timestamps, train_adj, test_adj, val_adj, \ - train_triple, test_triple, val_triple, train_1nei, test_1nei, val_1nei, t_indep_trp, train_so2r, val_so2r, test_so2r - #return num_e, num_rel, train_node_feature, test_node_feature, val_node_feature, train_timestamps, test_timestamps, val_timestamps, train_adj, test_adj, val_adj, triples + return ( + num_e, + num_rel, + train_timestamps, + test_timestamps, + val_timestamps, + train_adj, + test_adj, + val_adj, + train_triple, + test_triple, + val_triple, + train_1nei, + test_1nei, + val_1nei, + t_indep_trp, + train_so2r, + val_so2r, + test_so2r, + ) + # return num_e, num_rel, train_node_feature, test_node_feature, val_node_feature, train_timestamps, test_timestamps, val_timestamps, train_adj, test_adj, val_adj, triples else: - return num_e, num_rel, train_timestamps, test_timestamps, train_adj, test_adj, train_triple, test_triple, train_1nei, test_1nei, t_indep_trp, train_so2r, test_so2r - #return num_e, num_rel, train_node_feature, test_node_feature, train_timestamps, test_timestamps, train_adj, test_adj, triples + return ( + num_e, + num_rel, + train_timestamps, + test_timestamps, + train_adj, + test_adj, + train_triple, + test_triple, + train_1nei, + test_1nei, + t_indep_trp, + train_so2r, + test_so2r, + ) + # return num_e, num_rel, train_node_feature, test_node_feature, train_timestamps, test_timestamps, train_adj, test_adj, triples + def setup_tKG2(dataset, logger, embsize, scale, val_exist, input_step): # load data after preprocess - loaddata = open('{}/data_tKG_j.pkl'.format(dataset), 'rb') + loaddata = open("{}/data_tKG_j.pkl".format(dataset), "rb") data = pickle.load(loaddata) loaddata.close() - loadsr2o = open('{}/sr2o_all_tKG_j.pkl'.format(dataset), 'rb') + loadsr2o = open("{}/sr2o_all_tKG_j.pkl".format(dataset), "rb") sr2o = pickle.load(loadsr2o) loadsr2o.close() - loadso2r = open('{}/so2r_all_tKG_j.pkl'.format(dataset), 'rb') + loadso2r = open("{}/so2r_all_tKG_j.pkl".format(dataset), "rb") so2r = pickle.load(loadso2r) loadso2r.close() - loadtriples = open('{}/triples_tKG_j.pkl'.format(dataset), 'rb') + loadtriples = open("{}/triples_tKG_j.pkl".format(dataset), "rb") triples = pickle.load(loadtriples) loadtriples.close() - loadadjs = open('{}/adjs_tKG_j.pkl'.format(dataset), 'rb') + loadadjs = open("{}/adjs_tKG_j.pkl".format(dataset), "rb") adjs = pickle.load(loadadjs) loadadjs.close() - loadtimestamp = open('{}/timestamp_tKG_j.pkl'.format(dataset), 'rb') + loadtimestamp = open("{}/timestamp_tKG_j.pkl".format(dataset), "rb") timestamp = pickle.load(loadtimestamp) loadtimestamp.close() - loadindep = open('{}/t_indep_trp_j.pkl'.format(dataset), 'rb') + loadindep = open("{}/t_indep_trp_j.pkl".format(dataset), "rb") t_indep_trp = pickle.load(loadindep) loadindep.close() - loadnei = open('{}/neighbor_tKG_j.pkl'.format(dataset), 'rb') + loadnei = open("{}/neighbor_tKG_j.pkl".format(dataset), "rb") neighbor = pickle.load(loadnei) loadnei.close() def get_total_number(inPath, fileName): - with open(os.path.join(inPath, fileName), 'r') as fr: + with open(os.path.join(inPath, fileName), "r") as fr: for line in fr: line_split = line.split() return int(line_split[0]), int(line_split[1]) - num_e, num_rel = get_total_number('{}/'.format(dataset), 'stat.txt') - #num_rel *= 2 # include inv rel + num_e, num_rel = get_total_number("{}/".format(dataset), "stat.txt") + # num_rel *= 2 # include inv rel logger.info("number of entities:" + str(num_e)) logger.info("number of relations:" + str(num_rel)) if val_exist: # timestamps # normalize timestamps, and scale - ts_max = max(max(max(timestamp['train_jump']), max(timestamp['test_jump'])), max(timestamp['valid_jump'])) # max timestamp in the dataset - train_timestamps = (torch.tensor(timestamp['train_jump']) / torch.tensor(ts_max, dtype=torch.float)) * scale - test_timestamps = (torch.tensor(timestamp['test_jump']) / torch.tensor(ts_max, dtype=torch.float)) * scale - val_timestamps = (torch.tensor(timestamp['valid_jump']) / torch.tensor(ts_max, dtype=torch.float)) * scale + ts_max = max( + max(max(timestamp["train_jump"]), max(timestamp["test_jump"])), + max(timestamp["valid_jump"]), + ) # max timestamp in the dataset + train_timestamps = ( + torch.tensor(timestamp["train_jump"]) + / torch.tensor(ts_max, dtype=torch.float) + ) * scale + test_timestamps = ( + torch.tensor(timestamp["test_jump"]) + / torch.tensor(ts_max, dtype=torch.float) + ) * scale + val_timestamps = ( + torch.tensor(timestamp["valid_jump"]) + / torch.tensor(ts_max, dtype=torch.float) + ) * scale # extend val and test timestamps - val_timestamps = torch.cat([train_timestamps[-input_step:], val_timestamps], dim=0) - test_timestamps = torch.cat([val_timestamps[-input_step:], test_timestamps], dim=0) - - print("number of training snapshots:", len(timestamp['train_jump'])) - print("number of validation snapshots:", len(timestamp['valid_jump'])) - print("number of testing snapshots:", len(timestamp['test_jump'])) - logger.info("number of training snapshots:" + str(len(timestamp['train_jump']))) - logger.info("number of validation snapshots:" + str(len(timestamp['valid_jump']))) - logger.info("number of testing snapshots:" + str(len(timestamp['test_jump']))) + val_timestamps = torch.cat( + [train_timestamps[-input_step:], val_timestamps], dim=0 + ) + test_timestamps = torch.cat( + [val_timestamps[-input_step:], test_timestamps], dim=0 + ) + + print("number of training snapshots:", len(timestamp["train_jump"])) + print("number of validation snapshots:", len(timestamp["valid_jump"])) + print("number of testing snapshots:", len(timestamp["test_jump"])) + logger.info("number of training snapshots:" + str(len(timestamp["train_jump"]))) + logger.info( + "number of validation snapshots:" + str(len(timestamp["valid_jump"])) + ) + logger.info("number of testing snapshots:" + str(len(timestamp["test_jump"]))) # adjs - train_adj = adjs['train_jump'] - test_adj = adjs['test_jump'] - val_adj = adjs['valid_jump'] + train_adj = adjs["train_jump"] + test_adj = adjs["test_jump"] + val_adj = adjs["valid_jump"] # extend val and test adj val_adj_extend = train_adj[-input_step:] @@ -255,9 +331,9 @@ def get_total_number(inPath, fileName): test_adj = test_adj_extend + test_adj # triples - train_triple = triples['train_jump'] - val_triple = triples['valid_jump'] - test_triple = triples['test_jump'] + train_triple = triples["train_jump"] + val_triple = triples["valid_jump"] + test_triple = triples["test_jump"] # extend val and test triples val_triple_extend = train_triple[-input_step:] @@ -266,9 +342,9 @@ def get_total_number(inPath, fileName): test_triple = test_triple_extend + test_triple # one hop neighbor - train_1nei = neighbor['train_jump'] - test_1nei = neighbor['test_jump'] - val_1nei = neighbor['valid_jump'] + train_1nei = neighbor["train_jump"] + test_1nei = neighbor["test_jump"] + val_1nei = neighbor["valid_jump"] # extend val and test neighbor val_1nei_extend = train_1nei[-input_step:] @@ -277,9 +353,9 @@ def get_total_number(inPath, fileName): test_1nei = test_1nei_extend + test_1nei # so2r - train_so2r = so2r['train_jump'] - val_so2r = so2r['valid_jump'] - test_so2r = so2r['test_jump'] + train_so2r = so2r["train_jump"] + val_so2r = so2r["valid_jump"] + test_so2r = so2r["test_jump"] # extend val and test so2r val_so2r_extend = train_so2r[-input_step:] @@ -290,138 +366,195 @@ def get_total_number(inPath, fileName): else: # timestamps # normalize timestamps, and scale - ts_max = max(max(timestamp['train']), max(timestamp['test_jump'])) # max timestamp in the dataset - train_timestamps = (torch.tensor(timestamp['train_jump']) / torch.tensor(ts_max, dtype=torch.float)) * scale - test_timestamps = (torch.tensor(timestamp['test_jump']) / torch.tensor(ts_max, dtype=torch.float)) * scale - #train_timestamps = torch.tensor(timestamp['train']) * scale / 2.4 - #test_timestamps = torch.tensor(timestamp['test']) * scale / 2.4 + ts_max = max( + max(timestamp["train"]), max(timestamp["test_jump"]) + ) # max timestamp in the dataset + train_timestamps = ( + torch.tensor(timestamp["train_jump"]) + / torch.tensor(ts_max, dtype=torch.float) + ) * scale + test_timestamps = ( + torch.tensor(timestamp["test_jump"]) + / torch.tensor(ts_max, dtype=torch.float) + ) * scale + # train_timestamps = torch.tensor(timestamp['train']) * scale / 2.4 + # test_timestamps = torch.tensor(timestamp['test']) * scale / 2.4 # extend test timestamps - test_timestamps = torch.cat([train_timestamps[-input_step:], test_timestamps], dim=0) + test_timestamps = torch.cat( + [train_timestamps[-input_step:], test_timestamps], dim=0 + ) - print("number of training snapshots:", len(timestamp['train_jump'])) - print("number of testing snapshots:", len(timestamp['test_jump'])) - logger.info("number of training snapshots:" + str(len(timestamp['train_jump']))) - logger.info("number of testing snapshots:" + str(len(timestamp['test_jump']))) + print("number of training snapshots:", len(timestamp["train_jump"])) + print("number of testing snapshots:", len(timestamp["test_jump"])) + logger.info("number of training snapshots:" + str(len(timestamp["train_jump"]))) + logger.info("number of testing snapshots:" + str(len(timestamp["test_jump"]))) # adjs - train_adj = adjs['train_jump'] - test_adj = adjs['test_jump'] + train_adj = adjs["train_jump"] + test_adj = adjs["test_jump"] # extend test adj test_adj_extend = train_adj[-input_step:] test_adj = test_adj_extend + test_adj # triples - train_triple = triples['train_jump'] - test_triple = triples['test_jump'] + train_triple = triples["train_jump"] + test_triple = triples["test_jump"] # extend test triples test_triple_extend = train_triple[-input_step:] test_triple = test_triple_extend + test_triple # one hop neighbor - train_1nei = neighbor['train_jump'] - test_1nei = neighbor['test_jump'] + train_1nei = neighbor["train_jump"] + test_1nei = neighbor["test_jump"] # extend test neighbor test_1nei_extend = train_1nei[-input_step:] test_1nei = test_1nei_extend + test_1nei # so2r - train_so2r = so2r['train_jump'] - test_so2r = so2r['test_jump'] + train_so2r = so2r["train_jump"] + test_so2r = so2r["test_jump"] # extend val and test so2r test_so2r_extend = train_so2r[-input_step:] test_so2r = test_so2r_extend + test_so2r if val_exist: - return num_e, num_rel, train_timestamps, test_timestamps, val_timestamps, train_adj, test_adj, val_adj, \ - train_triple, test_triple, val_triple, train_1nei, test_1nei, val_1nei, t_indep_trp, train_so2r, val_so2r, test_so2r - #return num_e, num_rel, train_node_feature, test_node_feature, val_node_feature, train_timestamps, test_timestamps, val_timestamps, train_adj, test_adj, val_adj, triples + return ( + num_e, + num_rel, + train_timestamps, + test_timestamps, + val_timestamps, + train_adj, + test_adj, + val_adj, + train_triple, + test_triple, + val_triple, + train_1nei, + test_1nei, + val_1nei, + t_indep_trp, + train_so2r, + val_so2r, + test_so2r, + ) + # return num_e, num_rel, train_node_feature, test_node_feature, val_node_feature, train_timestamps, test_timestamps, val_timestamps, train_adj, test_adj, val_adj, triples else: - return num_e, num_rel, train_timestamps, test_timestamps, train_adj, test_adj, train_triple, test_triple, train_1nei, test_1nei, t_indep_trp, train_so2r, test_so2r - #return num_e, num_rel, train_node_feature, test_node_feature, train_timestamps, test_timestamps, train_adj, test_adj, triples + return ( + num_e, + num_rel, + train_timestamps, + test_timestamps, + train_adj, + test_adj, + train_triple, + test_triple, + train_1nei, + test_1nei, + t_indep_trp, + train_so2r, + test_so2r, + ) + # return num_e, num_rel, train_node_feature, test_node_feature, train_timestamps, test_timestamps, train_adj, test_adj, triples + def setup_induct_test(dataset, logger, scale, input_step): print("Preparing for inductive test...") # load data after preprocess - loaddata = open('{}/data_tKG.pkl'.format(dataset), 'rb') + loaddata = open("{}/data_tKG.pkl".format(dataset), "rb") data = pickle.load(loaddata) loaddata.close() - loadsr2o = open('{}/sr2o_all_tKG.pkl'.format(dataset), 'rb') + loadsr2o = open("{}/sr2o_all_tKG.pkl".format(dataset), "rb") sr2o = pickle.load(loadsr2o) loadsr2o.close() - loadso2r = open('{}/so2r_all_tKG.pkl'.format(dataset), 'rb') + loadso2r = open("{}/so2r_all_tKG.pkl".format(dataset), "rb") so2r = pickle.load(loadso2r) loadso2r.close() - loadtriples = open('{}/triples_tKG.pkl'.format(dataset), 'rb') + loadtriples = open("{}/triples_tKG.pkl".format(dataset), "rb") triples = pickle.load(loadtriples) loadtriples.close() - loadadjs = open('{}/adjs_tKG.pkl'.format(dataset), 'rb') + loadadjs = open("{}/adjs_tKG.pkl".format(dataset), "rb") adjs = pickle.load(loadadjs) loadadjs.close() - loadtimestamp = open('{}/timestamp_tKG.pkl'.format(dataset), 'rb') + loadtimestamp = open("{}/timestamp_tKG.pkl".format(dataset), "rb") timestamp = pickle.load(loadtimestamp) loadtimestamp.close() - loadindep = open('{}/t_indep_trp.pkl'.format(dataset), 'rb') + loadindep = open("{}/t_indep_trp.pkl".format(dataset), "rb") t_indep_trp = pickle.load(loadindep) loadindep.close() - loadnei = open('{}/neighbor_tKG.pkl'.format(dataset), 'rb') + loadnei = open("{}/neighbor_tKG.pkl".format(dataset), "rb") neighbor = pickle.load(loadnei) loadnei.close() - loadinduct = open('{}/inductive.pkl'.format(dataset), 'rb') + loadinduct = open("{}/inductive.pkl".format(dataset), "rb") induct = pickle.load(loadinduct) loadinduct.close() def get_total_number(inPath, fileName): - with open(os.path.join(inPath, fileName), 'r') as fr: + with open(os.path.join(inPath, fileName), "r") as fr: for line in fr: line_split = line.split() return int(line_split[0]), int(line_split[1]) - num_e, num_rel = get_total_number('{}/'.format(dataset), 'stat.txt') + num_e, num_rel = get_total_number("{}/".format(dataset), "stat.txt") logger.info("number of entities:" + str(num_e)) logger.info("number of relations:" + str(num_rel)) # timestamps # normalize timestamps, and scale - ts_max = max(max(max(timestamp['train']), max(timestamp['test'])), - max(timestamp['valid'])) # max timestamp in the dataset - test_timestamps = (torch.tensor(timestamp['test']) / torch.tensor(ts_max, dtype=torch.float)) * scale + ts_max = max( + max(max(timestamp["train"]), max(timestamp["test"])), max(timestamp["valid"]) + ) # max timestamp in the dataset + test_timestamps = ( + torch.tensor(timestamp["test"]) / torch.tensor(ts_max, dtype=torch.float) + ) * scale - print("number of testing snapshots:", len(timestamp['test'])) - logger.info("number of testing snapshots:" + str(len(timestamp['test']))) + print("number of testing snapshots:", len(timestamp["test"])) + logger.info("number of testing snapshots:" + str(len(timestamp["test"]))) # adjs - test_adj = adjs['test'] + test_adj = adjs["test"] # test triples - test_triple = triples['test'] + test_triple = triples["test"] # one hop neighbor - test_1nei = neighbor['test'] + test_1nei = neighbor["test"] # so2r - test_so2r = so2r['test'] + test_so2r = so2r["test"] # inductive evaluations induct_tar = induct - return num_e, num_rel, test_timestamps, test_adj, test_triple, test_1nei, t_indep_trp, test_so2r, induct_tar + return ( + num_e, + num_rel, + test_timestamps, + test_adj, + test_triple, + test_1nei, + t_indep_trp, + test_so2r, + induct_tar, + ) + def load_adjmtx(dataset): - loadadjlist = open('{}/adjlist_tKG.pkl'.format(dataset), 'rb') + loadadjlist = open("{}/adjlist_tKG.pkl".format(dataset), "rb") adjlist = pickle.load(loadadjlist) loadadjlist.close() - return adjlist \ No newline at end of file + return adjlist