diff --git a/chebifier/prediction_models/__init__.py b/chebifier/prediction_models/__init__.py index 51fa6d6..bc29580 100644 --- a/chebifier/prediction_models/__init__.py +++ b/chebifier/prediction_models/__init__.py @@ -1,8 +1,9 @@ from .base_predictor import BasePredictor -from .chemlog_predictor import ChemlogPeptidesPredictor, ChemlogExtraPredictor +from .c3p_predictor import C3PPredictor +from .chebi_lookup import ChEBILookupPredictor +from .chemlog_predictor import ChemlogExtraPredictor, ChemlogPeptidesPredictor from .electra_predictor import ElectraPredictor from .gnn_predictor import ResGatedPredictor -from .chebi_lookup import ChEBILookupPredictor __all__ = [ "BasePredictor", @@ -11,4 +12,5 @@ "ResGatedPredictor", "ChEBILookupPredictor", "ChemlogExtraPredictor", + "C3PPredictor", ] diff --git a/chebifier/prediction_models/c3p_predictor.py b/chebifier/prediction_models/c3p_predictor.py index dc4704d..4ef58a3 100644 --- a/chebifier/prediction_models/c3p_predictor.py +++ b/chebifier/prediction_models/c3p_predictor.py @@ -1,8 +1,6 @@ from pathlib import Path from typing import List, Optional -from c3p import classifier as c3p_classifier - from chebifier import modelwise_smiles_lru_cache from chebifier.prediction_models import BasePredictor @@ -26,6 +24,8 @@ def __init__( @modelwise_smiles_lru_cache.batch_decorator def predict_smiles_list(self, smiles_list: list[str]) -> list: + from c3p import classifier as c3p_classifier + result_list = c3p_classifier.classify( list(smiles_list), self.program_directory, @@ -50,6 +50,8 @@ def explain_smiles(self, smiles): C3P provides natural language explanations for each prediction (positive or negative). Since there are more than 300 classes, only take the positive ones. """ + from c3p import classifier as c3p_classifier + highlights = [] result_list = c3p_classifier.classify( [smiles], self.program_directory, self.chemical_classes, strict=False diff --git a/chebifier/prediction_models/chebi_lookup.py b/chebifier/prediction_models/chebi_lookup.py index d145e24..c8ad75a 100644 --- a/chebifier/prediction_models/chebi_lookup.py +++ b/chebifier/prediction_models/chebi_lookup.py @@ -2,7 +2,6 @@ import os from typing import Optional -import networkx as nx from rdkit import Chem from chebifier import modelwise_smiles_lru_cache @@ -18,6 +17,7 @@ def __init__( chebi_version: int = 241, **kwargs, ): + super().__init__(model_name, **kwargs) self._description = ( description @@ -42,6 +42,8 @@ def get_smiles_lookup(self): return smiles_lookup def build_smiles_lookup(self): + import networkx as nx + smiles_lookup = dict() for chebi_id, smiles in nx.get_node_attributes( self.chebi_graph, "smiles" diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 99fa3b9..5a402d1 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -1,24 +1,9 @@ from typing import Optional import tqdm -from chemlog.alg_classification.charge_classifier import get_charge_category -from chemlog.alg_classification.peptide_size_classifier import get_n_amino_acid_residues -from chemlog.alg_classification.proteinogenics_classifier import ( - get_proteinogenic_amino_acids, -) -from chemlog.alg_classification.substructure_classifier import ( - is_diketopiperazine, - is_emericellamide, -) -from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call -from chemlog_extra.alg_classification.by_element_classification import ( - OrganoXCompoundClassifier, - XMolecularEntityClassifier, -) - -from chebifier import modelwise_smiles_lru_cache from .base_predictor import BasePredictor +from .. import modelwise_smiles_lru_cache AA_DICT = { "A": "L-alanine", @@ -48,15 +33,16 @@ class ChemlogExtraPredictor(BasePredictor): - CHEMLOG_CLASSIFIER = None def __init__(self, model_name: str, **kwargs): super().__init__(model_name, **kwargs) self.chebi_graph = kwargs.get("chebi_graph", None) - self.classifier = self.CHEMLOG_CLASSIFIER() + self.classifier = None @modelwise_smiles_lru_cache.batch_decorator def predict_smiles_list(self, smiles_list: list[str]) -> list: + from chemlog.cli import _smiles_to_mol + mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list] res = self.classifier.classify(mol_list) if self.chebi_graph is not None: @@ -73,15 +59,29 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list: class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor): - CHEMLOG_CLASSIFIER = XMolecularEntityClassifier + def __init__(self, model_name: str, **kwargs): + from chemlog_extra.alg_classification.by_element_classification import ( + XMolecularEntityClassifier, + ) + + super().__init__(model_name, **kwargs) + self.classifier = XMolecularEntityClassifier() class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor): - CHEMLOG_CLASSIFIER = OrganoXCompoundClassifier + def __init__(self, model_name: str, **kwargs): + from chemlog_extra.alg_classification.by_element_classification import ( + OrganoXCompoundClassifier, + ) + + super().__init__(model_name, **kwargs) + self.classifier = OrganoXCompoundClassifier() class ChemlogPeptidesPredictor(BasePredictor): def __init__(self, model_name: str, **kwargs): + from chemlog.cli import CLASSIFIERS + super().__init__(model_name, **kwargs) self.strategy = "algo" self.chebi_graph = kwargs.get("chebi_graph", None) @@ -97,6 +97,8 @@ def __init__(self, model_name: str, **kwargs): print(f"Initialised ChemLog model {self.model_name}") def predict_smiles(self, smiles: str) -> Optional[dict]: + from chemlog.cli import _smiles_to_mol, strategy_call + mol = _smiles_to_mol(smiles) if mol is None: return None @@ -133,6 +135,19 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list: def get_chemlog_result_info(self, smiles): """Get classification for single molecule with additional information.""" + from chemlog.alg_classification.charge_classifier import get_charge_category + from chemlog.alg_classification.peptide_size_classifier import ( + get_n_amino_acid_residues, + ) + from chemlog.alg_classification.proteinogenics_classifier import ( + get_proteinogenic_amino_acids, + ) + from chemlog.alg_classification.substructure_classifier import ( + is_diketopiperazine, + is_emericellamide, + ) + from chemlog.cli import _smiles_to_mol + mol = _smiles_to_mol(smiles) if mol is None or not smiles: return {"error": "Failed to parse SMILES"} diff --git a/chebifier/prediction_models/electra_predictor.py b/chebifier/prediction_models/electra_predictor.py index 7d64418..4843cad 100644 --- a/chebifier/prediction_models/electra_predictor.py +++ b/chebifier/prediction_models/electra_predictor.py @@ -1,9 +1,12 @@ +from typing import TYPE_CHECKING + import numpy as np -from chebai.models.electra import Electra -from chebai.preprocessing.reader import EMBEDDING_OFFSET, ChemDataReader from .nn_predictor import NNPredictor +if TYPE_CHECKING: + from chebai.models.electra import Electra + def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0): n_nodes = len(node_labels) @@ -37,10 +40,14 @@ def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0): class ElectraPredictor(NNPredictor): def __init__(self, model_name: str, ckpt_path: str, **kwargs): + from chebai.preprocessing.reader import ChemDataReader + super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs) print(f"Initialised Electra model {self.model_name} (device: {self.device})") - def init_model(self, ckpt_path: str, **kwargs) -> Electra: + def init_model(self, ckpt_path: str, **kwargs) -> "Electra": + from chebai.models.electra import Electra + model = Electra.load_from_checkpoint( ckpt_path, map_location=self.device, @@ -53,6 +60,8 @@ def init_model(self, ckpt_path: str, **kwargs) -> Electra: return model def explain_smiles(self, smiles) -> dict: + from chebai.preprocessing.reader import EMBEDDING_OFFSET + reader = self.reader_cls() token_dict = reader.to_data(dict(features=smiles, labels=None)) tokens = np.array(token_dict["features"]).astype(int).tolist() diff --git a/chebifier/prediction_models/gnn_predictor.py b/chebifier/prediction_models/gnn_predictor.py index 3d6fc92..6d2a1a5 100644 --- a/chebifier/prediction_models/gnn_predictor.py +++ b/chebifier/prediction_models/gnn_predictor.py @@ -1,15 +1,16 @@ -import chebai_graph.preprocessing.properties as p -import torch -from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred -from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder -from chebai_graph.preprocessing.reader import GraphPropertyReader -from torch_geometric.data.data import Data as GeomData +from typing import TYPE_CHECKING from .nn_predictor import NNPredictor +if TYPE_CHECKING: + from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred + class ResGatedPredictor(NNPredictor): def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs): + from chebai_graph.preprocessing.properties import MolecularProperty + from chebai_graph.preprocessing.reader import GraphPropertyReader + super().__init__( model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs ) @@ -23,7 +24,7 @@ def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwar properties = [] self.molecular_properties = properties assert isinstance(self.molecular_properties, list) and all( - isinstance(prop, p.MolecularProperty) for prop in self.molecular_properties + isinstance(prop, MolecularProperty) for prop in self.molecular_properties ) print(f"Initialised GNN model {self.model_name} (device: {self.device})") @@ -32,7 +33,10 @@ def load_class(self, class_path: str): module = __import__(module_path, fromlist=[class_name]) return getattr(module, class_name) - def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred: + def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphConvNetGraphPred": + import torch + from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred + model = ResGatedGraphConvNetGraphPred.load_from_checkpoint( ckpt_path, map_location=torch.device(self.device), @@ -45,6 +49,14 @@ def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred: return model def read_smiles(self, smiles): + import torch + from chebai_graph.preprocessing.properties import AtomProperty, BondProperty + from chebai_graph.preprocessing.property_encoder import ( + IndexEncoder, + OneHotEncoder, + ) + from torch_geometric.data.data import Data as GeomData + reader = self.reader_cls() d = reader.to_data(dict(features=smiles, labels=None)) geom_data = d["features"] @@ -87,9 +99,9 @@ def read_smiles(self, smiles): encoded_values = encoded_values.unsqueeze(1) else: encoded_values = torch.zeros((0, prop.encoder.get_encoding_length())) - if isinstance(prop, p.AtomProperty): + if isinstance(prop, AtomProperty): x = torch.cat([x, encoded_values], dim=1) - elif isinstance(prop, p.BondProperty): + elif isinstance(prop, BondProperty): edge_attr = torch.cat([edge_attr, encoded_values], dim=1) else: molecule_attr = torch.cat([molecule_attr, encoded_values[0]], dim=1) diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index 79dcad9..c71e000 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -1,5 +1,4 @@ import numpy as np -import torch import tqdm from rdkit import Chem @@ -17,6 +16,8 @@ def __init__( target_labels_path: str, **kwargs, ): + import torch + super().__init__(model_name, **kwargs) self.reader_cls = reader_cls @@ -56,6 +57,8 @@ def read_smiles(self, smiles): def predict_smiles_list(self, smiles_list: list[str]) -> list: """Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary Of classes and predicted values.""" + import torch + token_dicts = [] could_not_parse = [] index_map = dict() diff --git a/pyproject.toml b/pyproject.toml index f01efe9..aa2a96d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,14 +20,15 @@ classifiers = [ dependencies = [ "click", "pyyaml", - "torch", "tqdm", "rdkit", - "chebai>=1.0.1", - "chemlog>=1.0.4", + # Package to install manually if required + #"chebai>=1.0.1", + #"chemlog>=1.0.4", + # pypi does not support git dependencies #"chemlog_extra @ git+https://github.com/ChEB-AI/chemlog-extra.git", - "c3p" + # forked version of c3p is windows-compatible #"c3p @ git+https://github.com/sfluegel05/c3p.git" ]