Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions chebifier/prediction_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_predictor import BasePredictor
from .chemlog_predictor import ChemlogPeptidesPredictor, ChemlogExtraPredictor
from .c3p_predictor import C3PPredictor
from .chebi_lookup import ChEBILookupPredictor
from .chemlog_predictor import ChemlogExtraPredictor, ChemlogPeptidesPredictor
from .electra_predictor import ElectraPredictor
from .gnn_predictor import ResGatedPredictor
from .chebi_lookup import ChEBILookupPredictor

__all__ = [
"BasePredictor",
Expand All @@ -11,4 +12,5 @@
"ResGatedPredictor",
"ChEBILookupPredictor",
"ChemlogExtraPredictor",
"C3PPredictor",
]
6 changes: 4 additions & 2 deletions chebifier/prediction_models/c3p_predictor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from pathlib import Path
from typing import List, Optional

from c3p import classifier as c3p_classifier

from chebifier import modelwise_smiles_lru_cache
from chebifier.prediction_models import BasePredictor

Expand All @@ -26,6 +24,8 @@ def __init__(

@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> list:
from c3p import classifier as c3p_classifier

result_list = c3p_classifier.classify(
list(smiles_list),
self.program_directory,
Expand All @@ -50,6 +50,8 @@ def explain_smiles(self, smiles):
C3P provides natural language explanations for each prediction (positive or negative). Since there are more
than 300 classes, only take the positive ones.
"""
from c3p import classifier as c3p_classifier

highlights = []
result_list = c3p_classifier.classify(
[smiles], self.program_directory, self.chemical_classes, strict=False
Expand Down
4 changes: 3 additions & 1 deletion chebifier/prediction_models/chebi_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from typing import Optional

import networkx as nx
from rdkit import Chem

from chebifier import modelwise_smiles_lru_cache
Expand All @@ -18,6 +17,7 @@ def __init__(
chebi_version: int = 241,
**kwargs,
):

super().__init__(model_name, **kwargs)
self._description = (
description
Expand All @@ -42,6 +42,8 @@ def get_smiles_lookup(self):
return smiles_lookup

def build_smiles_lookup(self):
import networkx as nx

smiles_lookup = dict()
for chebi_id, smiles in nx.get_node_attributes(
self.chebi_graph, "smiles"
Expand Down
55 changes: 35 additions & 20 deletions chebifier/prediction_models/chemlog_predictor.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,9 @@
from typing import Optional

import tqdm
from chemlog.alg_classification.charge_classifier import get_charge_category
from chemlog.alg_classification.peptide_size_classifier import get_n_amino_acid_residues
from chemlog.alg_classification.proteinogenics_classifier import (
get_proteinogenic_amino_acids,
)
from chemlog.alg_classification.substructure_classifier import (
is_diketopiperazine,
is_emericellamide,
)
from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call
from chemlog_extra.alg_classification.by_element_classification import (
OrganoXCompoundClassifier,
XMolecularEntityClassifier,
)

from chebifier import modelwise_smiles_lru_cache

from .base_predictor import BasePredictor
from .. import modelwise_smiles_lru_cache

AA_DICT = {
"A": "L-alanine",
Expand Down Expand Up @@ -48,15 +33,16 @@


class ChemlogExtraPredictor(BasePredictor):
CHEMLOG_CLASSIFIER = None

def __init__(self, model_name: str, **kwargs):
super().__init__(model_name, **kwargs)
self.chebi_graph = kwargs.get("chebi_graph", None)
self.classifier = self.CHEMLOG_CLASSIFIER()
self.classifier = None

@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> list:
from chemlog.cli import _smiles_to_mol

mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list]
res = self.classifier.classify(mol_list)
if self.chebi_graph is not None:
Expand All @@ -73,15 +59,29 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list:


class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor):
CHEMLOG_CLASSIFIER = XMolecularEntityClassifier
def __init__(self, model_name: str, **kwargs):
from chemlog_extra.alg_classification.by_element_classification import (
XMolecularEntityClassifier,
)

super().__init__(model_name, **kwargs)
self.classifier = XMolecularEntityClassifier()


class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor):
CHEMLOG_CLASSIFIER = OrganoXCompoundClassifier
def __init__(self, model_name: str, **kwargs):
from chemlog_extra.alg_classification.by_element_classification import (
OrganoXCompoundClassifier,
)

super().__init__(model_name, **kwargs)
self.classifier = OrganoXCompoundClassifier()


class ChemlogPeptidesPredictor(BasePredictor):
def __init__(self, model_name: str, **kwargs):
from chemlog.cli import CLASSIFIERS

super().__init__(model_name, **kwargs)
self.strategy = "algo"
self.chebi_graph = kwargs.get("chebi_graph", None)
Expand All @@ -97,6 +97,8 @@ def __init__(self, model_name: str, **kwargs):
print(f"Initialised ChemLog model {self.model_name}")

def predict_smiles(self, smiles: str) -> Optional[dict]:
from chemlog.cli import _smiles_to_mol, strategy_call

mol = _smiles_to_mol(smiles)
if mol is None:
return None
Expand Down Expand Up @@ -133,6 +135,19 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list:

def get_chemlog_result_info(self, smiles):
"""Get classification for single molecule with additional information."""
from chemlog.alg_classification.charge_classifier import get_charge_category
from chemlog.alg_classification.peptide_size_classifier import (
get_n_amino_acid_residues,
)
from chemlog.alg_classification.proteinogenics_classifier import (
get_proteinogenic_amino_acids,
)
from chemlog.alg_classification.substructure_classifier import (
is_diketopiperazine,
is_emericellamide,
)
from chemlog.cli import _smiles_to_mol

mol = _smiles_to_mol(smiles)
if mol is None or not smiles:
return {"error": "Failed to parse SMILES"}
Expand Down
15 changes: 12 additions & 3 deletions chebifier/prediction_models/electra_predictor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import TYPE_CHECKING

import numpy as np
from chebai.models.electra import Electra
from chebai.preprocessing.reader import EMBEDDING_OFFSET, ChemDataReader

from .nn_predictor import NNPredictor

if TYPE_CHECKING:
from chebai.models.electra import Electra


def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):
n_nodes = len(node_labels)
Expand Down Expand Up @@ -37,10 +40,14 @@ def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):

class ElectraPredictor(NNPredictor):
def __init__(self, model_name: str, ckpt_path: str, **kwargs):
from chebai.preprocessing.reader import ChemDataReader

super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs)
print(f"Initialised Electra model {self.model_name} (device: {self.device})")

def init_model(self, ckpt_path: str, **kwargs) -> Electra:
def init_model(self, ckpt_path: str, **kwargs) -> "Electra":
from chebai.models.electra import Electra

model = Electra.load_from_checkpoint(
ckpt_path,
map_location=self.device,
Expand All @@ -53,6 +60,8 @@ def init_model(self, ckpt_path: str, **kwargs) -> Electra:
return model

def explain_smiles(self, smiles) -> dict:
from chebai.preprocessing.reader import EMBEDDING_OFFSET

reader = self.reader_cls()
token_dict = reader.to_data(dict(features=smiles, labels=None))
tokens = np.array(token_dict["features"]).astype(int).tolist()
Expand Down
32 changes: 22 additions & 10 deletions chebifier/prediction_models/gnn_predictor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import chebai_graph.preprocessing.properties as p
import torch
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred
from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder
from chebai_graph.preprocessing.reader import GraphPropertyReader
from torch_geometric.data.data import Data as GeomData
from typing import TYPE_CHECKING

from .nn_predictor import NNPredictor

if TYPE_CHECKING:
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred


class ResGatedPredictor(NNPredictor):
def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs):
from chebai_graph.preprocessing.properties import MolecularProperty
from chebai_graph.preprocessing.reader import GraphPropertyReader

super().__init__(
model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs
)
Expand All @@ -23,7 +24,7 @@ def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwar
properties = []
self.molecular_properties = properties
assert isinstance(self.molecular_properties, list) and all(
isinstance(prop, p.MolecularProperty) for prop in self.molecular_properties
isinstance(prop, MolecularProperty) for prop in self.molecular_properties
)
print(f"Initialised GNN model {self.model_name} (device: {self.device})")

Expand All @@ -32,7 +33,10 @@ def load_class(self, class_path: str):
module = __import__(module_path, fromlist=[class_name])
return getattr(module, class_name)

def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred:
def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphConvNetGraphPred":
import torch
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred

model = ResGatedGraphConvNetGraphPred.load_from_checkpoint(
ckpt_path,
map_location=torch.device(self.device),
Expand All @@ -45,6 +49,14 @@ def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred:
return model

def read_smiles(self, smiles):
import torch
from chebai_graph.preprocessing.properties import AtomProperty, BondProperty
from chebai_graph.preprocessing.property_encoder import (
IndexEncoder,
OneHotEncoder,
)
from torch_geometric.data.data import Data as GeomData

reader = self.reader_cls()
d = reader.to_data(dict(features=smiles, labels=None))
geom_data = d["features"]
Expand Down Expand Up @@ -87,9 +99,9 @@ def read_smiles(self, smiles):
encoded_values = encoded_values.unsqueeze(1)
else:
encoded_values = torch.zeros((0, prop.encoder.get_encoding_length()))
if isinstance(prop, p.AtomProperty):
if isinstance(prop, AtomProperty):
x = torch.cat([x, encoded_values], dim=1)
elif isinstance(prop, p.BondProperty):
elif isinstance(prop, BondProperty):
edge_attr = torch.cat([edge_attr, encoded_values], dim=1)
else:
molecule_attr = torch.cat([molecule_attr, encoded_values[0]], dim=1)
Expand Down
5 changes: 4 additions & 1 deletion chebifier/prediction_models/nn_predictor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import torch
import tqdm
from rdkit import Chem

Expand All @@ -17,6 +16,8 @@ def __init__(
target_labels_path: str,
**kwargs,
):
import torch

super().__init__(model_name, **kwargs)
self.reader_cls = reader_cls

Expand Down Expand Up @@ -56,6 +57,8 @@ def read_smiles(self, smiles):
def predict_smiles_list(self, smiles_list: list[str]) -> list:
"""Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary
Of classes and predicted values."""
import torch

token_dicts = []
could_not_parse = []
index_map = dict()
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ classifiers = [
dependencies = [
"click",
"pyyaml",
"torch",
"tqdm",
"rdkit",
"chebai>=1.0.1",
"chemlog>=1.0.4",
# Package to install manually if required
#"chebai>=1.0.1",
#"chemlog>=1.0.4",

# pypi does not support git dependencies
#"chemlog_extra @ git+https://github.com/ChEB-AI/chemlog-extra.git",
"c3p"

# forked version of c3p is windows-compatible
#"c3p @ git+https://github.com/sfluegel05/c3p.git"
]
Expand Down