From 187572bb093dd1195a4f4a7200f02230c1355017 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 19 Aug 2025 11:19:16 +0200 Subject: [PATCH 1/6] Simplify getting started This PR updates the BaseEnsemble constructor to allow the following: 1. Passing a string or path to the configuration 2. Not passing a configuration at all, which will automatically load the default configuration. This is now the default, since most users won't want to have to configure it (it should have reasonable defaults) --- README.md | 12 ++++-------- chebifier/__init__.py | 5 +++++ chebifier/cli.py | 8 ++------ chebifier/ensemble/base_ensemble.py | 14 ++++++++++++-- chebifier/utils.py | 12 ++++++++++++ 5 files changed, 35 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index fcbbb37..8d59280 100644 --- a/README.md +++ b/README.md @@ -75,15 +75,11 @@ python -m chebifier predict --help You can also use the package programmatically: ```python -from chebifier.ensemble.base_ensemble import BaseEnsemble -import yaml +from chebifier import BaseEnsemble -# Load configuration from YAML file -with open('configs/example_config.yml', 'r') as f: - config = yaml.safe_load(f) - -# Instantiate ensemble model -ensemble = BaseEnsemble(config) +# Instantiate ensemble model. If desired, can pass +# a path to a configuration, like 'configs/example_config.yml' +ensemble = BaseEnsemble() # Make predictions smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"] diff --git a/chebifier/__init__.py b/chebifier/__init__.py index aa1e6ec..3ddcfe7 100644 --- a/chebifier/__init__.py +++ b/chebifier/__init__.py @@ -2,5 +2,10 @@ # even if multiple subpackages are imported later. from ._custom_cache import PerSmilesPerModelLRUCache +from chebifier.ensemble.base_ensemble import BaseEnsemble + +__all__ = [ + "BaseEnsemble", +] modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100) diff --git a/chebifier/cli.py b/chebifier/cli.py index c201187..1267c69 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -4,6 +4,7 @@ import yaml from chebifier.model_registry import ENSEMBLES +from chebifier.utils import get_default_configs @click.group() @@ -75,12 +76,7 @@ def predict( # Load configuration from YAML file if not ensemble_config: print("Using default ensemble configuration") - with ( - importlib.resources.files("chebifier") - .joinpath("ensemble.yml") - .open("r") as f - ): - config = yaml.safe_load(f) + config = get_default_configs() else: print(f"Loading ensemble configuration from {ensemble_config}") with open(ensemble_config, "r") as f: diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 6a3acef..e434281 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -1,23 +1,33 @@ import os import time +from pathlib import Path +from typing import Union import torch import tqdm +import yaml from chebifier.check_env import check_package_installed from chebifier.hugging_face import download_model_files from chebifier.inconsistency_resolution import PredictionSmoother from chebifier.prediction_models.base_predictor import BasePredictor -from chebifier.utils import get_disjoint_files, load_chebi_graph +from chebifier.utils import get_disjoint_files, load_chebi_graph, get_default_configs class BaseEnsemble: def __init__( self, - model_configs: dict, + model_configs: Union[str, Path, dict, None] = None, chebi_version: int = 241, resolve_inconsistencies: bool = True, ): + if model_configs is None: + model_configs = get_default_configs() + elif isinstance(model_configs, (str, Path)): + # Load configuration from YAML file + with open(model_configs) as file: + model_configs = yaml.safe_load(file) + # Deferred Import: To avoid circular import error from chebifier.model_registry import MODEL_TYPES diff --git a/chebifier/utils.py b/chebifier/utils.py index e6fefae..7a2e021 100644 --- a/chebifier/utils.py +++ b/chebifier/utils.py @@ -1,8 +1,11 @@ +import importlib.resources import os import networkx as nx import requests import fastobo +import yaml + from chebifier.hugging_face import download_model_files import pickle @@ -129,3 +132,12 @@ def get_disjoint_files(): # pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb")) chebi_graph = load_chebi_graph() print(chebi_graph) + + +def get_default_configs(): + with ( + importlib.resources.files("chebifier") + .joinpath("ensemble.yml") + .open("r") as f + ): + return yaml.safe_load(f) From e82f6d1755abc8a192730d697293d7ff05c53add Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 16 Sep 2025 16:39:55 +0200 Subject: [PATCH 2/6] remove BaseEnsemble from __all__ --- README.md | 2 +- chebifier/__init__.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/README.md b/README.md index 8d59280..d3b1114 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ python -m chebifier predict --help You can also use the package programmatically: ```python -from chebifier import BaseEnsemble +from chebifier.ensemble.base_ensemble import BaseEnsemble # Instantiate ensemble model. If desired, can pass # a path to a configuration, like 'configs/example_config.yml' diff --git a/chebifier/__init__.py b/chebifier/__init__.py index 3ddcfe7..bae4747 100644 --- a/chebifier/__init__.py +++ b/chebifier/__init__.py @@ -2,10 +2,6 @@ # even if multiple subpackages are imported later. from ._custom_cache import PerSmilesPerModelLRUCache -from chebifier.ensemble.base_ensemble import BaseEnsemble -__all__ = [ - "BaseEnsemble", -] modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100) From ed082896a57dd8c546809d7182826afe013f3852 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 16 Sep 2025 16:41:08 +0200 Subject: [PATCH 3/6] move config processing logic to base ensemble --- chebifier/cli.py | 32 +---------------------------- chebifier/ensemble/base_ensemble.py | 31 ++++++++++++++++++++-------- chebifier/utils.py | 27 ++++++++++++++++-------- 3 files changed, 41 insertions(+), 49 deletions(-) diff --git a/chebifier/cli.py b/chebifier/cli.py index 1267c69..6fb0c69 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -4,8 +4,6 @@ import yaml from chebifier.model_registry import ENSEMBLES -from chebifier.utils import get_default_configs - @click.group() def cli(): @@ -73,38 +71,10 @@ def predict( resolve_inconsistencies=True, ): """Predict ChEBI classes for SMILES strings using an ensemble model.""" - # Load configuration from YAML file - if not ensemble_config: - print("Using default ensemble configuration") - config = get_default_configs() - else: - print(f"Loading ensemble configuration from {ensemble_config}") - with open(ensemble_config, "r") as f: - config = yaml.safe_load(f) - - with ( - importlib.resources.files("chebifier") - .joinpath("model_registry.yml") - .open("r") as f - ): - model_registry = yaml.safe_load(f) - - new_config = {} - for model_name, entry in config.items(): - if "load_model" in entry: - if entry["load_model"] not in model_registry: - raise ValueError( - f"Model {entry['load_model']} not found in model registry. " - f"Available models are: {','.join(model_registry.keys())}." - ) - new_config[model_name] = {**model_registry[entry["load_model"]], **entry} - else: - new_config[model_name] = entry - config = new_config # Instantiate ensemble model ensemble = ENSEMBLES[ensemble_type]( - config, + ensemble_config, chebi_version=chebi_version, resolve_inconsistencies=resolve_inconsistencies, ) diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index e434281..925f4fc 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -6,12 +6,13 @@ import torch import tqdm import yaml +import importlib from chebifier.check_env import check_package_installed from chebifier.hugging_face import download_model_files from chebifier.inconsistency_resolution import PredictionSmoother from chebifier.prediction_models.base_predictor import BasePredictor -from chebifier.utils import get_disjoint_files, load_chebi_graph, get_default_configs +from chebifier.utils import get_disjoint_files, load_chebi_graph, get_default_configs, process_config class BaseEnsemble: @@ -21,22 +22,34 @@ def __init__( chebi_version: int = 241, resolve_inconsistencies: bool = True, ): - if model_configs is None: - model_configs = get_default_configs() - elif isinstance(model_configs, (str, Path)): - # Load configuration from YAML file - with open(model_configs) as file: - model_configs = yaml.safe_load(file) - # Deferred Import: To avoid circular import error from chebifier.model_registry import MODEL_TYPES + # Load configuration from YAML file + if not model_configs: + config = get_default_configs() + elif isinstance(model_configs, dict): + config = model_configs + else: + print(f"Loading ensemble configuration from {model_configs}") + with open(model_configs, "r") as f: + config = yaml.safe_load(f) + + with ( + importlib.resources.files("chebifier") + .joinpath("model_registry.yml") + .open("r") as f + ): + model_registry = yaml.safe_load(f) + + processed_configs = process_config(config, model_registry) + self.chebi_graph = load_chebi_graph() self.disjoint_files = get_disjoint_files() self.models = [] self.positive_prediction_threshold = 0.5 - for model_name, model_config in model_configs.items(): + for model_name, model_config in processed_configs.items(): model_cls = MODEL_TYPES[model_config["type"]] if "hugging_face" in model_config: hugging_face_kwargs = download_model_files(model_config["hugging_face"]) diff --git a/chebifier/utils.py b/chebifier/utils.py index 7a2e021..1b1a485 100644 --- a/chebifier/utils.py +++ b/chebifier/utils.py @@ -126,18 +126,27 @@ def get_disjoint_files(): return disjoint_files -if __name__ == "__main__": - # chebi_graph = build_chebi_graph(chebi_version=241) - # save the graph to a file - # pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb")) - chebi_graph = load_chebi_graph() - print(chebi_graph) - - def get_default_configs(): + default_config_name = "ensemble.yml" + print(f"Using default ensemble configuration from {default_config_name}") with ( importlib.resources.files("chebifier") - .joinpath("ensemble.yml") + .joinpath(default_config_name) .open("r") as f ): return yaml.safe_load(f) + + +def process_config(config, model_registry): + new_config = {} + for model_name, entry in config.items(): + if "load_model" in entry: + if entry["load_model"] not in model_registry: + raise ValueError( + f"Model {entry['load_model']} not found in model registry. " + f"Available models are: {','.join(model_registry.keys())}." + ) + new_config[model_name] = {**model_registry[entry["load_model"]], **entry} + else: + new_config[model_name] = entry + return new_config From ecff107242c79b843f198c1e082e42d3212e8321 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 16 Sep 2025 16:52:23 +0200 Subject: [PATCH 4/6] run pre-commit hooks --- chebifier/cli.py | 4 +--- chebifier/ensemble/base_ensemble.py | 13 +++++++++---- chebifier/inconsistency_resolution.py | 3 ++- chebifier/model_registry.py | 4 ++-- chebifier/prediction_models/chemlog_predictor.py | 2 +- chebifier/utils.py | 8 ++++---- 6 files changed, 19 insertions(+), 15 deletions(-) diff --git a/chebifier/cli.py b/chebifier/cli.py index 6fb0c69..a3db5d6 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -1,10 +1,8 @@ -import importlib.resources - import click -import yaml from chebifier.model_registry import ENSEMBLES + @click.group() def cli(): """Command line interface for Chebifier.""" diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 925f4fc..ad4efba 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -1,3 +1,4 @@ +import importlib import os import time from pathlib import Path @@ -6,13 +7,17 @@ import torch import tqdm import yaml -import importlib from chebifier.check_env import check_package_installed from chebifier.hugging_face import download_model_files from chebifier.inconsistency_resolution import PredictionSmoother from chebifier.prediction_models.base_predictor import BasePredictor -from chebifier.utils import get_disjoint_files, load_chebi_graph, get_default_configs, process_config +from chebifier.utils import ( + get_default_configs, + get_disjoint_files, + load_chebi_graph, + process_config, +) class BaseEnsemble: @@ -37,8 +42,8 @@ def __init__( with ( importlib.resources.files("chebifier") - .joinpath("model_registry.yml") - .open("r") as f + .joinpath("model_registry.yml") + .open("r") as f ): model_registry = yaml.safe_load(f) diff --git a/chebifier/inconsistency_resolution.py b/chebifier/inconsistency_resolution.py index f6640c2..6a9a45e 100644 --- a/chebifier/inconsistency_resolution.py +++ b/chebifier/inconsistency_resolution.py @@ -1,8 +1,9 @@ import csv import os -import torch from pathlib import Path +import torch + def get_disjoint_groups(disjoint_files): if disjoint_files is None: diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py index 76a3cde..fc36a87 100644 --- a/chebifier/model_registry.py +++ b/chebifier/model_registry.py @@ -4,15 +4,15 @@ WMVwithPPVNPVEnsemble, ) from chebifier.prediction_models import ( + ChEBILookupPredictor, ChemlogPeptidesPredictor, ElectraPredictor, ResGatedPredictor, - ChEBILookupPredictor, ) from chebifier.prediction_models.c3p_predictor import C3PPredictor from chebifier.prediction_models.chemlog_predictor import ( - ChemlogXMolecularEntityPredictor, ChemlogOrganoXCompoundPredictor, + ChemlogXMolecularEntityPredictor, ) ENSEMBLES = { diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 5a402d1..5581e6e 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -2,8 +2,8 @@ import tqdm -from .base_predictor import BasePredictor from .. import modelwise_smiles_lru_cache +from .base_predictor import BasePredictor AA_DICT = { "A": "L-alanine", diff --git a/chebifier/utils.py b/chebifier/utils.py index 1b1a485..d9610f7 100644 --- a/chebifier/utils.py +++ b/chebifier/utils.py @@ -1,13 +1,13 @@ import importlib.resources import os +import pickle +import fastobo import networkx as nx import requests -import fastobo import yaml from chebifier.hugging_face import download_model_files -import pickle def load_chebi_graph(filename=None): @@ -131,8 +131,8 @@ def get_default_configs(): print(f"Using default ensemble configuration from {default_config_name}") with ( importlib.resources.files("chebifier") - .joinpath(default_config_name) - .open("r") as f + .joinpath(default_config_name) + .open("r") as f ): return yaml.safe_load(f) From 9cf6973cb535f040bea49aab9ad27cb90f46e624 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 16 Sep 2025 17:18:15 +0200 Subject: [PATCH 5/6] Fix circular imports --- README.md | 2 +- chebifier/__init__.py | 10 +++++++--- chebifier/_custom_cache.py | 8 ++++++++ chebifier/prediction_models/base_predictor.py | 2 +- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index d3b1114..8d59280 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ python -m chebifier predict --help You can also use the package programmatically: ```python -from chebifier.ensemble.base_ensemble import BaseEnsemble +from chebifier import BaseEnsemble # Instantiate ensemble model. If desired, can pass # a path to a configuration, like 'configs/example_config.yml' diff --git a/chebifier/__init__.py b/chebifier/__init__.py index bae4747..0f7a891 100644 --- a/chebifier/__init__.py +++ b/chebifier/__init__.py @@ -1,7 +1,11 @@ # Note: The top-level package __init__.py runs only once, # even if multiple subpackages are imported later. -from ._custom_cache import PerSmilesPerModelLRUCache +from ._custom_cache import PerSmilesPerModelLRUCache, modelwise_smiles_lru_cache +from chebifier.ensemble.base_ensemble import BaseEnsemble - -modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100) +__all__ = [ + "BaseEnsemble", + "PerSmilesPerModelLRUCache", + "modelwise_smiles_lru_cache", +] diff --git a/chebifier/_custom_cache.py b/chebifier/_custom_cache.py index 38b500f..81204aa 100644 --- a/chebifier/_custom_cache.py +++ b/chebifier/_custom_cache.py @@ -6,6 +6,11 @@ from functools import wraps from typing import Any, Callable +__all__ = [ + "PerSmilesPerModelLRUCache", + "modelwise_smiles_lru_cache", +] + class PerSmilesPerModelLRUCache: """ @@ -206,3 +211,6 @@ def _load_cache(self) -> None: self._cache = loaded except Exception as e: print(f"[Cache Load Error] {e}") + + +modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100) diff --git a/chebifier/prediction_models/base_predictor.py b/chebifier/prediction_models/base_predictor.py index a175366..37851b2 100644 --- a/chebifier/prediction_models/base_predictor.py +++ b/chebifier/prediction_models/base_predictor.py @@ -1,7 +1,7 @@ import json from abc import ABC -from chebifier import modelwise_smiles_lru_cache +from .._custom_cache import modelwise_smiles_lru_cache class BasePredictor(ABC): From edcf93ef4cd1f4ea8237c67e6c882cdcd3f7d715 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 16 Sep 2025 17:18:55 +0200 Subject: [PATCH 6/6] Update __init__.py --- chebifier/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebifier/__init__.py b/chebifier/__init__.py index 0f7a891..a4f770c 100644 --- a/chebifier/__init__.py +++ b/chebifier/__init__.py @@ -2,7 +2,7 @@ # even if multiple subpackages are imported later. from ._custom_cache import PerSmilesPerModelLRUCache, modelwise_smiles_lru_cache -from chebifier.ensemble.base_ensemble import BaseEnsemble +from .ensemble.base_ensemble import BaseEnsemble __all__ = [ "BaseEnsemble",