-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathProtFeatureGUI.py
More file actions
129 lines (113 loc) · 5.72 KB
/
ProtFeatureGUI.py
File metadata and controls
129 lines (113 loc) · 5.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import tkinter as tk
from tkinter import ttk, filedialog
import pathlib, threading, queue, torch
from utils import dict2csv, arr2dict
from PCA import PCA
from protein import Protein
from message_hub import hub
class ProtFeatureGUI:
def __init__(self, root):
self.root = root
root.title("ESM 饱和突变特征提取 & PCA")
pad = dict(padx=5, pady=5, sticky="w")
self.job_done = False
self.uniprot_id = tk.StringVar()
self.out_dir = tk.StringVar()
self.use_cache = tk.BooleanVar(value=True) # 默认勾选
ttk.Label(self.root, text="UniProt ID:").grid(row=0, column=0, **pad)
ttk.Entry(self.root, textvariable=self.uniprot_id, width=15).grid(row=0, column=1, **pad)
ttk.Checkbutton(self.root, text="启用缓存", variable=self.use_cache).grid(row=0, column=2, columnspan=2, **pad)
ttk.Label(self.root, text="输出文件夹:").grid(row=1, column=0, **pad)
ttk.Entry(self.root, textvariable=self.out_dir, width=40).grid(row=1, column=1, **pad)
ttk.Button(self.root, text="浏览", command=self.browse).grid(row=1, column=2, **pad)
self.run_btn = ttk.Button(self.root, text="开始运行", command=self.run)
self.run_btn.grid(row=2, column=0, columnspan=3, pady=10)
self.log = (tk.Text(self.root, height=10, width=70))
self.log.grid(row=3, column=0, columnspan=3, padx=5, pady=5)
def browse(self):
select_dir = filedialog.askdirectory()
if select_dir: self.out_dir.set(select_dir)
def log_msg(self, msg, max_lines=100):
self.log.insert("end", msg+"\n")
lines = int(self.log.index('end-1c').split('.')[0])
if lines > max_lines:
self.log.delete('1.0', f'{lines - max_lines + 1}.0')
self.log.see("end")
def poll_queue(self, interval=2000):
"""
每隔 interval 毫秒拉一次队列,只保留最新一条,其余丢弃。
这样既能保证“及时性”,又不会因为消息洪峰把 UI 卡死。
"""
batch = []
last = None
try:
batch = []
while True: # 一口气全捞出来
batch.append(hub.get_nowait())
except queue.Empty:
pass
# 如果拿到了至少一条,就只打最后一条
if len(batch) > 0:
for msg_type, msg in batch:
print(f"{msg_type}: {msg}")
if msg_type == "message":
if last:
self.log_msg(last)
last = None
self.log_msg(msg)
elif msg_type == "progress":
last = msg
if last:
self.log_msg(last)
last = None
# 继续预约下一次轮询
if not self.job_done:
self.root.after(interval, self.poll_queue)
def run(self):
uniprot_id = self.uniprot_id.get().strip()
out_dir = pathlib.Path(self.out_dir.get())
use_cache = self.use_cache.get()
print(f"use cache checkbox: {use_cache}")
print(f"uniprot_id: {uniprot_id}")
self.run_btn.config(state="disabled")
threading.Thread(target=self.worker, args=(uniprot_id, out_dir, use_cache), daemon=True).start()
self.poll_queue() # 启动轮询
def worker(self, uniprot_id, out_dir, use_cache):
try:
hub.put(("message","1. 下载蛋白序列"))
prot = Protein(uniprot_id=uniprot_id, out_dir=out_dir, use_cache=use_cache)
prot.deep_mut()
hub.put(("message","2. 提取 ESM 特征"))
repr_dict = prot.collect_repr() # 返回 dict{variant: vector}
hub.put(("message", "3. PCA 降维"))
variants = list(repr_dict.keys())
tensors = [repr_dict[v].view(1, -1) for v in variants]
data = torch.cat(tensors, dim=0).numpy() # (N_total, n_features)
n_features = data.shape[1]
# 2. 统一 PCA 降维
embedding_pca, pca_composition, explained_variance_ratio_ = PCA(data)
# 3. 按原拆分重新组装成字典
PCA_dict = arr2dict(embedding_pca, variants)
PC_names = [f"PC{i+1}" for i in range(n_features)]
compsition_dict = arr2dict(pca_composition, PC_names)
explained_variance_dict = arr2dict(explained_variance_ratio_.reshape(1, -1) , ["explained_variance"])
hub.put(("message","4. 保存结果"))
repr_file = out_dir / f"{prot.receptor}_{prot.species}_esm2_33.csv"
pca_data = out_dir / f"{prot.receptor}_{prot.species}_pca.csv"
pca_comp = out_dir / f"{prot.receptor}_{prot.species}_pca_composition.csv"
pca_ev = out_dir / f"{prot.receptor}_{prot.species}_pca_explained_variance.csv"
dict2csv(repr_dict=repr_dict, file_path=repr_file, key_name="variant",column_title="Feature")
dict2csv(repr_dict=PCA_dict, file_path=pca_data, key_name="variant",column_title="PC")
dict2csv(repr_dict=compsition_dict, file_path=pca_comp, key_name="PCs",column_title="Feature")
dict2csv(repr_dict=explained_variance_dict, file_path=pca_ev, key_name="", column_title="PC")
hub.put(("message","5. 全部完成"))
except Exception as e:
hub.put(("message",f"出错:{e}"))
finally:
self.run_btn.config(state="normal")
self.root.after(0, lambda: self.run_btn.config(state="normal"))
self.root.after(0, lambda: setattr(self, 'job_done', True))
if __name__ == '__main__':
root = tk.Tk()
ProtFeatureGUI(root)
root.mainloop()