diff --git a/README.md b/README.md index fcbbb37..8d59280 100644 --- a/README.md +++ b/README.md @@ -75,15 +75,11 @@ python -m chebifier predict --help You can also use the package programmatically: ```python -from chebifier.ensemble.base_ensemble import BaseEnsemble -import yaml +from chebifier import BaseEnsemble -# Load configuration from YAML file -with open('configs/example_config.yml', 'r') as f: - config = yaml.safe_load(f) - -# Instantiate ensemble model -ensemble = BaseEnsemble(config) +# Instantiate ensemble model. If desired, can pass +# a path to a configuration, like 'configs/example_config.yml' +ensemble = BaseEnsemble() # Make predictions smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"] diff --git a/chebifier/__init__.py b/chebifier/__init__.py index aa1e6ec..a4f770c 100644 --- a/chebifier/__init__.py +++ b/chebifier/__init__.py @@ -1,6 +1,11 @@ # Note: The top-level package __init__.py runs only once, # even if multiple subpackages are imported later. -from ._custom_cache import PerSmilesPerModelLRUCache +from ._custom_cache import PerSmilesPerModelLRUCache, modelwise_smiles_lru_cache +from .ensemble.base_ensemble import BaseEnsemble -modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100) +__all__ = [ + "BaseEnsemble", + "PerSmilesPerModelLRUCache", + "modelwise_smiles_lru_cache", +] diff --git a/chebifier/_custom_cache.py b/chebifier/_custom_cache.py index 38b500f..81204aa 100644 --- a/chebifier/_custom_cache.py +++ b/chebifier/_custom_cache.py @@ -6,6 +6,11 @@ from functools import wraps from typing import Any, Callable +__all__ = [ + "PerSmilesPerModelLRUCache", + "modelwise_smiles_lru_cache", +] + class PerSmilesPerModelLRUCache: """ @@ -206,3 +211,6 @@ def _load_cache(self) -> None: self._cache = loaded except Exception as e: print(f"[Cache Load Error] {e}") + + +modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100) diff --git a/chebifier/cli.py b/chebifier/cli.py index c201187..a3db5d6 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -1,7 +1,4 @@ -import importlib.resources - import click -import yaml from chebifier.model_registry import ENSEMBLES @@ -72,43 +69,10 @@ def predict( resolve_inconsistencies=True, ): """Predict ChEBI classes for SMILES strings using an ensemble model.""" - # Load configuration from YAML file - if not ensemble_config: - print("Using default ensemble configuration") - with ( - importlib.resources.files("chebifier") - .joinpath("ensemble.yml") - .open("r") as f - ): - config = yaml.safe_load(f) - else: - print(f"Loading ensemble configuration from {ensemble_config}") - with open(ensemble_config, "r") as f: - config = yaml.safe_load(f) - - with ( - importlib.resources.files("chebifier") - .joinpath("model_registry.yml") - .open("r") as f - ): - model_registry = yaml.safe_load(f) - - new_config = {} - for model_name, entry in config.items(): - if "load_model" in entry: - if entry["load_model"] not in model_registry: - raise ValueError( - f"Model {entry['load_model']} not found in model registry. " - f"Available models are: {','.join(model_registry.keys())}." - ) - new_config[model_name] = {**model_registry[entry["load_model"]], **entry} - else: - new_config[model_name] = entry - config = new_config # Instantiate ensemble model ensemble = ENSEMBLES[ensemble_type]( - config, + ensemble_config, chebi_version=chebi_version, resolve_inconsistencies=resolve_inconsistencies, ) diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 6a3acef..ad4efba 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -1,32 +1,60 @@ +import importlib import os import time +from pathlib import Path +from typing import Union import torch import tqdm +import yaml from chebifier.check_env import check_package_installed from chebifier.hugging_face import download_model_files from chebifier.inconsistency_resolution import PredictionSmoother from chebifier.prediction_models.base_predictor import BasePredictor -from chebifier.utils import get_disjoint_files, load_chebi_graph +from chebifier.utils import ( + get_default_configs, + get_disjoint_files, + load_chebi_graph, + process_config, +) class BaseEnsemble: def __init__( self, - model_configs: dict, + model_configs: Union[str, Path, dict, None] = None, chebi_version: int = 241, resolve_inconsistencies: bool = True, ): # Deferred Import: To avoid circular import error from chebifier.model_registry import MODEL_TYPES + # Load configuration from YAML file + if not model_configs: + config = get_default_configs() + elif isinstance(model_configs, dict): + config = model_configs + else: + print(f"Loading ensemble configuration from {model_configs}") + with open(model_configs, "r") as f: + config = yaml.safe_load(f) + + with ( + importlib.resources.files("chebifier") + .joinpath("model_registry.yml") + .open("r") as f + ): + model_registry = yaml.safe_load(f) + + processed_configs = process_config(config, model_registry) + self.chebi_graph = load_chebi_graph() self.disjoint_files = get_disjoint_files() self.models = [] self.positive_prediction_threshold = 0.5 - for model_name, model_config in model_configs.items(): + for model_name, model_config in processed_configs.items(): model_cls = MODEL_TYPES[model_config["type"]] if "hugging_face" in model_config: hugging_face_kwargs = download_model_files(model_config["hugging_face"]) diff --git a/chebifier/inconsistency_resolution.py b/chebifier/inconsistency_resolution.py index f6640c2..6a9a45e 100644 --- a/chebifier/inconsistency_resolution.py +++ b/chebifier/inconsistency_resolution.py @@ -1,8 +1,9 @@ import csv import os -import torch from pathlib import Path +import torch + def get_disjoint_groups(disjoint_files): if disjoint_files is None: diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py index 76a3cde..fc36a87 100644 --- a/chebifier/model_registry.py +++ b/chebifier/model_registry.py @@ -4,15 +4,15 @@ WMVwithPPVNPVEnsemble, ) from chebifier.prediction_models import ( + ChEBILookupPredictor, ChemlogPeptidesPredictor, ElectraPredictor, ResGatedPredictor, - ChEBILookupPredictor, ) from chebifier.prediction_models.c3p_predictor import C3PPredictor from chebifier.prediction_models.chemlog_predictor import ( - ChemlogXMolecularEntityPredictor, ChemlogOrganoXCompoundPredictor, + ChemlogXMolecularEntityPredictor, ) ENSEMBLES = { diff --git a/chebifier/prediction_models/base_predictor.py b/chebifier/prediction_models/base_predictor.py index a175366..37851b2 100644 --- a/chebifier/prediction_models/base_predictor.py +++ b/chebifier/prediction_models/base_predictor.py @@ -1,7 +1,7 @@ import json from abc import ABC -from chebifier import modelwise_smiles_lru_cache +from .._custom_cache import modelwise_smiles_lru_cache class BasePredictor(ABC): diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 5a402d1..5581e6e 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -2,8 +2,8 @@ import tqdm -from .base_predictor import BasePredictor from .. import modelwise_smiles_lru_cache +from .base_predictor import BasePredictor AA_DICT = { "A": "L-alanine", diff --git a/chebifier/utils.py b/chebifier/utils.py index e6fefae..d9610f7 100644 --- a/chebifier/utils.py +++ b/chebifier/utils.py @@ -1,10 +1,13 @@ +import importlib.resources import os +import pickle +import fastobo import networkx as nx import requests -import fastobo +import yaml + from chebifier.hugging_face import download_model_files -import pickle def load_chebi_graph(filename=None): @@ -123,9 +126,27 @@ def get_disjoint_files(): return disjoint_files -if __name__ == "__main__": - # chebi_graph = build_chebi_graph(chebi_version=241) - # save the graph to a file - # pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb")) - chebi_graph = load_chebi_graph() - print(chebi_graph) +def get_default_configs(): + default_config_name = "ensemble.yml" + print(f"Using default ensemble configuration from {default_config_name}") + with ( + importlib.resources.files("chebifier") + .joinpath(default_config_name) + .open("r") as f + ): + return yaml.safe_load(f) + + +def process_config(config, model_registry): + new_config = {} + for model_name, entry in config.items(): + if "load_model" in entry: + if entry["load_model"] not in model_registry: + raise ValueError( + f"Model {entry['load_model']} not found in model registry. " + f"Available models are: {','.join(model_registry.keys())}." + ) + new_config[model_name] = {**model_registry[entry["load_model"]], **entry} + else: + new_config[model_name] = entry + return new_config