Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
98dc8d2
refactor the code to take in input as other models in PG
Feb 10, 2026
fd24564
enable reading in multi-line fastas
Feb 10, 2026
50d32d9
add print statement for num sequences and weights dim for debugging
Feb 10, 2026
2388ca4
change weights_calc_method default to eve, simplify seq_name_to_seque…
Feb 10, 2026
db45a55
add directory to paths, update target sequence and MSA filepath input
Feb 11, 2026
63c8dba
add options to input DMS_index, comment out code to calculate VespaG_…
Feb 11, 2026
0836dc3
add a check of pdb file suffixes
Feb 11, 2026
9e4af94
add another error message printout for debugging
Feb 11, 2026
3a7d1d6
add options to input DMS_index
Feb 11, 2026
73c7c83
update ref file input fields
Feb 11, 2026
3a5c4c6
fix import path
Feb 11, 2026
da2f3a8
update ref file input fields (target_aa_seq), update the way to acces…
Feb 11, 2026
c62a980
update ref file input fields (target_aa_seq and pdb_range)
Feb 11, 2026
24c481b
add import path, update ref file input fields target_aa_seq, add comm…
Feb 11, 2026
1932686
remove old gemme compute_fitness script as the GEMME docker doesn't h…
Feb 11, 2026
e2e4838
add new gemme compute_fitness scripts to accommodate preprocessing an…
Feb 11, 2026
b0a7056
update progen3 scoring: generate a score.csv for each assay and calcu…
Feb 11, 2026
1aad992
update names of scripts used
Feb 11, 2026
9c9d808
update ref file input fields target_aa_seq
Feb 11, 2026
9bf6396
enable scoring assays through a ref file and DMS indices
Feb 11, 2026
07b3026
enable scoring assays through a ref file and DMS indices
Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 87 additions & 32 deletions proteingym/baselines/AIDO/compute_fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import seaborn as sns
import matplotlib.pyplot as plt

sys.path.append("/n/groups/marks/users/elain/PG2/ProteinGym/AIDO/")
from utils import misc
from utils import protein

Expand Down Expand Up @@ -67,54 +68,108 @@ def main(args):
#############################

os.makedirs(args.output_path, exist_ok=True)
for dms_id in args.dms_ids:

## Read Sequence
dms2seq, dms2annot = misc.load_fasta(f"{args.input_data_path}/query.fasta", load_annotation=True)
q_seq = dms2seq[dms_id]
start, end = [int(x) for x in dms2annot[dms_id].split('-')]

## Read MSA
msa = misc.load_msa_txt(f"{args.input_data_path}/msa_data/{dms_id}.txt.gz")
assert q_seq == msa[0]

## Read PDB
with open(f"{args.input_data_path}/struc_data/{dms_id}.pdb") as IN:
text = IN.read()

prot = protein.from_pdb_string(text, molecular_type='protein')
assert prot.seq(True) == q_seq
if args.dms_ids != None:
for dms_id in args.dms_ids:

## Read Sequence
dms2seq, dms2annot = misc.load_fasta(f"{args.fasta_dir}/{dms_id}.fasta", load_annotation=True)
q_seq = dms2seq[dms_id]
start, end = [int(x) for x in dms2annot[dms_id].split('-')]

## Read MSA
msa = misc.load_msa_txt(f"{args.msa_dir}/{dms_id}.a2m") # msa filename
assert q_seq == msa[0]

## Read PDB
with open(f"{args.pdb_dir}/{dms_id}.pdb") as IN: # pdb filename
text = IN.read()

prot = protein.from_pdb_string(text, molecular_type='protein')
assert prot.seq(True) == q_seq

## Read DMS Table
dms_df = pd.read_csv(f'{args.dms_dir}/{dms_id}.csv') # dms filename

## Read DMS Table
dms_df = pd.read_csv(f'{args.input_data_path}/dms_data/{dms_id}.csv')
## Inference
all_poses, logits_table = misc.get_logits_table_sliding(q_seq, prot, msa, dms_df, model, tokenizer, str_tokenizer, start, mask_str=args.mask_str, disable_tqdm=False)
result_df = misc.get_scores_from_table(q_seq, logits_table, all_poses, dms_df, tokenizer, start, temp_mt=1.0, temp_wt=1.5)

assert np.all(dms_df['mutant'] == result_df['Mutation'])
assert np.allclose(dms_df['DMS_score'] , result_df['GT_Score'], 1e-05, 1e-05)
dms_df['AIDO.Protein-RAG-16B-zeroshot'] = result_df['Pred_Score']
result_df = dms_df

result_df.to_csv(join(args.output_path, f"{dms_id}.csv"), index=False)
r = round(spearmanr(result_df['AIDO.Protein-RAG-16B-zeroshot'], result_df['DMS_score'])[0], 4)
print(f"{dms_id}: R={r}")
elif args.reference_file != None:
reference_file_df = pd.read_csv(args.reference_file)
for index, row in reference_file_df.iterrows():
dms_id = row['DMS_id']

## Read Sequence
dms2seq, dms2annot = misc.load_fasta(f"{args.fasta_dir}/{dms_id}.fasta", load_annotation=True)
q_seq = dms2seq[dms_id]
start, end = [int(x) for x in dms2annot[dms_id].split('-')]

## Read MSA
msa = misc.load_msa_txt(f"{args.msa_dir}/{row['MSA_filename']}")
assert q_seq == msa[0]

## Inference
all_poses, logits_table = misc.get_logits_table_sliding(q_seq, prot, msa, dms_df, model, tokenizer, str_tokenizer, start, mask_str=args.mask_str, disable_tqdm=False)
result_df = misc.get_scores_from_table(q_seq, logits_table, all_poses, dms_df, tokenizer, start, temp_mt=1.0, temp_wt=1.5)
## Read PDB
with open(f"{args.pdb_dir}/{row['pdb_file']}") as IN: # pdb filename
text = IN.read()

prot = protein.from_pdb_string(text, molecular_type='protein')
assert prot.seq(True) == q_seq

## Read DMS Table
dms_df = pd.read_csv(f'{args.dms_dir}/{dms_id}.csv') # dms filename

assert np.all(dms_df['mutant'] == result_df['Mutation'])
assert np.allclose(dms_df['DMS_score'] , result_df['GT_Score'], 1e-05, 1e-05)
dms_df['AIDO.Protein-RAG-16B-zeroshot'] = result_df['Pred_Score']
result_df = dms_df
## Inference
all_poses, logits_table = misc.get_logits_table_sliding(q_seq, prot, msa, dms_df, model, tokenizer, str_tokenizer, start, mask_str=args.mask_str, disable_tqdm=False)
result_df = misc.get_scores_from_table(q_seq, logits_table, all_poses, dms_df, tokenizer, start, temp_mt=1.0, temp_wt=1.5)

assert np.all(dms_df['mutant'] == result_df['Mutation'])
assert np.allclose(dms_df['DMS_score'] , result_df['GT_Score'], 1e-05, 1e-05)
dms_df['AIDO.Protein-RAG-16B-zeroshot'] = result_df['Pred_Score']
result_df = dms_df

result_df.to_csv(join(args.output_path, f"{dms_id}.csv"), index=False)
r = round(spearmanr(result_df['AIDO.Protein-RAG-16B-zeroshot'], result_df['DMS_score'])[0], 4)
print(f"{dms_id}: R={r}")
else:
raise FileNotFoundError("No input for the DMS_id info. Specify --dms_ids or --reference_file")

result_df.to_csv(join(args.output_path, f"{dms_id}.csv"), index=False)
r = round(spearmanr(result_df['AIDO.Protein-RAG-16B-zeroshot'], result_df['DMS_score'])[0], 4)
print(f"{dms_id}: R={r}")

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--dms_ids", type=str, default=None, nargs="+")
parser.add_argument("--dms_ids", type=str, default=None, nargs="*")
parser.add_argument("--reference_file", type=str, default=None)
parser.add_argument("--dms_index", type=int, default=None, help="Index of a single DMS_id in full list of assays (all_dms_ids)")
parser.add_argument("--input_data_path", type=str, default=f"{SCRIPT_PATH}", help="Input data path",)
# parser.add_argument("--input_data_path", type=str, default=f"{SCRIPT_PATH}", help="Input data path",)
parser.add_argument("--fasta_dir", required=True,
help="Directory containing fasta files")
parser.add_argument("--msa_dir", required=True,
help="Directory containing MSA files")
parser.add_argument("--dms_dir", required=True,
help="Directory containing DMS CSV files")
parser.add_argument("--pdb_dir", required=False, default=None,
help="Directory containing PDB files (only needed for esm3_open with use_structure=True)")
parser.add_argument("--output_path", type=str, default=f"{SCRIPT_PATH}/output", help="Output path",)
parser.add_argument("--hf_cache_location", type=str, default=None, help="Hugging Face cache directory for downloading models and tokenizers")
parser.add_argument("--mask-str", action='store_true', help="Mask the structure input")
args = parser.parse_args()
all_dms_ids = [n[:-7] for n in os.listdir(f'{args.input_data_path}/msa_data') if n.endswith('.txt.gz')]
all_dms_ids = [n for n in os.listdir(f'{args.dms_dir}') if n.endswith('.csv')]

print(args.reference_file)

if args.dms_ids is None:
if args.dms_index is not None: # Compute scores for a single DMS id, indexed by dms_index
args.dms_ids=[all_dms_ids[args.dms_index]]
else:
args.dms_ids=all_dms_ids # Compute scores for all assays in ProteinGym
if args.reference_file is None:
args.dms_ids=all_dms_ids # Compute scores for all assays in ProteinGym

main(args)
63 changes: 52 additions & 11 deletions proteingym/baselines/AIDO/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,25 +164,62 @@ def load_msa_txt(file_or_stream, load_id=False, load_annot=False, sort=False):
else:
with open(file_or_stream) as IN:
lines = IN.read().strip().split('\n')


q_seq = ""
sequence = ""
current_id = ""
sequence_list = []
id_ = 0
annot = None
for idx,line in enumerate(lines):
data = line.strip().split()
if idx == 0:
if line.startswith(">"):
if current_id:
# output the previous sequence and identity
sequence = "".join(sequence_list)
if len(msa) == 0:
q_seq = sequence

assert len(sequence) == len(q_seq)
msa.append(sequence)
if id_ == 0:
id_ = round(np.mean([ r1==r2 for r1,r2 in zip(sequence, q_seq) ]), 3)
id_arr.append(id_)
annotations.append( annot )

# reset
sequence_list = []
id_ = 0

current_id = line[1:]
continue

if not q_seq:
assert len(data) == 1, f"Expect 1 element for the 1st line, but got {data} in {file_or_stream}"
q_seq = data[0]
sequence_list.append(data[0])
else:
if len(data) >= 2:
id_arr.append( float(data[1]) )
else:
assert len(q_seq) == len(data[0])
id_ = round(np.mean([ r1==r2 for r1,r2 in zip(q_seq, data[0]) ]), 3)
id_arr.append(id_)
msa.append( data[0] )
id_ = float(data[1])
# else:
# id_arr.append(id_)
sequence_list.append(data[0])
if len(data) >= 3:
annot = " ".join(data[2:])
annotations.append( annot )
else:
annotations.append(None)
annot = None

# output the last entry
sequence = "".join(sequence_list)
if len(msa) == 0:
q_seq = sequence

assert len(sequence) == len(q_seq)
msa.append(sequence)
if id_ == 0:
id_ = round(np.mean([ r1==r2 for r1,r2 in zip(sequence, q_seq) ]), 3)
id_arr.append(id_)
annotations.append( annot )


id_arr = np.array(id_arr, dtype=np.float64)
if sort:
Expand Down Expand Up @@ -281,6 +318,7 @@ def get_logits_table_sliding(q_seq, prot, msa, dms_df, model, tokenizer, str_tok
# assert model_type == 'emb_model_step1'
assert len(q_seq) == prot.aatype.shape[0], f"len(q_seq)={len(q_seq)}, prot.aatype.shape[0]={prot.aatype.shape[0]}"
assert q_seq == msa[0]
print("q_seq len:", len(q_seq))

pd_scores = []
all_poses = set()
Expand All @@ -306,10 +344,12 @@ def get_logits_table_sliding(q_seq, prot, msa, dms_df, model, tokenizer, str_tok
f_end = min(f_start + sliding_window, len(q_seq))

f_q_seq = q_seq[f_start:f_end]
print("f_q_seq len", len(f_q_seq))

# f_msa = rag_utils.greedy_select(list(set([ seq[f_start:f_end] for seq in msa[1:] ])), num_seqs=None, num_tokens=12800, seed=0)
f_msa = greedy_select([ seq[f_start:f_end] for seq in msa[1:] ], num_seqs=None, num_tokens=12800, seed=0)
f_msa.sort(key=lambda x: x.count('-'))
print("f_msa len", len(f_msa[0]))

str_embs, str_toks = str_tokenizer.encode(prot.aatype[f_start:f_end], prot.atom_positions[f_start:f_end], prot.atom_mask[f_start:f_end], get_embedding=True)
str_embs, str_toks = str_embs.cuda().bfloat16(), str_toks.cuda()
Expand All @@ -328,6 +368,7 @@ def get_logits_table_sliding(q_seq, prot, msa, dms_df, model, tokenizer, str_tok
if f_start <= pos < f_end:
masked_tokens = tokens.clone()
masked_tokens[pos_encoding[0]==pos-f_start] = tokenizer.TokenToId('tMASK')
print("masked_tokens", masked_tokens[None].shape)
lm_output = model.transformer(
input_ids=masked_tokens[None],
position_ids=pos_encoding[None],
Expand Down
2 changes: 2 additions & 0 deletions proteingym/baselines/EVE/EVE/VAE_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def train_model(self, data, training_parameters, use_dataloader=False):
best_model_step_index = training_parameters['num_training_steps']

seq_sample_probs = weights_train / np.sum(weights_train)
# print("num sequences: ", len(data.seq_name_to_sequence))
# print("num weights: ", weights_train.shape[0])
assert len(data.seq_name_to_sequence) == weights_train.shape[0] # One weight per sequence

# TMP TODO: Keep old behaviour for comparison
Expand Down
7 changes: 3 additions & 4 deletions proteingym/baselines/EVE/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self,
threshold_focus_cols_frac_gaps=0.3,
remove_sequences_with_indeterminate_AA_in_focus_cols=True,
num_cpus=1,
weights_calc_method="evcouplings",
weights_calc_method="eve",
overwrite_weights=False,
debug_only_weights=False,
):
Expand Down Expand Up @@ -214,8 +214,7 @@ def _lower_case_and_filter_fragments(seq):
msa_df = msa_df[seq_below_threshold]
# Overwrite seq_name_to_sequence with clean version
seq_name_to_sequence = defaultdict(str)
for seq_idx in range(len(msa_df['sequence'])):
seq_name_to_sequence[msa_df.index[seq_idx]] = msa_df.sequence[seq_idx]
seq_name_to_sequence = dict(zip(msa_df.index, msa_df.sequence))

return seq_name_to_sequence

Expand Down Expand Up @@ -468,4 +467,4 @@ def collate_fn(batch_seqs):
sampler=sampler,
collate_fn=collate_fn,) #pin_memory=True

return dataloader
return dataloader
7 changes: 5 additions & 2 deletions proteingym/baselines/PoET/scripts/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from tqdm import tqdm, trange

import sys
# append paths of git cloned PoET repo and the PG baseline PoET (ProteinGym/proteingym/baselines/PoET/)
from poet.alphabets import Uniprot21
from poet.fasta import parse_stream
from poet.models.modules.packed_sequence import PackedTensorSequences
Expand Down Expand Up @@ -269,7 +271,7 @@ def main():
ref_series = pd.read_csv(args.DMS_reference_file_path).iloc[args.DMS_index]
msa_start = int(ref_series["MSA_start"])
msa_end = int(ref_series["MSA_end"])
wt_sequence = ref_series["target_seq"][msa_start - 1 : msa_end]
wt_sequence = ref_series["target_aa_seq"][msa_start - 1 : msa_end]
variants_filename = ref_series["DMS_filename"]
variants_df = pd.read_csv(args.DMS_data_folder / variants_filename)
if "mutated_sequence" in variants_df.columns:
Expand All @@ -287,7 +289,8 @@ def main():
variants.append(wt)

# process msa
msa_filepath = (args.MSA_folder / variants_filename).with_suffix(".a3m.zst")
msa_filename = ref_series["MSA_filename"].removesuffix(".a2m")
msa_filepath = args.MSA_folder.joinpath(msa_filename + ".a3m.zst")
msa_sequences = get_seqs_from_fastalike(msa_filepath)
assert msa_sequences[0].decode() == wt_sequence
msa = get_encoded_msa_from_a3m_seqs(msa_sequences=msa_sequences, alphabet=alphabet)
Expand Down
Loading