From d7b7bc080ce3716cc484d258d647cd6732094d4b Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Wed, 27 Mar 2024 12:02:42 +0100 Subject: [PATCH 1/6] update make_fixed_size to accomodate numpy arrays --- alphafold/model/tf/data_transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/alphafold/model/tf/data_transforms.py b/alphafold/model/tf/data_transforms.py index 7af966ef4..05945dfca 100644 --- a/alphafold/model/tf/data_transforms.py +++ b/alphafold/model/tf/data_transforms.py @@ -20,7 +20,6 @@ from alphafold.model.tf import utils import numpy as np import tensorflow.compat.v1 as tf - # Pylint gets confused by the curry1 decorator because it changes the number # of arguments to the function. # pylint:disable=no-value-for-parameter @@ -413,19 +412,20 @@ def make_masked_msa(protein, config, replace_fraction): def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num_res, num_templates=0): """Guess at the MSA and sequence dimensions to make fixed size.""" - pad_size_map = { NUM_RES: num_res, NUM_MSA_SEQ: msa_cluster_size, NUM_EXTRA_SEQ: extra_msa_size, NUM_TEMPLATES: num_templates, } - for k, v in protein.items(): # Don't transfer this to the accelerator. if k == 'extra_cluster_assignment': continue - shape = v.shape.as_list() + if type(v) ==np.ndarray: + shape = v.shape + else: + shape = v.shape.as_list() schema = shape_schema[k] assert len(shape) == len(schema), ( f'Rank mismatch between shape and shape schema for {k}: ' From 3257ff2f6f6945ab1dcd6b86abbf49cf3f98141b Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Wed, 27 Mar 2024 12:20:23 +0100 Subject: [PATCH 2/6] now allow monomeric modelling to manually set the targeted total number of residues to pad --- alphafold/model/features.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/alphafold/model/features.py b/alphafold/model/features.py index c261cef19..062248996 100644 --- a/alphafold/model/features.py +++ b/alphafold/model/features.py @@ -77,10 +77,13 @@ def tf_example_to_features(tf_example: tf.train.Example, def np_example_to_features(np_example: FeatureDict, config: ml_collections.ConfigDict, - random_seed: int = 0) -> FeatureDict: + random_seed: int = 0, desired_num_res: int = None) -> FeatureDict: """Preprocesses NumPy feature dict using TF pipeline.""" np_example = dict(np_example) - num_res = int(np_example['seq_length'][0]) + if desired_num_res is not None: + num_res = desired_num_res + else: + num_res = int(np_example['seq_length'][0]) cfg, feature_names = make_data_config(config, num_res=num_res) if 'deletion_matrix_int' in np_example: From f253385a33fdfe0f8848f95f0dea7dcb0c926654 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Tue, 16 Apr 2024 15:32:57 +0200 Subject: [PATCH 3/6] update model.py to accomodate new way of alphapulldown modelling with padding --- alphafold/model/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/alphafold/model/model.py b/alphafold/model/model.py index 072355acd..25eeea27a 100644 --- a/alphafold/model/model.py +++ b/alphafold/model/model.py @@ -170,8 +170,8 @@ def predict(self, # already happening when computing get_confidence_metrics, and this ensures # all outputs are blocked on. jax.tree_map(lambda x: x.block_until_ready(), result) - result.update( - get_confidence_metrics(result, multimer_mode=self.multimer_mode)) + # result.update( + # get_confidence_metrics(result, multimer_mode=self.multimer_mode)) logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result)) return result From 0a15fed29378900f1979bef316dd5b6500ef9eba Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Tue, 16 Apr 2024 15:44:39 +0200 Subject: [PATCH 4/6] update model.py to accomodate new way of alphapulldown modelling with padding --- alphafold/model/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/alphafold/model/model.py b/alphafold/model/model.py index 25eeea27a..7cacd04e1 100644 --- a/alphafold/model/model.py +++ b/alphafold/model/model.py @@ -172,6 +172,8 @@ def predict(self, jax.tree_map(lambda x: x.block_until_ready(), result) # result.update( # get_confidence_metrics(result, multimer_mode=self.multimer_mode)) + result.update({"plddt": confidence.compute_plddt( + result['predicted_lddt']['logits'])}) logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result)) return result From a3fce635976635565375a10e6852105bd4c9d81d Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Tue, 16 Apr 2024 17:55:42 +0200 Subject: [PATCH 5/6] update the codes --- alphafold/model/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/alphafold/model/model.py b/alphafold/model/model.py index 7cacd04e1..043b6f902 100644 --- a/alphafold/model/model.py +++ b/alphafold/model/model.py @@ -165,15 +165,16 @@ def predict(self, logging.info('Running predict with shape(feat) = %s', tree.map_structure(lambda x: x.shape, feat)) result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat) - + logging.info(f"PAE has keys: {result['predicted_aligned_error'].keys()}") # This block is to ensure benchmark timings are accurate. Some blocking is # already happening when computing get_confidence_metrics, and this ensures # all outputs are blocked on. jax.tree_map(lambda x: x.block_until_ready(), result) # result.update( - # get_confidence_metrics(result, multimer_mode=self.multimer_mode)) + # get_confidence_metrics(result, multimer_mode=self.multimer_mode)) result.update({"plddt": confidence.compute_plddt( result['predicted_lddt']['logits'])}) + logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result)) return result From c844e1bb60a3beb50bb8d562c6be046da1e43e3d Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Mon, 29 Apr 2024 14:50:20 +0200 Subject: [PATCH 6/6] remove unnecessary logging --- alphafold/model/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/alphafold/model/model.py b/alphafold/model/model.py index 043b6f902..a48d36d6c 100644 --- a/alphafold/model/model.py +++ b/alphafold/model/model.py @@ -165,7 +165,6 @@ def predict(self, logging.info('Running predict with shape(feat) = %s', tree.map_structure(lambda x: x.shape, feat)) result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat) - logging.info(f"PAE has keys: {result['predicted_aligned_error'].keys()}") # This block is to ensure benchmark timings are accurate. Some blocking is # already happening when computing get_confidence_metrics, and this ensures # all outputs are blocked on.