diff --git a/alphafold/data/mmcif_parsing.py b/alphafold/data/mmcif_parsing.py index 61cf149c0..48d4ce07f 100644 --- a/alphafold/data/mmcif_parsing.py +++ b/alphafold/data/mmcif_parsing.py @@ -165,6 +165,7 @@ def mmcif_loop_to_dict(prefix: str, def parse(*, file_id: str, mmcif_string: str, + is_pdb_file:bool=False, catch_all_errors: bool = True) -> ParsingResult: """Entry point, parses an mmcif_string. @@ -181,10 +182,14 @@ def parse(*, """ errors = {} try: - parser = PDB.MMCIFParser(QUIET=True) + if not is_pdb_file: + parser = PDB.MMCIFParser(QUIET=True) + else: + parser = PDB.PDBParser(QUIET=True) handle = io.StringIO(mmcif_string) full_structure = parser.get_structure('', handle) first_model_structure = _get_first_model(full_structure) + # Extract the _mmcif_dict from the parser, which contains useful fields not # reflected in the Biopython structure. parsed_info = parser._mmcif_dict # pylint:disable=protected-access @@ -273,6 +278,7 @@ def parse(*, return ParsingResult(mmcif_object=mmcif_object, errors=errors) except Exception as e: # pylint:disable=broad-except + print(f"##### line 281 mmcif_parsing failed to parse mmcif file") errors[(file_id, '')] = e if not catch_all_errors: raise diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py index a90eb5776..4ba269b18 100644 --- a/alphafold/data/pipeline.py +++ b/alphafold/data/pipeline.py @@ -26,7 +26,7 @@ from alphafold.data.tools import hmmsearch from alphafold.data.tools import jackhmmer import numpy as np - +from concurrent.futures import ProcessPoolExecutor, as_completed # Internal import (7716). FeatureDict = MutableMapping[str, np.ndarray] @@ -91,7 +91,7 @@ def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str, """Runs an MSA tool, checking if output already exists first.""" if not use_precomputed_msas or not os.path.exists(msa_out_path): if msa_format == 'sto' and max_sto_sequences is not None: - result = msa_runner.query(input_fasta_path, max_sto_sequences)[0] # pytype: disable=wrong-arg-count + result = msa_runner.query(input_fasta_path, max_sto_sequences)[0] # pytype: disable=wrong-arg-count else: result = msa_runner.query(input_fasta_path)[0] with open(msa_out_path, 'w') as f: @@ -147,19 +147,9 @@ def __init__(self, self.uniref_max_hits = uniref_max_hits self.use_precomputed_msas = use_precomputed_msas - def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: - """Runs alignment tools on the input sequence and creates features.""" - with open(input_fasta_path) as f: - input_fasta_str = f.read() - input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) - if len(input_seqs) != 1: - raise ValueError( - f'More than one input sequence found in {input_fasta_path}.') - input_sequence = input_seqs[0] - input_description = input_descs[0] - num_res = len(input_sequence) - - uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') + def run_jackhmmer_uniref90(self, input_fasta_path, uniref90_out_path): + """An async function that runs alignment against uniref90""" + logging.info(f"Now running uniref90 alignment concurrently") jackhmmer_uniref90_result = run_msa_tool( msa_runner=self.jackhmmer_uniref90_runner, input_fasta_path=input_fasta_path, @@ -167,7 +157,15 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: msa_format='sto', use_precomputed_msas=self.use_precomputed_msas, max_sto_sequences=self.uniref_max_hits) - mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') + msa_for_templates = jackhmmer_uniref90_result['sto'] + msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates) + msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa( + msa_for_templates) + return jackhmmer_uniref90_result, msa_for_templates + + def run_jackhmmer_mgnify(self, input_fasta_path, mgnify_out_path): + """An async function that runs msa alignment against mgnify database""" + logging.info(f"Now running mgnify alignment concurrently") jackhmmer_mgnify_result = run_msa_tool( msa_runner=self.jackhmmer_mgnify_runner, input_fasta_path=input_fasta_path, @@ -175,11 +173,69 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: msa_format='sto', use_precomputed_msas=self.use_precomputed_msas, max_sto_sequences=self.mgnify_max_hits) + return jackhmmer_mgnify_result - msa_for_templates = jackhmmer_uniref90_result['sto'] - msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates) - msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa( - msa_for_templates) + def run_bfd_alignments(self, msa_output_dir, input_fasta_path): + """An async function that runs msa alignment against bfd database""" + logging.info(f"Now running bfd alignment concurrently") + if self._use_small_bfd: + bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') + jackhmmer_small_bfd_result = run_msa_tool( + msa_runner=self.jackhmmer_small_bfd_runner, + input_fasta_path=input_fasta_path, + msa_out_path=bfd_out_path, + msa_format='sto', + use_precomputed_msas=self.use_precomputed_msas) + bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) + else: + bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m') + hhblits_bfd_uniref_result = run_msa_tool( + msa_runner=self.hhblits_bfd_uniref_runner, + input_fasta_path=input_fasta_path, + msa_out_path=bfd_out_path, + msa_format='a3m', + use_precomputed_msas=self.use_precomputed_msas) + bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m']) + + return bfd_msa + + def run_all_msa_runners(self, input_fasta_path:str, + uniref90_out_path:str, mgnify_out_path:str, + msa_output_dir:str): + """An async function that creates all async tasks and run them in parallel""" + logging.info(f"Now running MSA runners in parallel.") + with ProcessPoolExecutor(max_workers=5) as executor: + uniref_msa_process = [executor.submit(self.run_jackhmmer_uniref90,*(input_fasta_path, uniref90_out_path))] + mgnify_msa_process = [executor.submit(self.run_jackhmmer_mgnify,*(input_fasta_path, mgnify_out_path))] + bfd_msa_process = [executor.submit(self.run_bfd_alignments, *(msa_output_dir, input_fasta_path))] + jackhmmer_uniref90_result, msa_for_templates = [process.result() for process in as_completed(uniref_msa_process)][0] + jackhmmer_mgnify_result = [process.result() for process in as_completed(mgnify_msa_process)][0] + bfd_msa = [process.result() for process in as_completed(bfd_msa_process)][0] + + return {"uniref90_results": (jackhmmer_uniref90_result, msa_for_templates), + "mgnify_results": jackhmmer_mgnify_result, "bfd_results": bfd_msa} + + def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: + """Runs alignment tools on the input sequence and creates features.""" + with open(input_fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + if len(input_seqs) != 1: + raise ValueError( + f'More than one input sequence found in {input_fasta_path}.') + input_sequence = input_seqs[0] + input_description = input_descs[0] + num_res = len(input_sequence) + + uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') + + mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') + + msa_results = self.run_all_msa_runners(input_fasta_path,uniref90_out_path, + mgnify_out_path,msa_output_dir) + jackhmmer_uniref90_result, msa_for_templates = msa_results['uniref90_results'] + jackhmmer_mgnify_result = msa_results['mgnify_results'] + bfd_msa = msa_results['bfd_results'] if self.template_searcher.input_format == 'sto': pdb_templates_result = self.template_searcher.query(msa_for_templates) @@ -201,25 +257,6 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: pdb_template_hits = self.template_searcher.get_template_hits( output_string=pdb_templates_result, input_sequence=input_sequence) - if self._use_small_bfd: - bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') - jackhmmer_small_bfd_result = run_msa_tool( - msa_runner=self.jackhmmer_small_bfd_runner, - input_fasta_path=input_fasta_path, - msa_out_path=bfd_out_path, - msa_format='sto', - use_precomputed_msas=self.use_precomputed_msas) - bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) - else: - bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m') - hhblits_bfd_uniref_result = run_msa_tool( - msa_runner=self.hhblits_bfd_uniref_runner, - input_fasta_path=input_fasta_path, - msa_out_path=bfd_out_path, - msa_format='a3m', - use_precomputed_msas=self.use_precomputed_msas) - bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m']) - templates_result = self.template_featurizer.get_templates( query_sequence=input_sequence, hits=pdb_template_hits) diff --git a/alphafold/data/templates.py b/alphafold/data/templates.py index f2de65c6e..1a828c60c 100644 --- a/alphafold/data/templates.py +++ b/alphafold/data/templates.py @@ -601,6 +601,7 @@ def _extract_template_features( templates_aatype = residue_constants.sequence_to_onehot( output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID) + return ( { 'template_all_atom_positions': np.array(templates_all_atom_positions), diff --git a/alphafold/data/tools/hhblits.py b/alphafold/data/tools/hhblits.py index 1d8c180d8..027834c16 100644 --- a/alphafold/data/tools/hhblits.py +++ b/alphafold/data/tools/hhblits.py @@ -35,7 +35,7 @@ def __init__(self, *, binary_path: str, databases: Sequence[str], - n_cpu: int = 4, + n_cpu: int = 2, n_iter: int = 3, e_value: float = 0.001, maxseq: int = 1_000_000, diff --git a/alphafold/data/tools/jackhmmer.py b/alphafold/data/tools/jackhmmer.py index 68997f857..16739cc51 100644 --- a/alphafold/data/tools/jackhmmer.py +++ b/alphafold/data/tools/jackhmmer.py @@ -35,7 +35,7 @@ def __init__(self, *, binary_path: str, database_path: str, - n_cpu: int = 8, + n_cpu: int = 2, n_iter: int = 1, e_value: float = 0.0001, z_value: Optional[int] = None,