-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdeepchem_benchmark.py
More file actions
79 lines (62 loc) · 2.91 KB
/
deepchem_benchmark.py
File metadata and controls
79 lines (62 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from torch.autograd import Variable
import deepchem as dc
from deepchem.feat import Featurizer
from dataloader import *
import model
# * Source code for fc models (tf for classification, tf_regression for regression)
# https://github.com/deepchem/deepchem/blob/master/deepchem/models/tensorgraph/fcnet.py
ckpt = './model/model_ck_004_000016000.tar'
def benchmark(ckpt_file):
# load checkpoint file
checkpoint = torch.load(ckpt_file, map_location=torch.device('cpu'))
args = checkpoint['args']
args.batch_size = 1
args.test_batch_size = 1
comet = model.Encoder(args)
comet.load_state_dict(checkpoint['encoder'])
def mol_to_graph(mol):
#print(mol)
#mol = Chem.MolFromSmiles(mol)
adj = Chem.rdmolops.GetAdjacencyMatrix(mol)
list_feature = list()
for atom in mol.GetAtoms():
list_feature.append(atom_feature(atom))
return np.array(list_feature), adj
class MyEncoder(Featurizer):
name = ['comet_encoder']
def __init__(self, model):
self.model = model
def _featurize(self, mol):
X, A = mol_to_graph(mol)
X = Variable(torch.unsqueeze(torch.from_numpy(X), dim=0)).long()
A = Variable(torch.unsqueeze(torch.from_numpy(A.astype(float)), dim=0)).float()
_, _, molvec = self.model(X, A)
return torch.squeeze(molvec)
filename = args.model_name
reg_path = './benchmark/'+'reg_'+filename+'.csv'
cls_path = './benchmark/'+'cls_'+filename+'.csv'
reg_tasks = dc.molnet.run_benchmark(
datasets= ['bace_c', 'bbbp', 'clintox', 'hiv', 'muv', 'pcba', 'sider', 'tox21', 'toxcast'],
model = 'tf',
split = None,
metric = None,
featurizer = MyEncoder(comet),
out_path= reg_path,
hyper_parameters = None,
test = True,
reload = False,
seed = 123 )
cls_tasks = dc.molnet.run_benchmark(
datasets= ['bace_r', 'chembl', 'clearance', 'delaney', 'hopv', 'kaggle', 'lipo',
'nci', 'pdbbind', 'ppb', 'qm7', 'qm7b', 'qm8', 'qm9', 'sampl'],
model = 'tf_regression',
split = None,
metric = None,
featurizer = MyEncoder(comet),
out_path= cls_path,
hyper_parameters = None,
test = True,
reload = False,
seed = 123 )
return reg_tasks, cls_tasks
benchmark(ckpt)