From b2582aee2f7b72822001674662c8fb6d00516644 Mon Sep 17 00:00:00 2001 From: hsbyeon1 Date: Wed, 11 Dec 2024 21:27:04 +0900 Subject: [PATCH 1/6] feat:copy colabfold mmseqs code --- chai_lab/data/dataset/msas/local_mmseqs.py | 541 +++++++++++++++++++++ 1 file changed, 541 insertions(+) create mode 100644 chai_lab/data/dataset/msas/local_mmseqs.py diff --git a/chai_lab/data/dataset/msas/local_mmseqs.py b/chai_lab/data/dataset/msas/local_mmseqs.py new file mode 100644 index 00000000..0b2c5e47 --- /dev/null +++ b/chai_lab/data/dataset/msas/local_mmseqs.py @@ -0,0 +1,541 @@ +""" +Functionality for running mmseqs locally. Takes in a fasta file, outputs final.a3m +""" + +import logging +import math +import os +import shutil +import subprocess +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from pathlib import Path +from typing import List, Union + +from colabfold.batch import get_queries, msa_to_str +from colabfold.utils import safe_filename + +logger = logging.getLogger(__name__) + +MODULE_OUTPUT_POS = { + "align": 4, + "convertalis": 4, + "expandaln": 5, + "filterresult": 4, + "lndb": 2, + "mergedbs": 2, + "mvdb": 2, + "pairaln": 4, + "result2msa": 4, + "search": 3, +} + +def run_mmseqs(mmseqs: Path, params: List[Union[str, Path]]): + module = params[0] + if module in MODULE_OUTPUT_POS: + output_pos = MODULE_OUTPUT_POS[module] + output_path = Path(params[output_pos]).with_suffix('.dbtype') + if output_path.exists(): + logger.info(f"Skipping {module} because {output_path} already exists") + return + + params_log = " ".join(str(i) for i in params) + logger.info(f"Running {mmseqs} {params_log}") + # hide MMseqs2 verbose paramters list that clogs up the log + os.environ["MMSEQS_CALL_DEPTH"] = "1" + subprocess.check_call([mmseqs] + params) + + +def mmseqs_search_monomer( + dbbase: Path, + base: Path, + uniref_db: Path = Path("uniref30_2302_db"), + template_db: Path = Path(""), # Unused by default + metagenomic_db: Path = Path("colabfold_envdb_202108_db"), + mmseqs: Path = Path("mmseqs"), + use_env: bool = True, + use_templates: bool = False, + filter: bool = True, + expand_eval: float = math.inf, + align_eval: int = 10, + diff: int = 3000, + qsc: float = -20.0, + max_accept: int = 1000000, + prefilter_mode: int = 0, + s: float = 8, + db_load_mode: int = 2, + threads: int = 32, + unpack: bool = True, +): + """Run mmseqs with a local colabfold database set + + db1: uniprot db (UniRef30) + db2: Template (unused by default) + db3: metagenomic db (colabfold_envdb_202108 or bfd_mgy_colabfold, the former is preferred) + """ + if filter: + # 0.1 was not used in benchmarks due to POSIX shell bug in line above + # EXPAND_EVAL=0.1 + align_eval = 10 + qsc = 0.8 + max_accept = 100000 + + used_dbs = [uniref_db] + if use_templates: + used_dbs.append(template_db) + if use_env: + used_dbs.append(metagenomic_db) + + for db in used_dbs: + if not dbbase.joinpath(f"{db}.dbtype").is_file(): + raise FileNotFoundError(f"Database {db} does not exist") + if ( + ( + not dbbase.joinpath(f"{db}.idx").is_file() + and not dbbase.joinpath(f"{db}.idx.index").is_file() + ) + or os.environ.get("MMSEQS_IGNORE_INDEX", False) + ): + logger.info("Search does not use index") + db_load_mode = 0 + dbSuffix1 = "_seq" + dbSuffix2 = "_aln" + dbSuffix3 = "" + else: + dbSuffix1 = ".idx" + dbSuffix2 = ".idx" + dbSuffix3 = ".idx" + + search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", "10000"] + search_param += ["--prefilter-mode", str(prefilter_mode)] + if s is not None: + search_param += ["-s", "{:.1f}".format(s)] + else: + search_param += ["--k-score", "'seq:96,prof:80'"] + + filter_param = ["--filter-msa", str(filter), "--filter-min-enable", "1000", "--diff", str(diff), "--qid", "0.0,0.2,0.4,0.6,0.8,1.0", "--qsc", "0", "--max-seq-id", "0.95",] + expand_param = ["--expansion-mode", "0", "-e", str(expand_eval), "--expand-filter-clusters", str(filter), "--max-seq-id", "0.95",] + + if not base.joinpath("uniref.a3m").with_suffix('.a3m.dbtype').exists(): + run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(uniref_db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads)] + search_param) + run_mmseqs(mmseqs, ["mvdb", base.joinpath("tmp/latest/profile_1"), base.joinpath("prof_res")]) + run_mmseqs(mmseqs, ["lndb", base.joinpath("qdb_h"), base.joinpath("prof_res_h")]) + run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{uniref_db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + expand_param) + run_mmseqs(mmseqs, ["align", base.joinpath("prof_res"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_exp"), base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", str(align_eval), "--max-accept", str(max_accept), "--threads", str(threads), "--alt-ali", "10", "-a"]) + run_mmseqs(mmseqs, ["filterresult", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), + base.joinpath("res_exp_realign"), base.joinpath("res_exp_realign_filter"), "--db-load-mode", + str(db_load_mode), "--qid", "0", "--qsc", str(qsc), "--diff", "0", "--threads", + str(threads), "--max-seq-id", "1.0", "--filter-min-enable", "100"]) + run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), + base.joinpath("res_exp_realign_filter"), base.joinpath("uniref.a3m"), "--msa-format-mode", + "6", "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign_filter")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")]) + else: + logger.info(f"Skipping {uniref_db} search because uniref.a3m already exists") + + if use_env and not base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m").with_suffix('.a3m.dbtype').exists(): + run_mmseqs(mmseqs, ["search", base.joinpath("prof_res"), dbbase.joinpath(metagenomic_db), base.joinpath("res_env"), + base.joinpath("tmp3"), "--threads", str(threads)] + search_param) + run_mmseqs(mmseqs, ["expandaln", base.joinpath("prof_res"), dbbase.joinpath(f"{metagenomic_db}{dbSuffix1}"), base.joinpath("res_env"), + dbbase.joinpath(f"{metagenomic_db}{dbSuffix2}"), base.joinpath("res_env_exp"), "-e", str(expand_eval), + "--expansion-mode", "0", "--db-load-mode", str(db_load_mode), "--threads", str(threads)]) + run_mmseqs(mmseqs, ["align", base.joinpath("tmp3/latest/profile_1"), dbbase.joinpath(f"{metagenomic_db}{dbSuffix1}"), + base.joinpath("res_env_exp"), base.joinpath("res_env_exp_realign"), "--db-load-mode", + str(db_load_mode), "-e", str(align_eval), "--max-accept", str(max_accept), "--threads", + str(threads), "--alt-ali", "10", "-a"]) + run_mmseqs(mmseqs, ["filterresult", base.joinpath("qdb"), dbbase.joinpath(f"{metagenomic_db}{dbSuffix1}"), + base.joinpath("res_env_exp_realign"), base.joinpath("res_env_exp_realign_filter"), + "--db-load-mode", str(db_load_mode), "--qid", "0", "--qsc", str(qsc), "--diff", "0", + "--max-seq-id", "1.0", "--threads", str(threads), "--filter-min-enable", "100"]) + run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{metagenomic_db}{dbSuffix1}"), + base.joinpath("res_env_exp_realign_filter"), + base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m"), "--msa-format-mode", "6", + "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp_realign_filter")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp_realign")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env")]) + elif use_env: + logger.info(f"Skipping {metagenomic_db} search because bfd.mgnify30.metaeuk30.smag30.a3m already exists") + + if use_templates and not base.joinpath(f"{template_db}.m8").with_suffix('.m8.dbtype').exists(): + run_mmseqs(mmseqs, ["search", base.joinpath("prof_res"), dbbase.joinpath(template_db), base.joinpath("res_pdb"), + base.joinpath("tmp2"), "--db-load-mode", str(db_load_mode), "--threads", str(threads), "-s", "7.5", "-a", "-e", "0.1", "--prefilter-mode", str(prefilter_mode)]) + run_mmseqs(mmseqs, ["convertalis", base.joinpath("prof_res"), dbbase.joinpath(f"{template_db}{dbSuffix3}"), base.joinpath("res_pdb"), + base.joinpath(f"{template_db}"), "--format-output", + "query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,cigar", + "--db-output", "1", + "--db-load-mode", str(db_load_mode), "--threads", str(threads)]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_pdb")]) + elif use_templates: + logger.info(f"Skipping {template_db} search because {template_db}.m8 already exists") + + if use_env: + run_mmseqs(mmseqs, ["mergedbs", base.joinpath("qdb"), base.joinpath("final.a3m"), base.joinpath("uniref.a3m"), base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("uniref.a3m")]) + else: + run_mmseqs(mmseqs, ["mvdb", base.joinpath("uniref.a3m"), base.joinpath("final.a3m")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("uniref.a3m")]) + + if unpack: + run_mmseqs(mmseqs, ["unpackdb", base.joinpath("final.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".a3m"]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("final.a3m")]) + + if use_templates: + run_mmseqs(mmseqs, ["unpackdb", base.joinpath(f"{template_db}"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".m8"]) + if base.joinpath(f"{template_db}").exists(): + run_mmseqs(mmseqs, ["rmdb", base.joinpath(f"{template_db}")]) + + run_mmseqs(mmseqs, ["rmdb", base.joinpath("prof_res")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("prof_res_h")]) + shutil.rmtree(base.joinpath("tmp")) + if use_templates: + shutil.rmtree(base.joinpath("tmp2")) + if use_env: + shutil.rmtree(base.joinpath("tmp3")) + +def mmseqs_search_pair( + dbbase: Path, + base: Path, + uniref_db: Path = Path("uniref30_2302_db"), + spire_db: Path = Path("spire_ctg10_2401_db"), + mmseqs: Path = Path("mmseqs"), + pair_env: bool = True, + prefilter_mode: int = 0, + s: float = 8, + threads: int = 64, + db_load_mode: int = 2, + pairing_strategy: int = 0, # 0: greedy, 1: complete + unpack: bool = True, +): + if not dbbase.joinpath(f"{uniref_db}.dbtype").is_file(): + raise FileNotFoundError(f"Database {uniref_db} does not exist") + if ( + ( + not dbbase.joinpath(f"{uniref_db}.idx").is_file() + and not dbbase.joinpath(f"{uniref_db}.idx.index").is_file() + ) + or os.environ.get("MMSEQS_IGNORE_INDEX", False) + ): + logger.info("Search does not use index") + db_load_mode = 0 + dbSuffix1 = "_seq" + dbSuffix2 = "_aln" + else: + dbSuffix1 = ".idx" + dbSuffix2 = ".idx" + + if pair_env: + db = spire_db + output = ".env.paired.a3m" + else: + db = uniref_db + output = ".paired.a3m" + + # fmt: off + # @formatter:off + search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", "10000",] + search_param += ["--prefilter-mode", str(prefilter_mode)] + if s is not None: + search_param += ["-s", "{:.1f}".format(s)] + else: + search_param += ["--k-score", "'seq:96,prof:80'"] + expand_param = ["--expansion-mode", "0", "-e", "inf", "--expand-filter-clusters", "0", "--max-seq-id", "0.95",] + run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads),] + search_param,) + run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads),] + expand_param,) + run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res_exp"), base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", "0.001", "--max-accept", "1000000", "--threads", str(threads), "-c", "0.5", "--cov-mode", "1",],) + run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), dbbase.joinpath(f"{db}"), base.joinpath("res_exp_realign"), base.joinpath("res_exp_realign_pair"), "--db-load-mode", str(db_load_mode), "--pairing-mode", str(pairing_strategy), "--pairing-dummy-mode", "0", "--threads", str(threads), ],) + run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res_exp_realign_pair"), base.joinpath("res_exp_realign_pair_bt"), "--db-load-mode", str(db_load_mode), "-e", "inf", "-a", "--threads", str(threads), ],) + run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), dbbase.joinpath(f"{db}"), base.joinpath("res_exp_realign_pair_bt"), base.joinpath("res_final"), "--db-load-mode", str(db_load_mode), "--pairing-mode", str(pairing_strategy), "--pairing-dummy-mode", "1", "--threads", str(threads),],) + run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res_final"), base.joinpath("pair.a3m"), "--db-load-mode", str(db_load_mode), "--msa-format-mode", "5", "--threads", str(threads),],) + if unpack: + run_mmseqs(mmseqs, ["unpackdb", base.joinpath("pair.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", output,],) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("pair.a3m")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign_pair")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign_pair_bt")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_final")]) + shutil.rmtree(base.joinpath("tmp")) + # @formatter:on + # fmt: on + +def main(): + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument( + "query", + type=Path, + help="fasta files with the queries.", + ) + parser.add_argument( + "dbbase", + type=Path, + help="The path to the database and indices you downloaded and created with setup_databases.sh", + ) + parser.add_argument( + "base", type=Path, help="Directory for the results (and intermediate files)" + ) + parser.add_argument( + "--prefilter-mode", + type=int, + default=0, + choices=[0, 1, 2], + help="Prefiltering algorithm to use: 0: k-mer (high-mem), 1: ungapped (high-cpu), 2: exhaustive (no prefilter, very slow). See wiki for more details: https://github.com/sokrypton/ColabFold/wiki#colabfold_search", + ) + parser.add_argument( + "-s", + type=float, + default=None, + help="MMseqs2 sensitivity. Lowering this will result in a much faster search but possibly sparser MSAs. By default, the k-mer threshold is directly set to the same one of the server, which corresponds to a sensitivity of ~8.", + ) + # dbs are uniref, templates and environmental + # We normally don't use templates + parser.add_argument( + "--db1", type=Path, default=Path("uniref30_2302_db"), help="UniRef database" + ) + parser.add_argument("--db2", type=Path, default=Path(""), help="Templates database") + parser.add_argument( + "--db3", + type=Path, + default=Path("colabfold_envdb_202108_db"), + help="Environmental database", + ) + parser.add_argument("--db4", type=Path, default=Path("spire_ctg10_2401_db"), help="Environmental pairing database") + + # poor man's boolean arguments + parser.add_argument( + "--use-env", type=int, default=1, choices=[0, 1], help="Use --db3" + ) + parser.add_argument( + "--use-env-pairing", type=int, default=0, choices=[0, 1], help="Use --db4" + ) + parser.add_argument( + "--use-templates", type=int, default=0, choices=[0, 1], help="Use --db2" + ) + parser.add_argument( + "--filter", + type=int, + default=1, + choices=[0, 1], + help="Filter the MSA by pre-defined align_eval, qsc, max_accept", + ) + + # mmseqs params + parser.add_argument( + "--mmseqs", + type=Path, + default=Path("mmseqs"), + help="Location of the mmseqs binary.", + ) + parser.add_argument( + "--expand-eval", + type=float, + default=math.inf, + help="e-val threshold for 'expandaln'.", + ) + parser.add_argument( + "--align-eval", type=int, default=10, help="e-val threshold for 'align'." + ) + parser.add_argument( + "--diff", + type=int, + default=3000, + help="filterresult - Keep at least this many seqs in each MSA block.", + ) + parser.add_argument( + "--qsc", + type=float, + default=-20.0, + help="filterresult - reduce diversity of output MSAs using min score thresh.", + ) + parser.add_argument( + "--max-accept", + type=int, + default=1000000, + help="align - Maximum accepted alignments before alignment calculation for a query is stopped.", + ) + parser.add_argument( + "--pairing_strategy", type=int, default=0, help="pairaln - Pairing strategy." + ) + parser.add_argument( + "--db-load-mode", + type=int, + default=0, + help="Database preload mode 0: auto, 1: fread, 2: mmap, 3: mmap+touch", + ) + parser.add_argument( + "--unpack", type=int, default=1, choices=[0, 1], help="Unpack results to loose files or keep MMseqs2 databases." + ) + parser.add_argument( + "--threads", type=int, default=64, help="Number of threads to use." + ) + args = parser.parse_args() + + logging.basicConfig(level = logging.INFO) + + queries, is_complex = get_queries(args.query, None) + + queries_unique = [] + for job_number, (raw_jobname, query_sequences, a3m_lines) in enumerate(queries): + # remove duplicates before searching + query_sequences = ( + [query_sequences] if isinstance(query_sequences, str) else query_sequences + ) + query_seqs_unique = [] + for x in query_sequences: + if x not in query_seqs_unique: + query_seqs_unique.append(x) + query_seqs_cardinality = [0] * len(query_seqs_unique) + for seq in query_sequences: + seq_idx = query_seqs_unique.index(seq) + query_seqs_cardinality[seq_idx] += 1 + + queries_unique.append([raw_jobname, query_seqs_unique, query_seqs_cardinality]) + + args.base.mkdir(exist_ok=True, parents=True) + query_file = args.base.joinpath("query.fas") + with query_file.open("w") as f: + for job_number, ( + raw_jobname, + query_sequences, + query_seqs_cardinality, + ) in enumerate(queries_unique): + for j, seq in enumerate(query_sequences): + # The header of first sequence set as 101 + query_seq_headername = 101 + j + f.write(f">{query_seq_headername}\n{seq}\n") + + run_mmseqs( + args.mmseqs, + ["createdb", query_file, args.base.joinpath("qdb"), "--shuffle", "0"], + ) + with args.base.joinpath("qdb.lookup").open("w") as f: + id = 0 + file_number = 0 + for job_number, ( + raw_jobname, + query_sequences, + query_seqs_cardinality, + ) in enumerate(queries_unique): + for seq in query_sequences: + raw_jobname_first = raw_jobname.split()[0] + f.write(f"{id}\t{raw_jobname_first}\t{file_number}\n") + id += 1 + file_number += 1 + + mmseqs_search_monomer( + mmseqs=args.mmseqs, + dbbase=args.dbbase, + base=args.base, + uniref_db=args.db1, + template_db=args.db2, + metagenomic_db=args.db3, + use_env=args.use_env, + use_templates=args.use_templates, + filter=args.filter, + expand_eval=args.expand_eval, + align_eval=args.align_eval, + diff=args.diff, + qsc=args.qsc, + max_accept=args.max_accept, + prefilter_mode=args.prefilter_mode, + s=args.s, + db_load_mode=args.db_load_mode, + threads=args.threads, + unpack=args.unpack, + ) + if is_complex is True: + mmseqs_search_pair( + mmseqs=args.mmseqs, + dbbase=args.dbbase, + base=args.base, + uniref_db=args.db1, + prefilter_mode=args.prefilter_mode, + s=args.s, + db_load_mode=args.db_load_mode, + threads=args.threads, + pairing_strategy=args.pairing_strategy, + pair_env=False, + unpack=args.unpack, + ) + if args.use_env_pairing: + mmseqs_search_pair( + mmseqs=args.mmseqs, + dbbase=args.dbbase, + base=args.base, + uniref_db=args.db1, + spire_db=args.db4, + prefilter_mode=args.prefilter_mode, + s=args.s, + db_load_mode=args.db_load_mode, + threads=args.threads, + pairing_strategy=args.pairing_strategy, + pair_env=True, + unpack=args.unpack, + ) + + if args.unpack: + id = 0 + for job_number, ( + raw_jobname, + query_sequences, + query_seqs_cardinality, + ) in enumerate(queries_unique): + unpaired_msa = [] + paired_msa = None + if len(query_seqs_cardinality) > 1: + paired_msa = [] + for seq in query_sequences: + with args.base.joinpath(f"{id}.a3m").open("r") as f: + unpaired_msa.append(f.read()) + args.base.joinpath(f"{id}.a3m").unlink() + + if args.use_env_pairing: + with open(args.base.joinpath(f"{id}.paired.a3m"), 'a') as file_pair: + with open(args.base.joinpath(f"{id}.env.paired.a3m"), 'r') as file_pair_env: + while chunk := file_pair_env.read(10 * 1024 * 1024): + file_pair.write(chunk) + args.base.joinpath(f"{id}.env.paired.a3m").unlink() + + if len(query_seqs_cardinality) > 1: + with args.base.joinpath(f"{id}.paired.a3m").open("r") as f: + paired_msa.append(f.read()) + args.base.joinpath(f"{id}.paired.a3m").unlink() + id += 1 + msa = msa_to_str( + unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality + ) + args.base.joinpath(f"{job_number}.a3m").write_text(msa) + + if args.unpack: + # rename a3m files + for job_number, (raw_jobname, query_sequences, query_seqs_cardinality) in enumerate(queries_unique): + os.rename( + args.base.joinpath(f"{job_number}.a3m"), + args.base.joinpath(f"{safe_filename(raw_jobname)}.a3m"), + ) + + # rename m8 files + if args.use_templates: + id = 0 + for raw_jobname, query_sequences, query_seqs_cardinality in queries_unique: + with args.base.joinpath(f"{safe_filename(raw_jobname)}_{args.db2}.m8").open( + "w" + ) as f: + for _ in range(len(query_seqs_cardinality)): + with args.base.joinpath(f"{id}.m8").open("r") as g: + f.write(g.read()) + os.remove(args.base.joinpath(f"{id}.m8")) + id += 1 + run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")]) + run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")]) + + query_file.unlink() + + +if __name__ == "__main__": + main() From ca13a4d58cdec14a0afeeb0d0c524f89bd0d3dc7 Mon Sep 17 00:00:00 2001 From: hsbyeon1 Date: Wed, 11 Dec 2024 21:41:21 +0900 Subject: [PATCH 2/6] feat:copy colabfold util codes for mmseqs --- .../dataset/msas/colabfold_codes/__init__.py | 0 .../dataset/msas/colabfold_codes/batch.py | 224 ++++++++++++++++++ .../dataset/msas/colabfold_codes/utils.py | 2 + 3 files changed, 226 insertions(+) create mode 100644 chai_lab/data/dataset/msas/colabfold_codes/__init__.py create mode 100644 chai_lab/data/dataset/msas/colabfold_codes/batch.py create mode 100644 chai_lab/data/dataset/msas/colabfold_codes/utils.py diff --git a/chai_lab/data/dataset/msas/colabfold_codes/__init__.py b/chai_lab/data/dataset/msas/colabfold_codes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chai_lab/data/dataset/msas/colabfold_codes/batch.py b/chai_lab/data/dataset/msas/colabfold_codes/batch.py new file mode 100644 index 00000000..78d7e020 --- /dev/null +++ b/chai_lab/data/dataset/msas/colabfold_codes/batch.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import warnings +from Bio import BiopythonDeprecationWarning # what can possibly go wrong... +warnings.simplefilter(action='ignore', category=BiopythonDeprecationWarning) + +import logging +import random +from pathlib import Path +from typing import List, Optional, Tuple, Union, TYPE_CHECKING +import pandas + +logger = logging.getLogger(__name__) + +def parse_fasta(fasta_string: str) -> Tuple[List[str], List[str]]: + """Parses FASTA string and returns list of strings with amino-acid sequences. + + Arguments: + fasta_string: The string contents of a FASTA file. + + Returns: + A tuple of two lists: + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. + """ + sequences = [] + descriptions = [] + index = -1 + for line in fasta_string.splitlines(): + line = line.strip() + if line.startswith("#"): + continue + if line.startswith(">"): + index += 1 + descriptions.append(line[1:]) # Remove the '>' at the beginning. + sequences.append("") + continue + elif not line: + continue # Skip blank lines. + sequences[index] += line + + return sequences, descriptions + +def get_queries( + input_path: Union[str, Path], sort_queries_by: str = "length" +) -> Tuple[List[Tuple[str, str, Optional[List[str]]]], bool]: + """Reads a directory of fasta files, a single fasta file or a csv file and returns a tuple + of job name, sequence and the optional a3m lines""" + + input_path = Path(input_path) + if not input_path.exists(): + raise OSError(f"{input_path} could not be found") + + if input_path.is_file(): + if input_path.suffix == ".csv" or input_path.suffix == ".tsv": + sep = "\t" if input_path.suffix == ".tsv" else "," + df = pandas.read_csv(input_path, sep=sep, dtype=str) + assert "id" in df.columns and "sequence" in df.columns + queries = [ + (seq_id, sequence.upper().split(":"), None) + for seq_id, sequence in df[["id", "sequence"]].itertuples(index=False) + ] + for i in range(len(queries)): + if len(queries[i][1]) == 1: + queries[i] = (queries[i][0], queries[i][1][0], None) + elif input_path.suffix == ".a3m": + (seqs, header) = parse_fasta(input_path.read_text()) + if len(seqs) == 0: + raise ValueError(f"{input_path} is empty") + query_sequence = seqs[0] + # Use a list so we can easily extend this to multiple msas later + a3m_lines = [input_path.read_text()] + queries = [(input_path.stem, query_sequence, a3m_lines)] + elif input_path.suffix in [".fasta", ".faa", ".fa"]: + (sequences, headers) = parse_fasta(input_path.read_text()) + queries = [] + for sequence, header in zip(sequences, headers): + sequence = sequence.upper() + if sequence.count(":") == 0: + # Single sequence + queries.append((header, sequence, None)) + else: + # Complex mode + queries.append((header, sequence.upper().split(":"), None)) + else: + raise ValueError(f"Unknown file format {input_path.suffix}") + else: + assert input_path.is_dir(), "Expected either an input file or a input directory" + queries = [] + for file in sorted(input_path.iterdir()): + if not file.is_file(): + continue + if file.suffix.lower() not in [".a3m", ".fasta", ".faa"]: + logger.warning(f"non-fasta/a3m file in input directory: {file}") + continue + (seqs, header) = parse_fasta(file.read_text()) + if len(seqs) == 0: + logger.error(f"{file} is empty") + continue + query_sequence = seqs[0] + if len(seqs) > 1 and file.suffix in [".fasta", ".faa", ".fa"]: + logger.warning( + f"More than one sequence in {file}, ignoring all but the first sequence" + ) + + if file.suffix.lower() == ".a3m": + a3m_lines = [file.read_text()] + queries.append((file.stem, query_sequence.upper(), a3m_lines)) + else: + if query_sequence.count(":") == 0: + # Single sequence + queries.append((file.stem, query_sequence, None)) + else: + # Complex mode + queries.append((file.stem, query_sequence.upper().split(":"), None)) + + # sort by seq. len + if sort_queries_by == "length": + queries.sort(key=lambda t: len("".join(t[1]))) + + elif sort_queries_by == "random": + random.shuffle(queries) + + is_complex = False + for job_number, (_, query_sequence, a3m_lines) in enumerate(queries): + if isinstance(query_sequence, list): + is_complex = True + break + if a3m_lines is not None and a3m_lines[0].startswith("#"): + a3m_line = a3m_lines[0].splitlines()[0] + tab_sep_entries = a3m_line[1:].split("\t") + if len(tab_sep_entries) == 2: + query_seq_len = tab_sep_entries[0].split(",") + query_seq_len = list(map(int, query_seq_len)) + query_seqs_cardinality = tab_sep_entries[1].split(",") + query_seqs_cardinality = list(map(int, query_seqs_cardinality)) + is_single_protein = ( + True + if len(query_seq_len) == 1 and query_seqs_cardinality[0] == 1 + else False + ) + if not is_single_protein: + is_complex = True + break + return queries, is_complex + +def pair_sequences( + a3m_lines: List[str], query_sequences: List[str], query_cardinality: List[int] +) -> str: + a3m_line_paired = [""] * len(a3m_lines[0].splitlines()) + for n, seq in enumerate(query_sequences): + lines = a3m_lines[n].splitlines() + for i, line in enumerate(lines): + if line.startswith(">"): + if n != 0: + line = line.replace(">", "\t", 1) + a3m_line_paired[i] = a3m_line_paired[i] + line + else: + a3m_line_paired[i] = a3m_line_paired[i] + line * query_cardinality[n] + return "\n".join(a3m_line_paired) + +def pad_sequences( + a3m_lines: List[str], query_sequences: List[str], query_cardinality: List[int] +) -> str: + _blank_seq = [ + ("-" * len(seq)) + for n, seq in enumerate(query_sequences) + for _ in range(query_cardinality[n]) + ] + a3m_lines_combined = [] + pos = 0 + for n, seq in enumerate(query_sequences): + for j in range(0, query_cardinality[n]): + lines = a3m_lines[n].split("\n") + for a3m_line in lines: + if len(a3m_line) == 0: + continue + if a3m_line.startswith(">"): + a3m_lines_combined.append(a3m_line) + else: + a3m_lines_combined.append( + "".join(_blank_seq[:pos] + [a3m_line] + _blank_seq[pos + 1 :]) + ) + pos += 1 + return "\n".join(a3m_lines_combined) + +def pair_msa( + query_seqs_unique: List[str], + query_seqs_cardinality: List[int], + paired_msa: Optional[List[str]], + unpaired_msa: Optional[List[str]], +) -> str: + if paired_msa is None and unpaired_msa is not None: + a3m_lines = pad_sequences( + unpaired_msa, query_seqs_unique, query_seqs_cardinality + ) + elif paired_msa is not None and unpaired_msa is not None: + a3m_lines = ( + pair_sequences(paired_msa, query_seqs_unique, query_seqs_cardinality) + + "\n" + + pad_sequences(unpaired_msa, query_seqs_unique, query_seqs_cardinality) + ) + elif paired_msa is not None and unpaired_msa is None: + a3m_lines = pair_sequences( + paired_msa, query_seqs_unique, query_seqs_cardinality + ) + else: + raise ValueError(f"Invalid pairing") + return a3m_lines + +def msa_to_str( + unpaired_msa: List[str], + paired_msa: List[str], + query_seqs_unique: List[str], + query_seqs_cardinality: List[int], +) -> str: + msa = "#" + ",".join(map(str, map(len, query_seqs_unique))) + "\t" + msa += ",".join(map(str, query_seqs_cardinality)) + "\n" + # build msa with cardinality of 1, it makes it easier to parse and manipulate + query_seqs_cardinality = [1 for _ in query_seqs_cardinality] + msa += pair_msa(query_seqs_unique, query_seqs_cardinality, paired_msa, unpaired_msa) + return msa + diff --git a/chai_lab/data/dataset/msas/colabfold_codes/utils.py b/chai_lab/data/dataset/msas/colabfold_codes/utils.py new file mode 100644 index 00000000..4f3b8903 --- /dev/null +++ b/chai_lab/data/dataset/msas/colabfold_codes/utils.py @@ -0,0 +1,2 @@ +def safe_filename(file: str) -> str: + return "".join([c if c.isalnum() or c in ["_", ".", "-"] else "_" for c in file]) From 5495bdc7a95be41201ed4182a5093ed857a40ccc Mon Sep 17 00:00:00 2001 From: hsbyeon1 Date: Wed, 11 Dec 2024 21:56:31 +0900 Subject: [PATCH 3/6] fix: colabfold utils import --- chai_lab/data/dataset/msas/local_mmseqs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chai_lab/data/dataset/msas/local_mmseqs.py b/chai_lab/data/dataset/msas/local_mmseqs.py index 0b2c5e47..5904d5ad 100644 --- a/chai_lab/data/dataset/msas/local_mmseqs.py +++ b/chai_lab/data/dataset/msas/local_mmseqs.py @@ -11,8 +11,8 @@ from pathlib import Path from typing import List, Union -from colabfold.batch import get_queries, msa_to_str -from colabfold.utils import safe_filename +from chai_lab.data.dataset.msas.colabfold_codes.batch import get_queries, msa_to_str +from chai_lab.data.dataset.msas.colabfold_codes.utils import safe_filename logger = logging.getLogger(__name__) From eb510211b7b1327f1b68aeb966df7d37e96f56df Mon Sep 17 00:00:00 2001 From: hsbyeon1 Date: Wed, 11 Dec 2024 21:57:21 +0900 Subject: [PATCH 4/6] feat: copy chai msa pipeline --- chai_lab/data/dataset/msas/local_msa.py | 186 ++++++++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 chai_lab/data/dataset/msas/local_msa.py diff --git a/chai_lab/data/dataset/msas/local_msa.py b/chai_lab/data/dataset/msas/local_msa.py new file mode 100644 index 00000000..8abfd184 --- /dev/null +++ b/chai_lab/data/dataset/msas/local_msa.py @@ -0,0 +1,186 @@ +import logging +from multiprocessing import Value +import os +import random +import tarfile +import time +import typing +from pathlib import Path + +import pandas as pd +import requests +from tqdm import tqdm + +logger = logging.getLogger(__name__) + +TQDM_BAR_FORMAT = ( + "{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]" +) + + +def run_mmseqs_locally(**kwargs) -> dict: + """ + returns status + TODO:change to a code that runs mmseqs2 locally and returns its status? + """ + return {} + + +def _run_mmseqs2( + x, + prefix, + use_env=True, + use_filter=True, + use_templates=False, + filter=None, + use_pairing=False, + pairing_strategy="greedy", + user_agent: str = "", +) -> list[str] | tuple[list[str], list[str]]: + #TODO :Run mmseqs2 locally and copy pairing from mmseqs2 + + if use_env: + raise ValueError("use_env=True not supported since our cluster has no metagenomics db downloaded") + # process input x + seqs = [x] if isinstance(x, str) else x + + # compatibility to old option + if filter is not None: + use_filter = filter + + # setup mode + if use_filter: + mode = "env" if use_env else "all" + else: + mode = "env-nofilter" if use_env else "nofilter" + + if use_pairing: + use_templates = False + mode = "" + # greedy is default, complete was the previous behavior + if pairing_strategy == "greedy": + mode = "pairgreedy" + elif pairing_strategy == "complete": + mode = "paircomplete" + if use_env: + mode = mode + "-env" + + # define path + path = f"{prefix}_{mode}" + if not os.path.isdir(path): + os.mkdir(path) + + # call mmseqs2 api + tar_gz_file = f"{path}/out.tar.gz" + N, REDO = 101, True + + # deduplicate and keep track of order + seqs_unique = [] + # TODO this might be slow for large sets + [seqs_unique.append(x) for x in seqs if x not in seqs_unique] + Ms = [N + seqs_unique.index(seq) for seq in seqs] + # lets do it! + if not os.path.isfile(tar_gz_file): # if mmseqs2 output does not exist + out = run_mmseqs_locally(seqs_unique, mode, N) + + # prep list of a3m files + if use_pairing: + a3m_files = [f"{path}/pair.a3m"] + else: + a3m_files = [f"{path}/uniref.a3m"] + if use_env: + a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m") + + # extract a3m files + if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files): + with tarfile.open(tar_gz_file) as tar_gz: + tar_gz.extractall(path) + + # templates + if use_templates: + templates = {} + # print("seq\tpdb\tcid\tevalue") + for line in open(f"{path}/pdb70.m8", "r"): + p = line.rstrip().split() + M, pdb, _, _ = p[0], p[1], p[2], p[10] + M = int(M) + if M not in templates: + templates[M] = [] + templates[M].append(pdb) + # if len(templates[M]) <= 20: + # print(f"{int(M)-N}\t{pdb}\t{qid}\t{e_value}") + + template_paths = {} + for k, TMPL in templates.items(): + TMPL_PATH = f"{prefix}_{mode}/templates_{k}" + if not os.path.isdir(TMPL_PATH): + os.mkdir(TMPL_PATH) + TMPL_LINE = ",".join(TMPL[:20]) + response = None + while True: + error_count = 0 + try: + # https://requests.readthedocs.io/en/latest/user/advanced/#advanced + # "good practice to set connect timeouts to slightly larger than a multiple of 3" + response = requests.get( + f"{host_url}/template/{TMPL_LINE}", + stream=True, + timeout=6.02, + headers=headers, + ) + except requests.exceptions.Timeout: + logger.warning( + "Timeout while submitting to template server. Retrying..." + ) + continue + except Exception as e: + error_count += 1 + logger.warning( + f"Error while fetching result from template server. Retrying... ({error_count}/5)" + ) + logger.warning(f"Error: {e}") + time.sleep(5) + if error_count > 5: + raise + continue + break + with tarfile.open(fileobj=response.raw, mode="r|gz") as tar: + tar.extractall(path=TMPL_PATH) + os.symlink("pdb70_a3m.ffindex", f"{TMPL_PATH}/pdb70_cs219.ffindex") + with open(f"{TMPL_PATH}/pdb70_cs219.ffdata", "w") as f: + f.write("") + template_paths[k] = TMPL_PATH + + # gather a3m lines + a3m_lines = {} + for a3m_file in a3m_files: + update_M, M = True, None + for line in open(a3m_file, "r"): + if len(line) > 0: + if "\x00" in line: + line = line.replace("\x00", "") + update_M = True + if line.startswith(">") and update_M: + M = int(line[1:].rstrip()) + update_M = False + if M not in a3m_lines: + a3m_lines[M] = [] + a3m_lines[M].append(line) + + # return results + + a3m_lines = ["".join(a3m_lines[n]) for n in Ms] + + if use_templates: + template_paths_ = [] + for n in Ms: + if n not in template_paths: + template_paths_.append(None) + # print(f"{n-N}\tno_templates_found") + else: + template_paths_.append(template_paths[n]) + template_paths = template_paths_ + + return (a3m_lines, template_paths) if use_templates else a3m_lines + +def run_ \ No newline at end of file From e5305847857ea994bfe777a85179ec13c06a62e0 Mon Sep 17 00:00:00 2001 From: hsbyeon1 Date: Wed, 11 Dec 2024 21:58:57 +0900 Subject: [PATCH 5/6] comment:add todo --- chai_lab/data/dataset/msas/local_mmseqs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chai_lab/data/dataset/msas/local_mmseqs.py b/chai_lab/data/dataset/msas/local_mmseqs.py index 5904d5ad..300894c7 100644 --- a/chai_lab/data/dataset/msas/local_mmseqs.py +++ b/chai_lab/data/dataset/msas/local_mmseqs.py @@ -265,6 +265,8 @@ def mmseqs_search_pair( # fmt: on def main(): + # a python wrapper for mmseqs2 + # TODO: change argparse to function kwargs parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser.add_argument( "query", From 626ffbf1d38a4a22599681cc2d1cf71bff486886 Mon Sep 17 00:00:00 2001 From: hsbyeon1 Date: Wed, 11 Dec 2024 21:59:32 +0900 Subject: [PATCH 6/6] remove typo --- chai_lab/data/dataset/msas/local_msa.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/chai_lab/data/dataset/msas/local_msa.py b/chai_lab/data/dataset/msas/local_msa.py index 8abfd184..741a2b70 100644 --- a/chai_lab/data/dataset/msas/local_msa.py +++ b/chai_lab/data/dataset/msas/local_msa.py @@ -181,6 +181,4 @@ def _run_mmseqs2( template_paths_.append(template_paths[n]) template_paths = template_paths_ - return (a3m_lines, template_paths) if use_templates else a3m_lines - -def run_ \ No newline at end of file + return (a3m_lines, template_paths) if use_templates else a3m_lines \ No newline at end of file