-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path01_extrapolation_trajectories.py
More file actions
136 lines (114 loc) · 5.55 KB
/
01_extrapolation_trajectories.py
File metadata and controls
136 lines (114 loc) · 5.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python
# coding: utf-8
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from os.path import isfile, join, exists
from os import listdir, chdir
from scipy.stats import spearmanr
from os.path import abspath
import sys
# for relative paths in nn4dms code to work properly, we need to set the current working
# directory to the root of the project
# we also need to add the code folder to the system path for imports to work properly
print('Setting working directory to nn4dms root.')
chdir('nn4dms_nn-extrapolate')
module_path = abspath("code")
if module_path not in sys.path:
sys.path.append(module_path)
# add relative path to write directory (nn-extrapolation)
nnextrap_root_relpath = ".."
pretrained_dir = "nn-extrapolation-models/pretrained_models"
import encode as enc
import inference as inf
import inference_lr as inf_lr
import design_tools as dt
CHARS = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L",
"M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
WT = "MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE"
models = ['lr', 'fcn', 'gcn', 'cnn']
for direction in ['all', 'all_down', 'wt']:
func_all_list = []
label_list = []
for model in models:
num_models = 100
found_models = []
model_paths = []
for i in range(num_models):
path = join(nnextrap_root_relpath, pretrained_dir, model+'s/model_'+str(i))
if (exists(path)):
for file_name in listdir(path):
if '.pb' in file_name:
model_name = file_name
model_paths.append(path+'/'+model_name)
found_models.append(i)
if len(model_paths) != num_models:
print('Could not find all models, missing models: ',
','.join([str(i) for i in range(100) if i not in found_models]))
model_sesses = []
for model_path in model_paths:
model_sesses.append(inf.restore_sess_from_pb(model_path))
# calculate wt fitness for each model
if direction == 'wt':
seqs = [WT]
func_all = []
encoded_variants = enc.encode(encoding="one_hot,aa_index", char_seqs=seqs, wt_aa=[aa for aa in WT])
functions_all = []
for sess in model_sesses: # TODO: change back to all cnns
if model == 'lr':
functions_all.append(inf_lr.run_inference_lr(encoded_data=encoded_variants, sess=sess))
else:
functions_all.append(inf.run_inference(encoded_data=encoded_variants, sess=sess))
functions = np.median(functions_all, axis=0)
seqs_mut_df = pd.DataFrame(data=list(zip(seqs, np.array(functions_all).T, [functions])),
columns=['seq', 'func_all', 'func'])
func_all_list.append([np.array(functions_all)])
label_list.append(model+"_func")
# make upward or downward trajectory for each model
else:
curr_muts = []
func_all = []
med_fits = []
num_muts = 55
for i in range(55):
print("starting mutation ", i)
curr_poss = [int(mut[1:-1]) for mut in curr_muts]
# look at all possible neighbors
possible_muts = []
for pos in range(1, len(WT)):
if pos not in curr_poss: # don't mutate already mutated positions
for aa in CHARS:
if WT[pos] != aa: # not mutating to self
mut = WT[pos]+str(pos)+aa
possible_muts.append(mut)
# calculate fitness
join_muts = [curr_muts+[mut] for mut in possible_muts]
seqs = [dt.mut2seq(WT, muts) for muts in join_muts]
encoded_variants = enc.encode(encoding="one_hot,aa_index", char_seqs=seqs, wt_aa=[aa for aa in WT])
functions_all = []
for sess in model_sesses:
if model == 'lr':
functions_all.append(inf_lr.run_inference_lr(encoded_data=encoded_variants, sess=sess))
else:
functions_all.append(inf.run_inference(encoded_data=encoded_variants, sess=sess))
functions = np.median(functions_all, axis=0)
# get next mutant in trajectory as min/max of median model prediction
seqs_mut_df = pd.DataFrame(data=list(zip(possible_muts, join_muts, seqs, np.array(functions_all).T, functions)),
columns=['added_mut', 'join_mut', 'seq', 'func_all', 'func'])
if direction == 'all':
seqs_mut_df.sort_values('func', ascending=False, inplace=True)
else:
seqs_mut_df.sort_values('func', ascending=True, inplace=True)
curr_muts.append(seqs_mut_df.iloc[0]['added_mut'])
func_all.append(seqs_mut_df.iloc[0]['func_all'])
med_fits.append(seqs_mut_df.iloc[0]['func'])
num_muts -= 1
if num_muts <= 0:
break
# save steps
func_all_list.append(curr_muts)
func_all_list.append(func_all)
label_list.append(model+"_mut")
label_list.append(model+"_func")
func_df = pd.DataFrame(data=list(zip(*func_all_list)), columns=label_list)
func_df.to_csv(join(nnextrap_root_relpath, 'gen_data/mut_func_'+direction+'.csv'))