-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
60 lines (56 loc) · 2.58 KB
/
utils.py
File metadata and controls
60 lines (56 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
from esm import pretrained
from esm import FastaBatchedDataset
import csv, torch
from message_hub import hub
def extract(dataset: FastaBatchedDataset, repr_layer, temp_dir=None):
toks_per_batch = 4096
hub.put(("message", "加载模型:esm2_t33_650M_UR50D"))
model, alphabet = pretrained.load_model_and_alphabet("esm2_t33_650M_UR50D")
model.eval()
batches = dataset.get_batch_indices(int(toks_per_batch), extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(dataset, collate_fn=alphabet.get_batch_converter(), batch_sampler=batches)
assert 0 <= repr_layer <= model.num_layers
all_repr = {}
total=len(dataset)
# hub.msg(f"total {total} sequences in dataloader")
current = 0
with torch.no_grad():
for batch_idx, (variants, strs, toks) in enumerate(data_loader):
out = model(toks, repr_layers=[repr_layer], need_head_weights=True, return_contacts=False)
batch_repr = out["representations"][repr_layer]
for i, variant in enumerate(variants):
all_repr[variant] = batch_repr[i, 1: len(strs[i]) + 1].mean(0).clone()
if temp_dir:
path = temp_dir / f"{variant}_esm2_{repr_layer}.csv"
np.savetxt(path, all_repr[variant].detach().numpy(), delimiter=',')
current += 1
hub.put(("progress", f"正在提取特征,进度:{current}/{total}"))
return all_repr
def dict2csv(repr_dict, file_path, key_name="variant",column_title="feature"):
"""repr_dict: {variant: [f1, f2, ...]}"""
with file_path.open("w", newline="") as f:
w = csv.writer(f)
header = [key_name] + [f"{column_title}{i+1}" for i in range(len(next(iter(repr_dict.values()))))]
w.writerow(header)
for var, vec in repr_dict.items():
vec = vec.tolist()
w.writerow([var, *vec])
def arr2dict(array, keys):
dict = {}
for idx_0, row in enumerate(array):
key = keys[idx_0]
dict[key] = row
return dict
def csv2arr(file_path, dtype=float, delimiter: str = ','):
with open(file_path, newline='', encoding='utf-8') as f:
reader = csv.reader(f, delimiter=delimiter)
rows = []
for row in reader:
# 跳过空行
if not row:
continue
# 自动去掉首尾空白
rows.extend([dtype(col.strip()) for col in row]) # 直接 extend 拉平
arr = np.array(rows, dtype=dtype) # shape: [1280]
return torch.from_numpy(arr)