From d85bf8d1d3fbda513647daba0f612bb77c1ef441 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:52:06 +0200 Subject: [PATCH 1/6] add lint workflow --- .github/workflows/black.yml | 10 ---------- .github/workflows/lint.yml | 26 ++++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 10 deletions(-) delete mode 100644 .github/workflows/black.yml create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml deleted file mode 100644 index b04fb15..0000000 --- a/.github/workflows/black.yml +++ /dev/null @@ -1,10 +0,0 @@ -name: Lint - -on: [push, pull_request] - -jobs: - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: psf/black@stable diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..bb9154f --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' # or any version your project uses + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install black==25.1.0 ruff==0.12.2 + + - name: Run Black + run: black --check . + + - name: Run Ruff (no formatting) + run: ruff check . --no-fix From 838e3ad22fc9116ddffb74409000787e4d7b7a09 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 19:07:03 +0200 Subject: [PATCH 2/6] update pre-commit --- .pre-commit-config.yaml | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 866c153..cbb7284 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,31 @@ repos: -#- repo: https://github.com/PyCQA/isort -# rev: "5.12.0" -# hooks: -# - id: isort - repo: https://github.com/psf/black - rev: "22.10.0" + rev: "25.1.0" hooks: - - id: black \ No newline at end of file + - id: black + - id: black-jupyter # for formatting jupyter-notebook + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + args: ["--profile=black"] + +- repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.2 + hooks: + - id: ruff + args: [--fix] From 9c765c0536e55be92364790ad7f40521d024d851 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 20:20:54 +0200 Subject: [PATCH 3/6] ruff format -- all --- README.md | 9 ++--- chebai_graph/models/gin_net.py | 9 +++-- chebai_graph/preprocessing/__init__.py | 38 ++++++++++++++++++- chebai_graph/preprocessing/collate.py | 4 +- .../preprocessing/datasets/pubchem.py | 3 +- chebai_graph/preprocessing/properties.py | 6 +-- .../preprocessing/property_encoder.py | 10 ++--- chebai_graph/preprocessing/reader.py | 24 ++++++------ .../preprocessing/transform_unlabeled.py | 1 + configs/data/chebi50_graph.yml | 2 +- configs/data/pubchem_graph.yml | 2 +- configs/loss/mask_pretraining.yml | 2 +- configs/model/gnn.yml | 2 +- configs/model/gnn_attention.yml | 2 +- configs/model/gnn_gine.yml | 2 +- configs/model/gnn_res_gated.yml | 2 +- configs/model/gnn_resgated_pretrain.yml | 2 +- pyproject.toml | 2 +- 18 files changed, 79 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 20b0eab..35f64d3 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Graph-based models for molecular property prediction and ontology classification ## Installation -To install this repository, download it and run +To install this repository, download it and run ```bash pip install . @@ -30,7 +30,7 @@ Replace: - `${TORCH}` with a PyTorch version (e.g., `2.6.0`; for later versions, check first if they are compatible with torch_scatter and torch_geometric) - `${CUDA}` with e.g. `cpu`, `cu118`, or `cu121` depending on your system and CUDA version -If you already have `torch` installed, make sure that `torch_scatter` and `torch_geometric` are compatible with your +If you already have `torch` installed, make sure that `torch_scatter` and `torch_geometric` are compatible with your PyTorch version and are installed with the same CUDA version. For a full list of currently available PyTorch versions and CUDA compatibility, please refer to libraries' official documentation: @@ -68,11 +68,10 @@ my_projects/ ### Ontology Prediction -This example command trains a Residual Gated Graph Convolutional Network on the ChEBI50 dataset (see [wiki](https://github.com/ChEB-AI/python-chebai/wiki/Data-Management)). -The dataset has a customizable list of properties for atoms, bonds and molecules that are added to the graph. +This example command trains a Residual Gated Graph Convolutional Network on the ChEBI50 dataset (see [wiki](https://github.com/ChEB-AI/python-chebai/wiki/Data-Management)). +The dataset has a customizable list of properties for atoms, bonds and molecules that are added to the graph. The list can be found in the `configs/data/chebi50_graph_properties.yml` file. ```bash python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/csv_logger.yml --model=../python-chebai-graph/configs/model/gnn_res_gated.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml ``` - diff --git a/chebai_graph/models/gin_net.py b/chebai_graph/models/gin_net.py index 75c2c45..6fed4c6 100644 --- a/chebai_graph/models/gin_net.py +++ b/chebai_graph/models/gin_net.py @@ -1,10 +1,11 @@ +import typing + +import torch +import torch.nn.functional as F +import torch_geometric from torch_scatter import scatter_add from chebai_graph.models.graph import GraphBaseNet -import torch_geometric -import torch.nn.functional as F -import torch -import typing class AggregateMLP(torch.nn.Module): diff --git a/chebai_graph/preprocessing/__init__.py b/chebai_graph/preprocessing/__init__.py index 2b98ba8..80488cc 100644 --- a/chebai_graph/preprocessing/__init__.py +++ b/chebai_graph/preprocessing/__init__.py @@ -1 +1,37 @@ -from chebai_graph.preprocessing.properties import * +from chebai_graph.preprocessing.properties import ( + AtomAromaticity, + AtomCharge, + AtomChirality, + AtomHybridization, + AtomNumHs, + AtomProperty, + AtomType, + BondAromaticity, + BondInRing, + BondProperty, + BondType, + MolecularProperty, + MoleculeNumRings, + MoleculeProperty, + NumAtomBonds, + RDKit2DNormalized, +) + +__all__ = [ + "AtomAromaticity", + "AtomCharge", + "AtomChirality", + "AtomHybridization", + "AtomNumHs", + "AtomProperty", + "AtomType", + "BondAromaticity", + "BondInRing", + "BondProperty", + "BondType", + "MolecularProperty", + "MoleculeNumRings", + "MoleculeProperty", + "NumAtomBonds", + "RDKit2DNormalized", +] diff --git a/chebai_graph/preprocessing/collate.py b/chebai_graph/preprocessing/collate.py index 2c5f696..4be36cf 100644 --- a/chebai_graph/preprocessing/collate.py +++ b/chebai_graph/preprocessing/collate.py @@ -1,11 +1,11 @@ from typing import Dict import torch +from chebai.preprocessing.collate import RaggedCollator from torch_geometric.data import Data as GeomData from torch_geometric.data.collate import collate as graph_collate -from chebai_graph.preprocessing.structures import XYGraphData -from chebai.preprocessing.collate import RaggedCollator +from chebai_graph.preprocessing.structures import XYGraphData class GraphCollator(RaggedCollator): diff --git a/chebai_graph/preprocessing/datasets/pubchem.py b/chebai_graph/preprocessing/datasets/pubchem.py index 210b7ab..6f5d118 100644 --- a/chebai_graph/preprocessing/datasets/pubchem.py +++ b/chebai_graph/preprocessing/datasets/pubchem.py @@ -1,6 +1,7 @@ -from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn from chebai.preprocessing.datasets.pubchem import PubchemChem +from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn + class PubChemGraphProperties(GraphPropertiesMixIn, PubchemChem): pass diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties.py index 543a344..2b3acf8 100644 --- a/chebai_graph/preprocessing/properties.py +++ b/chebai_graph/preprocessing/properties.py @@ -6,11 +6,11 @@ from descriptastorus.descriptors import rdNormalizedDescriptors from chebai_graph.preprocessing.property_encoder import ( - PropertyEncoder, - IndexEncoder, - OneHotEncoder, AsIsEncoder, BoolEncoder, + IndexEncoder, + OneHotEncoder, + PropertyEncoder, ) diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index f998396..532c91f 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -1,11 +1,11 @@ import abc +import inspect import os +import sys +from itertools import islice from typing import Optional import torch -import sys -from itertools import islice -import inspect class PropertyEncoder(abc.ABC): @@ -94,7 +94,7 @@ def on_finish(self): def encode(self, token): """Returns a unique number for each token, automatically adds new tokens to the cache.""" - if not str(token) in self.cache: + if str(token) not in self.cache: self.cache[(str(token))] = len(self.cache) return torch.tensor([self.cache[str(token)] + self.offset]) @@ -111,7 +111,7 @@ def get_encoding_length(self) -> int: @property def name(self): - return f"one_hot" + return "one_hot" def on_start(self, property_values): """To get correct number of classes during encoding, cache unique tokens beforehand""" diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index e48c6ab..1d66de0 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -1,19 +1,17 @@ -import importlib - -from torch_geometric.utils import from_networkx -from typing import Tuple, Mapping, Optional, List +import os +from typing import List, Optional -import importlib +import chebai.preprocessing.reader as dr import networkx as nx -import os -import torch -import rdkit.Chem as Chem import pysmiles as ps -import chebai.preprocessing.reader as dr -from chebai_graph.preprocessing.collate import GraphCollator -import chebai_graph.preprocessing.properties as properties +import rdkit.Chem as Chem +import torch +from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from torch_geometric.data import Data as GeomData -from lightning_utilities.core.rank_zero import rank_zero_warn, rank_zero_info +from torch_geometric.utils import from_networkx + +import chebai_graph.preprocessing.properties as properties +from chebai_graph.preprocessing.collate import GraphCollator class GraphPropertyReader(dr.DataReader): @@ -44,7 +42,7 @@ def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: else: try: Chem.SanitizeMol(mol) - except Exception as e: + except Exception: rank_zero_warn(f"Rdkit failed at sanitizing {smiles}") self.failed_counter += 1 self.mol_object_buffer[smiles] = mol diff --git a/chebai_graph/preprocessing/transform_unlabeled.py b/chebai_graph/preprocessing/transform_unlabeled.py index 3920659..0cc4b35 100644 --- a/chebai_graph/preprocessing/transform_unlabeled.py +++ b/chebai_graph/preprocessing/transform_unlabeled.py @@ -1,4 +1,5 @@ import random + import torch diff --git a/configs/data/chebi50_graph.yml b/configs/data/chebi50_graph.yml index 14cc489..19c8753 100644 --- a/configs/data/chebi50_graph.yml +++ b/configs/data/chebi50_graph.yml @@ -1 +1 @@ -class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData \ No newline at end of file +class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData diff --git a/configs/data/pubchem_graph.yml b/configs/data/pubchem_graph.yml index af04491..c21f188 100644 --- a/configs/data/pubchem_graph.yml +++ b/configs/data/pubchem_graph.yml @@ -16,4 +16,4 @@ init_args: - chebai_graph.preprocessing.properties.BondInRing - chebai_graph.preprocessing.properties.BondAromaticity #- chebai_graph.preprocessing.properties.MoleculeNumRings - - chebai_graph.preprocessing.properties.RDKit2DNormalized \ No newline at end of file + - chebai_graph.preprocessing.properties.RDKit2DNormalized diff --git a/configs/loss/mask_pretraining.yml b/configs/loss/mask_pretraining.yml index c677559..6d2a560 100644 --- a/configs/loss/mask_pretraining.yml +++ b/configs/loss/mask_pretraining.yml @@ -1 +1 @@ -class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss \ No newline at end of file +class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss diff --git a/configs/model/gnn.yml b/configs/model/gnn.yml index b0b119d..f85fa76 100644 --- a/configs/model/gnn.yml +++ b/configs/model/gnn.yml @@ -7,4 +7,4 @@ init_args: hidden_length: 512 dropout_rate: 0.1 n_conv_layers: 3 - n_linear_layers: 3 \ No newline at end of file + n_linear_layers: 3 diff --git a/configs/model/gnn_attention.yml b/configs/model/gnn_attention.yml index b1c553b..0c11ced 100644 --- a/configs/model/gnn_attention.yml +++ b/configs/model/gnn_attention.yml @@ -8,4 +8,4 @@ init_args: dropout_rate: 0.1 n_conv_layers: 5 n_linear_layers: 3 - n_heads: 5 \ No newline at end of file + n_heads: 5 diff --git a/configs/model/gnn_gine.yml b/configs/model/gnn_gine.yml index 0d0ed20..c84ea61 100644 --- a/configs/model/gnn_gine.yml +++ b/configs/model/gnn_gine.yml @@ -8,4 +8,4 @@ init_args: n_conv_layers: 5 n_linear_layers: 3 n_atom_properties: 125 - n_bond_properties: 5 \ No newline at end of file + n_bond_properties: 5 diff --git a/configs/model/gnn_res_gated.yml b/configs/model/gnn_res_gated.yml index d9ddc05..27d1e78 100644 --- a/configs/model/gnn_res_gated.yml +++ b/configs/model/gnn_res_gated.yml @@ -10,4 +10,4 @@ init_args: n_linear_layers: 3 n_atom_properties: 158 n_bond_properties: 7 - n_molecule_properties: 200 \ No newline at end of file + n_molecule_properties: 200 diff --git a/configs/model/gnn_resgated_pretrain.yml b/configs/model/gnn_resgated_pretrain.yml index c26db76..fad8c27 100644 --- a/configs/model/gnn_resgated_pretrain.yml +++ b/configs/model/gnn_resgated_pretrain.yml @@ -13,4 +13,4 @@ init_args: n_linear_layers: 3 n_atom_properties: 151 n_bond_properties: 7 - n_molecule_properties: 200 \ No newline at end of file + n_molecule_properties: 200 diff --git a/pyproject.toml b/pyproject.toml index dc30793..138ab8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,4 +22,4 @@ build-backend = "flit_core.buildapi" requires = ["flit_core >=3.2,<4"] [project.entry-points.'chebai.plugins'] -models = 'chebai_graph.models' \ No newline at end of file +models = 'chebai_graph.models' From 2ba49324c90aa09631bd6f615ac9024307ce704b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 20:21:42 +0200 Subject: [PATCH 4/6] ruff error - use func instead of line func --- chebai_graph/preprocessing/datasets/chebi.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 60a711f..8532bf0 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -96,11 +96,12 @@ def _setup_properties(self): features = [row["features"] for row in raw_data] # use vectorized version of encode function, apply only if value is present - enc_if_not_none = lambda encode, value: ( - [encode(atom_v) for atom_v in value] - if value is not None and len(value) > 0 - else None - ) + def enc_if_not_none(encode, value): + return ( + [encode(v) for v in value] + if value is not None and len(value) > 0 + else None + ) for property in self.properties: if not os.path.isfile(self.get_property_path(property)): From f34fe80c4dd17f8ad59e79ad5fbe509c7c571ccd Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 20:26:26 +0200 Subject: [PATCH 5/6] dyanmic import for nx --- chebai_graph/preprocessing/reader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index 1d66de0..e71fcfe 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -2,7 +2,6 @@ from typing import List, Optional import chebai.preprocessing.reader as dr -import networkx as nx import pysmiles as ps import rdkit.Chem as Chem import torch @@ -93,6 +92,8 @@ def name(cls): return "graph" def _read_data(self, raw_data) -> Optional[GeomData]: + import networkx as nx + # raw_data is a SMILES string try: mol = ps.read_smiles(raw_data) From cd2655d66fe7280a41c45639080129470f700751 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 20:31:25 +0200 Subject: [PATCH 6/6] update pyproject for inference --- pyproject.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 138ab8b..3617dfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,13 @@ dev = [ "black" ] +inference = [ + "chebai" # pip install chebai[inference] + # below packages need to manually installed as mentioned in readme + # torch-geometric + # torch_scatter +] + [build-system] build-backend = "flit_core.buildapi" requires = ["flit_core >=3.2,<4"]