diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml deleted file mode 100644 index b04fb15..0000000 --- a/.github/workflows/black.yml +++ /dev/null @@ -1,10 +0,0 @@ -name: Lint - -on: [push, pull_request] - -jobs: - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: psf/black@stable diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..bb9154f --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' # or any version your project uses + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install black==25.1.0 ruff==0.12.2 + + - name: Run Black + run: black --check . + + - name: Run Ruff (no formatting) + run: ruff check . --no-fix diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 866c153..cbb7284 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,31 @@ repos: -#- repo: https://github.com/PyCQA/isort -# rev: "5.12.0" -# hooks: -# - id: isort - repo: https://github.com/psf/black - rev: "22.10.0" + rev: "25.1.0" hooks: - - id: black \ No newline at end of file + - id: black + - id: black-jupyter # for formatting jupyter-notebook + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + args: ["--profile=black"] + +- repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.2 + hooks: + - id: ruff + args: [--fix] diff --git a/README.md b/README.md index 20b0eab..35f64d3 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Graph-based models for molecular property prediction and ontology classification ## Installation -To install this repository, download it and run +To install this repository, download it and run ```bash pip install . @@ -30,7 +30,7 @@ Replace: - `${TORCH}` with a PyTorch version (e.g., `2.6.0`; for later versions, check first if they are compatible with torch_scatter and torch_geometric) - `${CUDA}` with e.g. `cpu`, `cu118`, or `cu121` depending on your system and CUDA version -If you already have `torch` installed, make sure that `torch_scatter` and `torch_geometric` are compatible with your +If you already have `torch` installed, make sure that `torch_scatter` and `torch_geometric` are compatible with your PyTorch version and are installed with the same CUDA version. For a full list of currently available PyTorch versions and CUDA compatibility, please refer to libraries' official documentation: @@ -68,11 +68,10 @@ my_projects/ ### Ontology Prediction -This example command trains a Residual Gated Graph Convolutional Network on the ChEBI50 dataset (see [wiki](https://github.com/ChEB-AI/python-chebai/wiki/Data-Management)). -The dataset has a customizable list of properties for atoms, bonds and molecules that are added to the graph. +This example command trains a Residual Gated Graph Convolutional Network on the ChEBI50 dataset (see [wiki](https://github.com/ChEB-AI/python-chebai/wiki/Data-Management)). +The dataset has a customizable list of properties for atoms, bonds and molecules that are added to the graph. The list can be found in the `configs/data/chebi50_graph_properties.yml` file. ```bash python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/csv_logger.yml --model=../python-chebai-graph/configs/model/gnn_res_gated.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml ``` - diff --git a/chebai_graph/models/gin_net.py b/chebai_graph/models/gin_net.py index 75c2c45..6fed4c6 100644 --- a/chebai_graph/models/gin_net.py +++ b/chebai_graph/models/gin_net.py @@ -1,10 +1,11 @@ +import typing + +import torch +import torch.nn.functional as F +import torch_geometric from torch_scatter import scatter_add from chebai_graph.models.graph import GraphBaseNet -import torch_geometric -import torch.nn.functional as F -import torch -import typing class AggregateMLP(torch.nn.Module): diff --git a/chebai_graph/preprocessing/__init__.py b/chebai_graph/preprocessing/__init__.py index 2b98ba8..80488cc 100644 --- a/chebai_graph/preprocessing/__init__.py +++ b/chebai_graph/preprocessing/__init__.py @@ -1 +1,37 @@ -from chebai_graph.preprocessing.properties import * +from chebai_graph.preprocessing.properties import ( + AtomAromaticity, + AtomCharge, + AtomChirality, + AtomHybridization, + AtomNumHs, + AtomProperty, + AtomType, + BondAromaticity, + BondInRing, + BondProperty, + BondType, + MolecularProperty, + MoleculeNumRings, + MoleculeProperty, + NumAtomBonds, + RDKit2DNormalized, +) + +__all__ = [ + "AtomAromaticity", + "AtomCharge", + "AtomChirality", + "AtomHybridization", + "AtomNumHs", + "AtomProperty", + "AtomType", + "BondAromaticity", + "BondInRing", + "BondProperty", + "BondType", + "MolecularProperty", + "MoleculeNumRings", + "MoleculeProperty", + "NumAtomBonds", + "RDKit2DNormalized", +] diff --git a/chebai_graph/preprocessing/collate.py b/chebai_graph/preprocessing/collate.py index 2c5f696..4be36cf 100644 --- a/chebai_graph/preprocessing/collate.py +++ b/chebai_graph/preprocessing/collate.py @@ -1,11 +1,11 @@ from typing import Dict import torch +from chebai.preprocessing.collate import RaggedCollator from torch_geometric.data import Data as GeomData from torch_geometric.data.collate import collate as graph_collate -from chebai_graph.preprocessing.structures import XYGraphData -from chebai.preprocessing.collate import RaggedCollator +from chebai_graph.preprocessing.structures import XYGraphData class GraphCollator(RaggedCollator): diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 60a711f..8532bf0 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -96,11 +96,12 @@ def _setup_properties(self): features = [row["features"] for row in raw_data] # use vectorized version of encode function, apply only if value is present - enc_if_not_none = lambda encode, value: ( - [encode(atom_v) for atom_v in value] - if value is not None and len(value) > 0 - else None - ) + def enc_if_not_none(encode, value): + return ( + [encode(v) for v in value] + if value is not None and len(value) > 0 + else None + ) for property in self.properties: if not os.path.isfile(self.get_property_path(property)): diff --git a/chebai_graph/preprocessing/datasets/pubchem.py b/chebai_graph/preprocessing/datasets/pubchem.py index 210b7ab..6f5d118 100644 --- a/chebai_graph/preprocessing/datasets/pubchem.py +++ b/chebai_graph/preprocessing/datasets/pubchem.py @@ -1,6 +1,7 @@ -from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn from chebai.preprocessing.datasets.pubchem import PubchemChem +from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn + class PubChemGraphProperties(GraphPropertiesMixIn, PubchemChem): pass diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties.py index 543a344..2b3acf8 100644 --- a/chebai_graph/preprocessing/properties.py +++ b/chebai_graph/preprocessing/properties.py @@ -6,11 +6,11 @@ from descriptastorus.descriptors import rdNormalizedDescriptors from chebai_graph.preprocessing.property_encoder import ( - PropertyEncoder, - IndexEncoder, - OneHotEncoder, AsIsEncoder, BoolEncoder, + IndexEncoder, + OneHotEncoder, + PropertyEncoder, ) diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index f998396..532c91f 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -1,11 +1,11 @@ import abc +import inspect import os +import sys +from itertools import islice from typing import Optional import torch -import sys -from itertools import islice -import inspect class PropertyEncoder(abc.ABC): @@ -94,7 +94,7 @@ def on_finish(self): def encode(self, token): """Returns a unique number for each token, automatically adds new tokens to the cache.""" - if not str(token) in self.cache: + if str(token) not in self.cache: self.cache[(str(token))] = len(self.cache) return torch.tensor([self.cache[str(token)] + self.offset]) @@ -111,7 +111,7 @@ def get_encoding_length(self) -> int: @property def name(self): - return f"one_hot" + return "one_hot" def on_start(self, property_values): """To get correct number of classes during encoding, cache unique tokens beforehand""" diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index e48c6ab..e71fcfe 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -1,19 +1,16 @@ -import importlib +import os +from typing import List, Optional +import chebai.preprocessing.reader as dr +import pysmiles as ps +import rdkit.Chem as Chem +import torch +from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn +from torch_geometric.data import Data as GeomData from torch_geometric.utils import from_networkx -from typing import Tuple, Mapping, Optional, List -import importlib -import networkx as nx -import os -import torch -import rdkit.Chem as Chem -import pysmiles as ps -import chebai.preprocessing.reader as dr -from chebai_graph.preprocessing.collate import GraphCollator import chebai_graph.preprocessing.properties as properties -from torch_geometric.data import Data as GeomData -from lightning_utilities.core.rank_zero import rank_zero_warn, rank_zero_info +from chebai_graph.preprocessing.collate import GraphCollator class GraphPropertyReader(dr.DataReader): @@ -44,7 +41,7 @@ def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: else: try: Chem.SanitizeMol(mol) - except Exception as e: + except Exception: rank_zero_warn(f"Rdkit failed at sanitizing {smiles}") self.failed_counter += 1 self.mol_object_buffer[smiles] = mol @@ -95,6 +92,8 @@ def name(cls): return "graph" def _read_data(self, raw_data) -> Optional[GeomData]: + import networkx as nx + # raw_data is a SMILES string try: mol = ps.read_smiles(raw_data) diff --git a/chebai_graph/preprocessing/transform_unlabeled.py b/chebai_graph/preprocessing/transform_unlabeled.py index 3920659..0cc4b35 100644 --- a/chebai_graph/preprocessing/transform_unlabeled.py +++ b/chebai_graph/preprocessing/transform_unlabeled.py @@ -1,4 +1,5 @@ import random + import torch diff --git a/configs/data/chebi50_graph.yml b/configs/data/chebi50_graph.yml index 14cc489..19c8753 100644 --- a/configs/data/chebi50_graph.yml +++ b/configs/data/chebi50_graph.yml @@ -1 +1 @@ -class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData \ No newline at end of file +class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData diff --git a/configs/data/pubchem_graph.yml b/configs/data/pubchem_graph.yml index af04491..c21f188 100644 --- a/configs/data/pubchem_graph.yml +++ b/configs/data/pubchem_graph.yml @@ -16,4 +16,4 @@ init_args: - chebai_graph.preprocessing.properties.BondInRing - chebai_graph.preprocessing.properties.BondAromaticity #- chebai_graph.preprocessing.properties.MoleculeNumRings - - chebai_graph.preprocessing.properties.RDKit2DNormalized \ No newline at end of file + - chebai_graph.preprocessing.properties.RDKit2DNormalized diff --git a/configs/loss/mask_pretraining.yml b/configs/loss/mask_pretraining.yml index c677559..6d2a560 100644 --- a/configs/loss/mask_pretraining.yml +++ b/configs/loss/mask_pretraining.yml @@ -1 +1 @@ -class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss \ No newline at end of file +class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss diff --git a/configs/model/gnn.yml b/configs/model/gnn.yml index b0b119d..f85fa76 100644 --- a/configs/model/gnn.yml +++ b/configs/model/gnn.yml @@ -7,4 +7,4 @@ init_args: hidden_length: 512 dropout_rate: 0.1 n_conv_layers: 3 - n_linear_layers: 3 \ No newline at end of file + n_linear_layers: 3 diff --git a/configs/model/gnn_attention.yml b/configs/model/gnn_attention.yml index b1c553b..0c11ced 100644 --- a/configs/model/gnn_attention.yml +++ b/configs/model/gnn_attention.yml @@ -8,4 +8,4 @@ init_args: dropout_rate: 0.1 n_conv_layers: 5 n_linear_layers: 3 - n_heads: 5 \ No newline at end of file + n_heads: 5 diff --git a/configs/model/gnn_gine.yml b/configs/model/gnn_gine.yml index 0d0ed20..c84ea61 100644 --- a/configs/model/gnn_gine.yml +++ b/configs/model/gnn_gine.yml @@ -8,4 +8,4 @@ init_args: n_conv_layers: 5 n_linear_layers: 3 n_atom_properties: 125 - n_bond_properties: 5 \ No newline at end of file + n_bond_properties: 5 diff --git a/configs/model/gnn_res_gated.yml b/configs/model/gnn_res_gated.yml index d9ddc05..27d1e78 100644 --- a/configs/model/gnn_res_gated.yml +++ b/configs/model/gnn_res_gated.yml @@ -10,4 +10,4 @@ init_args: n_linear_layers: 3 n_atom_properties: 158 n_bond_properties: 7 - n_molecule_properties: 200 \ No newline at end of file + n_molecule_properties: 200 diff --git a/configs/model/gnn_resgated_pretrain.yml b/configs/model/gnn_resgated_pretrain.yml index c26db76..fad8c27 100644 --- a/configs/model/gnn_resgated_pretrain.yml +++ b/configs/model/gnn_resgated_pretrain.yml @@ -13,4 +13,4 @@ init_args: n_linear_layers: 3 n_atom_properties: 151 n_bond_properties: 7 - n_molecule_properties: 200 \ No newline at end of file + n_molecule_properties: 200 diff --git a/pyproject.toml b/pyproject.toml index dc30793..3617dfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,16 @@ dev = [ "black" ] +inference = [ + "chebai" # pip install chebai[inference] + # below packages need to manually installed as mentioned in readme + # torch-geometric + # torch_scatter +] + [build-system] build-backend = "flit_core.buildapi" requires = ["flit_core >=3.2,<4"] [project.entry-points.'chebai.plugins'] -models = 'chebai_graph.models' \ No newline at end of file +models = 'chebai_graph.models'