Add dropout during inference flag for improved uncertainty estimation across all backends #533
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds a new
--dropout_during_inferenceflag to therun_structure_prediction.pyscript that enables dropout during inference for improved uncertainty estimation. The flag is supported across all available backends: AlphaFold, AlphaFold3, AlphaLink, and UniFold.Problem
Currently, there's no way to enable dropout during inference in AlphaPulldown, which limits the ability to obtain uncertainty estimates for predictions. Dropout during inference can provide valuable insights into model confidence and prediction reliability.
Solution
Added a new boolean flag
--dropout_during_inferencethat:Falseto maintain backward compatibilityeval_dropout=Truein the respective model configurationssetup()functions via the existing parameter passing mechanismImplementation Details
The implementation follows the pseudocode specified in the issue:
Changes Made
run_structure_prediction.py: Added flag definition and included it indefault_model_flagsalphafold_backend.py: Added parameter tosetup()and logic to setmodel_config.model.global_config.eval_dropout = Truealphafold3_backend.py: Added parameter and logic inmake_model_config()to setconfig.global_config.eval_dropout = Truealphalink_backend.py: Added parameter and logic with fallback handling for different config structuresunifold_backend.py: Added parameter and logic with proper error handling and loggingbackend/__init__.py: Updated available backends list to include alphalinkError Handling
The implementation includes proper error handling with
hasattr()checks and warning messages whenglobal_configis not available, ensuring compatibility across different model configurations.Usage
# Enable dropout during inference for any backend python alphapulldown/scripts/run_structure_prediction.py \ --input protein1.fasta \ --output_directory ./results \ --data_directory /path/to/weights \ --features_directory /path/to/features \ --fold_backend alphafold \ --dropout_during_inferenceTesting
Added comprehensive tests in
test/test_dropout_during_inference.pythat verify:setup()functions accept the parameterThe changes are minimal (132 insertions, 2 deletions) and maintain full backward compatibility.
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.