From 689c5ddc56e56c3275c2ea1300cc94d2e6af2153 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 17:58:16 +0200 Subject: [PATCH 01/13] dynamic imports for readers --- chebai/preprocessing/reader.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index aa9960f9..c737df75 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -5,10 +5,7 @@ from itertools import islice from typing import Any, Dict, List, Optional -import deepsmiles -import selfies as sf from pysmiles.read_smiles import _tokenize -from transformers import RobertaTokenizerFast from chebai.preprocessing.collate import DefaultCollator, RaggedCollator @@ -205,6 +202,8 @@ class DeepChemDataReader(ChemDataReader): """ def __init__(self, *args, **kwargs): + import deepsmiles + super().__init__(*args, **kwargs) self.converter = deepsmiles.Converter(rings=True, branches=True) self.error_count = 0 @@ -279,6 +278,8 @@ def __init__( vsize: int = 4000, **kwargs, ): + from transformers import RobertaTokenizerFast + super().__init__(*args, **kwargs) self.tokenizer = RobertaTokenizerFast.from_pretrained( data_path, max_len=max_len @@ -312,6 +313,8 @@ def __init__( vsize: int = 4000, **kwargs, ): + import selfies as sf + super().__init__(*args, **kwargs) self.error_count = 0 sf.set_semantic_constraints("hypervalent") @@ -323,6 +326,8 @@ def name(cls) -> str: def _read_data(self, raw_data: str) -> Optional[List[int]]: """Read and tokenize raw data using SELFIES.""" + import selfies as sf + try: tokenized = sf.split_selfies(sf.encoder(raw_data.strip(), strict=True)) tokenized = [self._get_token_index(v) for v in tokenized] From bbcf6352c626bb74f1a89e63c3826969be24b4db Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:16:59 +0200 Subject: [PATCH 02/13] dyamic import for base dm --- chebai/preprocessing/datasets/base.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 4a1898bc..a229e7af 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1,24 +1,21 @@ import os import random from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union import lightning as pl -import networkx as nx import pandas as pd import torch import tqdm -from iterstrat.ml_stratifiers import ( - MultilabelStratifiedKFold, - MultilabelStratifiedShuffleSplit, -) from lightning.pytorch.core.datamodule import LightningDataModule from lightning_utilities.core.rank_zero import rank_zero_info -from sklearn.model_selection import StratifiedShuffleSplit from torch.utils.data import DataLoader from chebai.preprocessing import reader as dr +if TYPE_CHECKING: + import networkx as nx + class XYBaseDataModule(LightningDataModule): """ @@ -818,7 +815,7 @@ def _download_required_data(self) -> str: pass @abstractmethod - def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph": """ Extracts the class hierarchy from the data. Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from @@ -833,7 +830,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: pass @abstractmethod - def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: + def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame: """ Converts the graph to a raw dataset. Uses the graph created by `_extract_class_hierarchy` method to extract the @@ -848,7 +845,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: pass @abstractmethod - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: """ Selects classes from the dataset based on a specified criteria. @@ -1023,6 +1020,9 @@ def get_test_split( Raises: ValueError: If the DataFrame does not contain a column named "labels". """ + from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit + from sklearn.model_selection import StratifiedShuffleSplit + print("Get test data split") labels_list = df["labels"].tolist() @@ -1060,6 +1060,12 @@ def get_train_val_splits_given_test( and validation DataFrames. The keys are the names of the train and validation sets, and the values are the corresponding DataFrames. """ + from iterstrat.ml_stratifiers import ( + MultilabelStratifiedKFold, + MultilabelStratifiedShuffleSplit, + ) + from sklearn.model_selection import StratifiedShuffleSplit + print("Split dataset into train / val with given test set") test_ids = test_df["ident"].tolist() From 0bfd79d246be0279d082c42415aea2e73672548f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:17:19 +0200 Subject: [PATCH 03/13] dynamic import for chebi dm --- chebai/preprocessing/datasets/chebi.py | 30 ++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 1df144d9..06aa1e70 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -13,17 +13,18 @@ import pickle from abc import ABC from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union -import fastobo -import networkx as nx import pandas as pd -import requests import torch from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset +if TYPE_CHECKING: + import fastobo + import networkx as nx + # exclude some entities from the dataset because the violate disjointness axioms CHEBI_BLACKLIST = [ 194026, @@ -212,6 +213,8 @@ def _load_chebi(self, version: int) -> str: Returns: str: The file path of the loaded ChEBI ontology. """ + import requests + chebi_name = self.raw_file_names_dict["chebi"] chebi_path = os.path.join(self.raw_dir, chebi_name) if not os.path.isfile(chebi_path): @@ -223,7 +226,7 @@ def _load_chebi(self, version: int) -> str: open(chebi_path, "wb").write(r.content) return chebi_path - def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph": """ Extracts the class hierarchy from the ChEBI ontology. Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from @@ -235,6 +238,9 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: Returns: nx.DiGraph: The class hierarchy. """ + import fastobo + import networkx as nx + with open(data_path, encoding="utf-8") as chebi: chebi = "\n".join(line for line in chebi if not line.startswith("xref:")) @@ -262,7 +268,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: print("Compute transitive closure") return nx.transitive_closure_dag(g) - def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: + def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame: """ Converts the graph to a raw dataset. Uses the graph created by `_extract_class_hierarchy` method to extract the @@ -274,6 +280,8 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: Returns: pd.DataFrame: The raw dataset created from the graph. """ + import networkx as nx + smiles = nx.get_node_attributes(g, "smiles") names = nx.get_node_attributes(g, "name") @@ -574,7 +582,7 @@ def _name(self) -> str: """ return f"ChEBI{self.THRESHOLD}" - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: """ Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold. @@ -599,6 +607,8 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: - The `THRESHOLD` attribute should be defined in the subclass of this class. - Nodes without a 'smiles' attribute are ignored in the successor count. """ + import networkx as nx + smiles = nx.get_node_attributes(g, "smiles") nodes = list( sorted( @@ -731,7 +741,7 @@ def processed_dir_main(self) -> str: "processed", ) - def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, chebi_path: str) -> "nx.DiGraph": """ Extracts a subset of ChEBI based on subclasses of the top class ID. @@ -786,7 +796,7 @@ def chebi_to_int(s: str) -> int: return int(s[s.index(":") + 1 :]) -def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]: +def term_callback(doc: "fastobo.term.TermFrame") -> Union[Dict, bool]: """ Extracts information from a ChEBI term document. This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents, @@ -803,6 +813,8 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]: - "name": The name of the ChEBI term. - "smiles": The SMILES string associated with the ChEBI term, if available. """ + import fastobo + parts = set() parents = [] name = None From 109723cf1c3931f393fa64748f27c3c13a6ece60 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:28:12 +0200 Subject: [PATCH 04/13] dynamic imports for log and struc --- chebai/loggers/custom.py | 3 ++- chebai/preprocessing/structures.py | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/chebai/loggers/custom.py b/chebai/loggers/custom.py index d1b4282d..04c48849 100644 --- a/chebai/loggers/custom.py +++ b/chebai/loggers/custom.py @@ -2,7 +2,6 @@ from datetime import datetime from typing import List, Literal, Optional, Union -import wandb from lightning.fabric.utilities.types import _PATH from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import WandbLogger @@ -105,6 +104,8 @@ def set_fold(self, fold: int) -> None: Args: fold (int): Cross-validation fold number. """ + import wandb + if fold != self._fold: self._fold = fold # Start new experiment diff --git a/chebai/preprocessing/structures.py b/chebai/preprocessing/structures.py index 5cfe7966..2ab5de5d 100644 --- a/chebai/preprocessing/structures.py +++ b/chebai/preprocessing/structures.py @@ -1,8 +1,10 @@ -from typing import Any, Tuple, Union +from typing import TYPE_CHECKING, Any, Tuple, Union -import networkx as nx import torch +if TYPE_CHECKING: + import networkx as nx + class XYData(torch.utils.data.Dataset): """ @@ -129,6 +131,8 @@ def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]: Returns: A tuple of molecular graphs with node attributes on the specified device. """ + import networkx as nx + l_ = [] for g in self.x: graph = g.copy() From b87129da8fdd5b2655dceb24f4295519af09baa4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:28:41 +0200 Subject: [PATCH 05/13] to avoid access to pubchem file: dynamic import --- chebai/loss/bce_weighted.py | 3 ++- chebai/loss/semantic.py | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 1d5ea763..993d535e 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -5,7 +5,6 @@ from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor -from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed class BCEWeighted(torch.nn.BCEWithLogitsLoss): @@ -27,6 +26,8 @@ def __init__( data_extractor: Optional[XYBaseDataModule] = None, **kwargs, ): + from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed + self.beta = beta if isinstance(data_extractor, LabeledUnlabeledMixed): data_extractor = data_extractor.labeled diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 18485269..877e0060 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -2,14 +2,16 @@ import math import os import pickle -from typing import List, Literal, Union +from typing import TYPE_CHECKING, List, Literal, Union import torch from chebai.loss.bce_weighted import BCEWeighted from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor -from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed + +if TYPE_CHECKING: + from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed class ImplicationLoss(torch.nn.Module): @@ -68,6 +70,8 @@ def __init__( multiply_with_base_loss: bool = True, no_grads: bool = False, ): + from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed + super().__init__() # automatically choose labeled subset for implication filter in case of mixed dataset if isinstance(data_extractor, LabeledUnlabeledMixed): From 610d9d4b2b50cfb362ceda62057121aa0367346c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:36:56 +0200 Subject: [PATCH 06/13] fix action error: add string literals --- chebai/loss/semantic.py | 2 +- chebai/preprocessing/structures.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 877e0060..89abb175 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -342,7 +342,7 @@ class DisjointLoss(ImplicationLoss): def __init__( self, path_to_disjointness: str, - data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed], + data_extractor: Union[_ChEBIDataExtractor, "LabeledUnlabeledMixed"], base_loss: torch.nn.Module = None, disjoint_loss_weight: float = 100, **kwargs, diff --git a/chebai/preprocessing/structures.py b/chebai/preprocessing/structures.py index 2ab5de5d..4a69ea4f 100644 --- a/chebai/preprocessing/structures.py +++ b/chebai/preprocessing/structures.py @@ -121,7 +121,7 @@ class XYMolData(XYData): kwargs: Additional fields to store in the dataset. """ - def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]: + def to_x(self, device: torch.device) -> Tuple["nx.Graph", ...]: """ Moves the node attributes of the molecular graphs to the specified device. From 925eea556439e8f2f2f82040a34c4e03d7bfa067 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:42:26 +0200 Subject: [PATCH 07/13] add inference dependencies --- pyproject.toml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 5723d78a..99728a13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,16 @@ dev = ["black", "isort", "pre-commit"] plot = ["matplotlib", "seaborn"] wandb = ["wandb"] +inference = [ + "numpy", + "pandas", + "torch", + "transformers", + "pysmiles==1.1.2", + "rdkit", + "lightning>=2.5", +] + [tool.setuptools] include-package-data = true license-files = ["LICEN[CS]E*"] From 2a039b6425d6ea4170a8b5d9e39085276f819db2 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 10 Aug 2025 00:01:01 +0200 Subject: [PATCH 08/13] fix nx dynamic import --- chebai/preprocessing/datasets/chebi.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index e9afd854..cbd04895 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -783,8 +783,10 @@ def _extract_class_hierarchy(self, chebi_path: str) -> "nx.DiGraph": ) return g - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: """Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself).""" + import networkx as nx + smiles = nx.get_node_attributes(g, "smiles") nodes = list( sorted( From 1718315d71dc292b23df72db0f2fd7296c2a37e1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 2 Sep 2025 11:27:39 +0200 Subject: [PATCH 09/13] move inference dep to main dep. --- pyproject.toml | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 99728a13..92c2359a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,42 +15,44 @@ maintainers = [ readme = "README.md" license = { text = "AGPL-3.0" } requires-python = ">=3.9,<3.13" + dependencies = [ - "networkx", "numpy", "pandas", - "requests", - "scikit-learn", - "scipy", "torch", "transformers", - "fastobo", "pysmiles==1.1.2", "rdkit", - "selfies", "lightning>=2.5", - "jsonargparse[signatures]>=4.17", - "omegaconf", - "deepsmiles", - "iterative-stratification", - "torchmetrics" ] [project.optional-dependencies] -dev = ["black", "isort", "pre-commit"] +linter = ["black", "isort", "pre-commit"] plot = ["matplotlib", "seaborn"] wandb = ["wandb"] -inference = [ +dev = [ + "networkx", "numpy", "pandas", + "requests", + "scikit-learn", + "scipy", "torch", "transformers", + "fastobo", "pysmiles==1.1.2", "rdkit", + "selfies", "lightning>=2.5", + "jsonargparse[signatures]>=4.17", + "omegaconf", + "deepsmiles", + "iterative-stratification", + "torchmetrics" ] + [tool.setuptools] include-package-data = true license-files = ["LICEN[CS]E*"] From d558826e4e7c5fc79031208e2ff381da9fb1c715 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 2 Sep 2025 11:31:15 +0200 Subject: [PATCH 10/13] install dev dep for test --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0ad21157..e6dcd009 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,7 +24,7 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade pip setuptools wheel python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - python -m pip install -e . + python -m pip install -e .[dev] - name: Display Python & Installed Packages run: | From e30ca3de25cff908b04c478c8b96895008d32cf7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 2 Sep 2025 11:46:10 +0200 Subject: [PATCH 11/13] include linters in dev dep --- pyproject.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 92c2359a..21ae71ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ dependencies = [ ] [project.optional-dependencies] -linter = ["black", "isort", "pre-commit"] plot = ["matplotlib", "seaborn"] wandb = ["wandb"] @@ -49,7 +48,10 @@ dev = [ "omegaconf", "deepsmiles", "iterative-stratification", - "torchmetrics" + "torchmetrics", + "black", + "isort", + "pre-commit", ] From ea2095fedef89a8d6cf547d41722b908e174e7b1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 13 Sep 2025 14:48:55 +0200 Subject: [PATCH 12/13] seperate dependency for linters --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 21ae71ca..ff78b28d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,9 @@ dev = [ "deepsmiles", "iterative-stratification", "torchmetrics", +] + +linters = [ "black", "isort", "pre-commit", From d51e4b5d3d7d2db96be980add5621f07ab394c60 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 19 Sep 2025 14:11:11 +0200 Subject: [PATCH 13/13] remove duplicate dependencies, pin lightning to 2.5.1, only search in chebai directory --- pyproject.toml | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1cbf64f3..eb75643d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "transformers", "pysmiles==1.1.2", "rdkit", - "lightning>=2.5", + "lightning==2.5.1", ] [project.optional-dependencies] @@ -31,18 +31,11 @@ wandb = ["wandb"] dev = [ "networkx", - "numpy", - "pandas", "requests", "scikit-learn", "scipy", - "torch", - "transformers", "fastobo", - "pysmiles==1.1.2", - "rdkit", "selfies", - "lightning<=2.5.1", "jsonargparse[signatures]>=4.17", "omegaconf", "deepsmiles", @@ -63,7 +56,13 @@ license-files = ["LICEN[CS]E*"] [tool.setuptools.packages.find] where = ["."] -exclude = ["tests*"] +include = ["chebai*"] [tool.setuptools.package-data] -"*" = ["**/*.txt", "**/*.json"] +# Include essential config files and preprocessing tokens +"chebai" = [ + "preprocessing/bin/**/*.txt", + "preprocessing/bin/**/*.json", + "configs/**/*.yml", + "configs/**/*.yaml", +]