Skip to content
Merged

Dev #30

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions aide_predict/bespoke_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,17 +432,17 @@ def fit(self, X: Union[ProteinSequences, List[str]] = None, y: Optional[np.ndarr
X, y = hook(self, X, y)

# Handle MSA requirements, structure checks, etc.
if not empty_X:
if self.requires_msa_for_fit:
if empty_X:
if self.wt is not None and self.wt.has_msa:
warnings.warn("No input sequences provided, but the wild type sequence has an MSA. Attempting to use the wild type sequence MSA.")
X = self.wt.msa
else:
raise ValueError("No input sequences provided and the wild type sequence does not have an MSA. Cannot fit model.")
if self.requires_msa_for_fit:
if empty_X:
if self.wt is not None and self.wt.has_msa:
warnings.warn("No input sequences provided, but the wild type sequence has an MSA. Attempting to use the wild type sequence MSA.")
X = self.wt.msa
else:
raise ValueError("No input sequences provided and the wild type sequence does not have an MSA. Cannot fit model.")

X = self._enforce_aligned(X)
X = self._enforce_aligned(X)

if not empty_X:
if self.requires_structure:
if any(seq.structure is None for seq in X) and self.wt is None:
raise ValueError("This model requires structure information, at least one of the sequences does not have it, and there is no avialable WT structure.")
Expand Down
6 changes: 3 additions & 3 deletions aide_predict/bespoke_models/embedders/ssemb.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,14 @@ def _transform(self, X: ProteinSequences) -> List[np.ndarray]:
for i, seq in enumerate(batch):
seq_idx = batch_start + i
seq_id = seq.id if seq.id else f"seq_{seq_idx}"
seq_id = seq_id.replace("/", "_").replace(" ", "_")
sequence_ids.append(seq_id)
seq_to_idx[seq_id] = seq_idx

# Get structure path
if seq.structure is None:
if not hasattr(self, '_wt_structure_path'):
raise ValueError(f"Sequence {seq_id} has no structure and no WT structure is available")
structure_path = self._wt_structure_path
assert self.wt.structure
structure_path = self.wt.structure.pdb_file
logger.warning(f"Using WT structure for sequence {seq_id}")
else:
structure_path = seq.structure.pdb_file
Expand Down
3 changes: 2 additions & 1 deletion aide_predict/bespoke_models/predictors/eve.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
RequiresWTToFunctionMixin,
RequiresFixedLengthMixin,
RequiresWTMSAMixin,
RequiresMSAForFitMixin,
AcceptsLowerCaseMixin,
RequiresWTDuringInferenceMixin,
CanRegressMixin
Expand Down Expand Up @@ -59,7 +60,7 @@
else:
AVAILABLE = MessageBool(True, "EVE model is available")

class EVEWrapper(RequiresWTToFunctionMixin, RequiresFixedLengthMixin, RequiresWTDuringInferenceMixin, RequiresWTMSAMixin, AcceptsLowerCaseMixin, CanRegressMixin, ProteinModelWrapper):
class EVEWrapper(RequiresWTToFunctionMixin, RequiresFixedLengthMixin, RequiresWTDuringInferenceMixin, RequiresMSAForFitMixin, RequiresWTMSAMixin, AcceptsLowerCaseMixin, CanRegressMixin, ProteinModelWrapper):
"""
Wrapper for EVE (Evolutionary Variational Autoencoder) model.

Expand Down
4 changes: 2 additions & 2 deletions aide_predict/bespoke_models/predictors/evmutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import List, Union, Optional
import numpy as np
import pandas as pd
from aide_predict.bespoke_models.base import ProteinModelWrapper, RequiresWTMSAMixin, CanRegressMixin, RequiresWTToFunctionMixin, RequiresFixedLengthMixin, MessageBool, AcceptsLowerCaseMixin, CacheMixin
from aide_predict.bespoke_models.base import ProteinModelWrapper, RequiresWTMSAMixin, CanRegressMixin, RequiresWTToFunctionMixin, RequiresMSAForFitMixin, RequiresFixedLengthMixin, MessageBool, AcceptsLowerCaseMixin, CacheMixin
from aide_predict.utils.data_structures import ProteinSequences, ProteinSequence

from tqdm import tqdm
Expand All @@ -38,6 +38,7 @@ class EVMutationWrapper(
CacheMixin,
RequiresWTToFunctionMixin,
RequiresFixedLengthMixin,
RequiresMSAForFitMixin,
RequiresWTMSAMixin,
CanRegressMixin,
AcceptsLowerCaseMixin,
Expand Down Expand Up @@ -110,7 +111,6 @@ def _fit(self, X: ProteinSequences, y: Optional[np.ndarray] = None) -> 'EVCoupli
Returns:
EVCouplingsWrapper: The fitted model.
"""
X = self.wt.msa
if not X.width == len(self.wt):
raise ValueError("The sequences in the MSA must all have the same length as the wild-type sequence")
if not str(X[0]).upper() == str(self.wt).upper():
Expand Down
1 change: 1 addition & 0 deletions aide_predict/bespoke_models/predictors/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _fit(self, X: ProteinSequences, y: Optional[np.ndarray] = None) -> 'HMMWrapp
Raises:
ValueError: If the input sequences are not aligned.
"""

if not X.aligned:
raise ValueError("Input sequences must be aligned for HMM building.")

Expand Down
2 changes: 2 additions & 0 deletions aide_predict/bespoke_models/predictors/msa_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, metadata_folder: str=None,
batch_size: int = 32,
device: str = 'cpu',
n_msa_seqs: int = 360,
use_cache: bool = False,
wt: Optional[Union[str, ProteinSequence]] = None):
"""
Initialize the MSATransformerLikelihoodWrapper.
Expand All @@ -67,6 +68,7 @@ def __init__(self, metadata_folder: str=None,
flatten=flatten,
pool=pool,
batch_size=batch_size,
use_cache=use_cache,
device=device,
wt=wt)
self.n_msa_seqs = n_msa_seqs
Expand Down
2 changes: 2 additions & 0 deletions aide_predict/bespoke_models/predictors/saprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
flatten: bool = True,
wt: str = None,
batch_size: int = 2,
use_cache: bool = False,
device: str = 'cpu',
foldseek_path: str = 'foldseek'
):
Expand All @@ -104,6 +105,7 @@ def __init__(
positions=positions,
pool=pool,
flatten=flatten,
use_cache=use_cache,
wt=wt,
batch_size=batch_size,
device=device
Expand Down
3 changes: 2 additions & 1 deletion aide_predict/bespoke_models/predictors/vespa.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class VESPAWrapper(CanRegressMixin, ExpectsNoFitMixin, RequiresWTDuringInference

def __init__(self, metadata_folder: Optional[str] = None,
wt: Optional[Union[str, ProteinSequence]] = None,
use_cache: bool = False,
light: bool = True) -> None:
"""
Initialize the VESPAWrapper.
Expand All @@ -58,7 +59,7 @@ def __init__(self, metadata_folder: Optional[str] = None,
wt (Optional[Union[str, ProteinSequence]]): Wild-type protein sequence.
light (bool): If True, use the lighter VESPAl model. If False, use the full VESPA model.
"""
super().__init__(metadata_folder=metadata_folder, wt=wt)
super().__init__(metadata_folder=metadata_folder, use_cache=use_cache, wt=wt)
self.light = light

def _fit(self, X: ProteinSequences, y: Optional[np.ndarray] = None) -> 'VESPAWrapper':
Expand Down
11 changes: 6 additions & 5 deletions aide_predict/utils/data_structures/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def align(self, other: 'ProteinSequence') -> 'ProteinSequence':
aligned_self.msa = other
return aligned_self

def saturation_mutagenesis(self, positions: List[int]=None) -> List['ProteinSequence']:
def saturation_mutagenesis(self, positions: List[int]=None, include_wt: bool=False) -> List['ProteinSequence']:
"""
Perform saturation mutagenesis at the specified positions.

Expand All @@ -442,10 +442,11 @@ def saturation_mutagenesis(self, positions: List[int]=None) -> List['ProteinSequ
positions = range(len(self))
for i in positions:
for aa in AA_SINGLE:
if aa != self[i]:
mutated = self._mutate(i, aa)
mutated.id = f"{self[i]}{i+1}{aa}"
sequences.append(mutated)
if aa == self[i] and not include_wt:
continue
mutated = self._mutate(i, aa)
mutated.id = f"{self[i]}{i+1}{aa}"
sequences.append(mutated)
return ProteinSequences(sequences)

def upper(self) -> 'ProteinSequence':
Expand Down
129 changes: 129 additions & 0 deletions aide_predict/utils/msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,132 @@ def compute_conservation(self, msa, normalize=True, gap_treatment='exclude', gap

logger.info(f"Conservation calculation complete: min={conservation.min():.4f}, max={conservation.max():.4f}, mean={conservation.mean():.4f}")
return conservation


def remove_gappy_columns(self,
msa: ProteinSequences,
gap_threshold: Optional[float] = None,
focus_seq_id: Optional[str] = None) -> ProteinSequences:
"""
Remove columns from the MSA that exceed the gap threshold.

All sequences are retained, but columns with gap fraction above the threshold
are completely removed from the alignment.

Args:
msa (ProteinSequences): The input multiple sequence alignment.
gap_threshold (Optional[float]): Maximum allowed fraction of gaps per column.
If None, uses self.threshold_focus_cols_frac_gaps.
focus_seq_id (Optional[str]): If provided, only consider gaps relative to
non-gap positions in the focus sequence. If None, consider all positions.

Returns:
ProteinSequences: New MSA with high-gap columns removed.

Raises:
ValueError: If the input MSA is not aligned.
ValueError: If gap_threshold is not between 0 and 1.
ValueError: If focus_seq_id is provided but not found in MSA.
"""
if not msa.aligned:
raise ValueError("Input MSA must be aligned")

if gap_threshold is None:
gap_threshold = self.threshold_focus_cols_frac_gaps

if not 0 <= gap_threshold <= 1:
raise ValueError("gap_threshold must be between 0 and 1")

logger.info(f"Removing columns with gap fraction > {gap_threshold}")
logger.debug(f"Input MSA: {len(msa)} sequences, {msa.width} columns")

# Get MSA as array for easier manipulation
msa_array = msa.as_array()

# If focus sequence is specified, first filter to focus sequence non-gap positions
if focus_seq_id is not None:
if focus_seq_id not in msa.id_mapping:
raise ValueError(f"Focus sequence ID '{focus_seq_id}' not found in MSA")

focus_seq = msa[focus_seq_id]
focus_seq_array = np.array(list(str(focus_seq)))

# Only consider positions that are not gaps in the focus sequence
focus_positions = focus_seq_array != '-'
msa_array_filtered = msa_array[:, focus_positions]

logger.debug(f"Focus sequence '{focus_seq_id}' has {np.sum(focus_positions)} non-gap positions")
else:
msa_array_filtered = msa_array
focus_positions = np.ones(msa.width, dtype=bool)

# Calculate gap fraction for each column
gap_fractions = np.mean(msa_array_filtered == '-', axis=0)

# Identify columns that pass the threshold
columns_to_keep_filtered = gap_fractions <= gap_threshold

logger.debug(f"Gap fractions: min={gap_fractions.min():.3f}, "
f"max={gap_fractions.max():.3f}, mean={gap_fractions.mean():.3f}")
logger.debug(f"Columns passing threshold: {np.sum(columns_to_keep_filtered)}/{len(columns_to_keep_filtered)}")

# Map back to original column indices if focus sequence was used
if focus_seq_id is not None:
columns_to_keep = np.zeros(msa.width, dtype=bool)
columns_to_keep[focus_positions] = columns_to_keep_filtered
else:
columns_to_keep = columns_to_keep_filtered

# Check if any columns remain
if np.sum(columns_to_keep) == 0:
raise ValueError(f"No columns pass the gap threshold of {gap_threshold}. "
f"Consider increasing the threshold.")

# Filter the MSA array
filtered_msa_array = msa_array[:, columns_to_keep]

# Create new ProteinSequences with filtered columns
filtered_sequences = []
valid_sequence_indices = []

for i, original_seq in enumerate(msa):
filtered_seq_str = ''.join(filtered_msa_array[i])

# Check if sequence is all gaps after column removal
non_gap_chars = [char for char in filtered_seq_str if char not in GAP_CHARACTERS]

if len(non_gap_chars) == 0:
logger.debug(f"Removing sequence '{original_seq.id}' - all gaps after column filtering")
continue # Skip this sequence

# Create new ProteinSequence preserving metadata
filtered_seq = ProteinSequence(
filtered_seq_str,
id=original_seq.id,
structure=original_seq.structure
)

# Preserve MSA reference if it exists
if original_seq.has_msa:
filtered_seq.msa = original_seq.msa

filtered_sequences.append(filtered_seq)
valid_sequence_indices.append(i)

# Check if any sequences remain
if len(filtered_sequences) == 0:
raise ValueError("No sequences remain after removing all-gap sequences. "
f"Consider relaxing the gap threshold (current: {gap_threshold})")

# Create new ProteinSequences object
filtered_msa = ProteinSequences(filtered_sequences)

# Preserve weights for valid sequences only
if hasattr(msa, 'weights') and msa.weights is not None:
filtered_msa.weights = msa.weights[valid_sequence_indices]

removed_sequences = len(msa) - len(filtered_msa)
logger.info(f"Filtered MSA: {len(filtered_msa)} sequences, {filtered_msa.width} columns "
f"(removed {msa.width - filtered_msa.width} columns, {removed_sequences} all-gap sequences)")

return filtered_msa
Binary file modified docs/_build/latex/aide.pdf
Binary file not shown.
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

project = 'aide'
copyright = '2024, Evan Komp'
copyright = '2025, Gregg T. Beckham'
author = 'Evan Komp, Gregg T. Beckham'
release = '1.0.0'
release = '1.1.01'

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ User Guide
user_guide/structure_pred.md
user_guide/msa_search.md
user_guide/badass.md
user_guide/roadmap.md
user_guide/resource_test.md

.. toctree::
:maxdepth: 2
Expand Down
44 changes: 44 additions & 0 deletions docs/user_guide/resource_test.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
---
title: Resource benchmarking
---
# Resource testing

See below for the cost it takes to run each tool. This test was run with a Dual socket Intel Xeon Sapphire Rapids 52 core CPU. When the model supports GPU, one NVIDIA H100 was provided.

The test system was a GFP (238) amino acids, MSA depth (when applicable) was 201. Times measure the total time to fit the model (when applicable) and run prediction on 50 variants. Missing values are either because the core model does not support that type of prediction or because AIDE's wrapper does not support it.

## Zero shot predictors

| Model Name | Marginal Method | GPU Total Time (s) | CPU Total Time (s) |
|------------|----------------|-------------------|-------------------|
| HMMWrapper | - | - | 0.136 |
| ESM2LikelihoodWrapper | wildtype_marginal | 0.980 | 2.560 |
| ESM2LikelihoodWrapper | mutant_marginal | 0.534 | 30.837 |
| ESM2LikelihoodWrapper | masked_marginal | 0.718 | 62.507 |
| MSATransformerLikelihoodWrapper | wildtype_marginal | 4.067 | 33.974 |
| MSATransformerLikelihoodWrapper | mutant_marginal | 57.297 | Timeout (>1800s) |
| MSATransformerLikelihoodWrapper | masked_marginal | 110.086 | Timeout (>1800s) |
| EVMutationWrapper | - | - | 96.697 |
| SaProtLikelihoodWrapper | wildtype_marginal | 5.356 | 24.291 |
| SaProtLikelihoodWrapper | mutant_marginal | 7.326 | 220.906 |
| SaProtLikelihoodWrapper | masked_marginal | 14.814 | 429.626 |
| VESPAWrapper | - | 244.852 | - |
| EVEWrapper | - | 925.930 | - |
| SSEmbWrapper | - | 192.999 | - |

## Embedders

Cost for embedding 21 GFP sequences.

| Model Name | GPU Total Time (s) | CPU Total Time (s) |
|------------|-------------------|-------------------|
| ESM2Embedding | 0.887 | 1.477 |
| OneHotAlignedEmbedding | - | 0.092 |
| OneHotProteinEmbedding | - | 0.023 |
| MSATransformerEmbedding | 18.962 | 62.653 |
| SaProtEmbedding | 10.439 | 32.360 |
| KmerEmbedding | - | 0.005 |
| SSEmbEmbedding | 665.772 | - |



Loading
Loading