Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions .github/workflows/black.yml

This file was deleted.

26 changes: 26 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Lint

on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10' # or any version your project uses

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install black==25.1.0 ruff==0.12.2

- name: Run Black
run: black --check .

- name: Run Ruff (no formatting)
run: ruff check . --no-fix
34 changes: 28 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,31 @@
repos:
#- repo: https://github.com/PyCQA/isort
# rev: "5.12.0"
# hooks:
# - id: isort
- repo: https://github.com/psf/black
rev: "22.10.0"
rev: "25.1.0"
hooks:
- id: black
- id: black
- id: black-jupyter # for formatting jupyter-notebook

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
args: ["--profile=black"]

- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
hooks:
- id: ruff
args: [--fix]
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Graph-based models for molecular property prediction and ontology classification

## Installation

To install this repository, download it and run
To install this repository, download it and run

```bash
pip install .
Expand All @@ -30,7 +30,7 @@ Replace:
- `${TORCH}` with a PyTorch version (e.g., `2.6.0`; for later versions, check first if they are compatible with torch_scatter and torch_geometric)
- `${CUDA}` with e.g. `cpu`, `cu118`, or `cu121` depending on your system and CUDA version

If you already have `torch` installed, make sure that `torch_scatter` and `torch_geometric` are compatible with your
If you already have `torch` installed, make sure that `torch_scatter` and `torch_geometric` are compatible with your
PyTorch version and are installed with the same CUDA version.

For a full list of currently available PyTorch versions and CUDA compatibility, please refer to libraries' official documentation:
Expand Down Expand Up @@ -68,11 +68,10 @@ my_projects/
### Ontology Prediction


This example command trains a Residual Gated Graph Convolutional Network on the ChEBI50 dataset (see [wiki](https://github.com/ChEB-AI/python-chebai/wiki/Data-Management)).
The dataset has a customizable list of properties for atoms, bonds and molecules that are added to the graph.
This example command trains a Residual Gated Graph Convolutional Network on the ChEBI50 dataset (see [wiki](https://github.com/ChEB-AI/python-chebai/wiki/Data-Management)).
The dataset has a customizable list of properties for atoms, bonds and molecules that are added to the graph.
The list can be found in the `configs/data/chebi50_graph_properties.yml` file.

```bash
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/csv_logger.yml --model=../python-chebai-graph/configs/model/gnn_res_gated.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml
```

9 changes: 5 additions & 4 deletions chebai_graph/models/gin_net.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import typing

import torch
import torch.nn.functional as F
import torch_geometric
from torch_scatter import scatter_add

from chebai_graph.models.graph import GraphBaseNet
import torch_geometric
import torch.nn.functional as F
import torch
import typing


class AggregateMLP(torch.nn.Module):
Expand Down
38 changes: 37 additions & 1 deletion chebai_graph/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,37 @@
from chebai_graph.preprocessing.properties import *
from chebai_graph.preprocessing.properties import (
AtomAromaticity,
AtomCharge,
AtomChirality,
AtomHybridization,
AtomNumHs,
AtomProperty,
AtomType,
BondAromaticity,
BondInRing,
BondProperty,
BondType,
MolecularProperty,
MoleculeNumRings,
MoleculeProperty,
NumAtomBonds,
RDKit2DNormalized,
)

__all__ = [
"AtomAromaticity",
"AtomCharge",
"AtomChirality",
"AtomHybridization",
"AtomNumHs",
"AtomProperty",
"AtomType",
"BondAromaticity",
"BondInRing",
"BondProperty",
"BondType",
"MolecularProperty",
"MoleculeNumRings",
"MoleculeProperty",
"NumAtomBonds",
"RDKit2DNormalized",
]
4 changes: 2 additions & 2 deletions chebai_graph/preprocessing/collate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Dict

import torch
from chebai.preprocessing.collate import RaggedCollator
from torch_geometric.data import Data as GeomData
from torch_geometric.data.collate import collate as graph_collate
from chebai_graph.preprocessing.structures import XYGraphData

from chebai.preprocessing.collate import RaggedCollator
from chebai_graph.preprocessing.structures import XYGraphData


class GraphCollator(RaggedCollator):
Expand Down
11 changes: 6 additions & 5 deletions chebai_graph/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ def _setup_properties(self):
features = [row["features"] for row in raw_data]

# use vectorized version of encode function, apply only if value is present
enc_if_not_none = lambda encode, value: (
[encode(atom_v) for atom_v in value]
if value is not None and len(value) > 0
else None
)
def enc_if_not_none(encode, value):
return (
[encode(v) for v in value]
if value is not None and len(value) > 0
else None
)

for property in self.properties:
if not os.path.isfile(self.get_property_path(property)):
Expand Down
3 changes: 2 additions & 1 deletion chebai_graph/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn
from chebai.preprocessing.datasets.pubchem import PubchemChem

from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn


class PubChemGraphProperties(GraphPropertiesMixIn, PubchemChem):
pass
6 changes: 3 additions & 3 deletions chebai_graph/preprocessing/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from descriptastorus.descriptors import rdNormalizedDescriptors

from chebai_graph.preprocessing.property_encoder import (
PropertyEncoder,
IndexEncoder,
OneHotEncoder,
AsIsEncoder,
BoolEncoder,
IndexEncoder,
OneHotEncoder,
PropertyEncoder,
)


Expand Down
10 changes: 5 additions & 5 deletions chebai_graph/preprocessing/property_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import abc
import inspect
import os
import sys
from itertools import islice
from typing import Optional

import torch
import sys
from itertools import islice
import inspect


class PropertyEncoder(abc.ABC):
Expand Down Expand Up @@ -94,7 +94,7 @@ def on_finish(self):

def encode(self, token):
"""Returns a unique number for each token, automatically adds new tokens to the cache."""
if not str(token) in self.cache:
if str(token) not in self.cache:
self.cache[(str(token))] = len(self.cache)
return torch.tensor([self.cache[str(token)] + self.offset])

Expand All @@ -111,7 +111,7 @@ def get_encoding_length(self) -> int:

@property
def name(self):
return f"one_hot"
return "one_hot"

def on_start(self, property_values):
"""To get correct number of classes during encoding, cache unique tokens beforehand"""
Expand Down
25 changes: 12 additions & 13 deletions chebai_graph/preprocessing/reader.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import importlib
import os
from typing import List, Optional

import chebai.preprocessing.reader as dr
import pysmiles as ps
import rdkit.Chem as Chem
import torch
from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn
from torch_geometric.data import Data as GeomData
from torch_geometric.utils import from_networkx
from typing import Tuple, Mapping, Optional, List

import importlib
import networkx as nx
import os
import torch
import rdkit.Chem as Chem
import pysmiles as ps
import chebai.preprocessing.reader as dr
from chebai_graph.preprocessing.collate import GraphCollator
import chebai_graph.preprocessing.properties as properties
from torch_geometric.data import Data as GeomData
from lightning_utilities.core.rank_zero import rank_zero_warn, rank_zero_info
from chebai_graph.preprocessing.collate import GraphCollator


class GraphPropertyReader(dr.DataReader):
Expand Down Expand Up @@ -44,7 +41,7 @@ def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]:
else:
try:
Chem.SanitizeMol(mol)
except Exception as e:
except Exception:
rank_zero_warn(f"Rdkit failed at sanitizing {smiles}")
self.failed_counter += 1
self.mol_object_buffer[smiles] = mol
Expand Down Expand Up @@ -95,6 +92,8 @@ def name(cls):
return "graph"

def _read_data(self, raw_data) -> Optional[GeomData]:
import networkx as nx
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only code where dynamic import is used, rest of the changes are formatting changes by ruff

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks


# raw_data is a SMILES string
try:
mol = ps.read_smiles(raw_data)
Expand Down
1 change: 1 addition & 0 deletions chebai_graph/preprocessing/transform_unlabeled.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random

import torch


Expand Down
2 changes: 1 addition & 1 deletion configs/data/chebi50_graph.yml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData
class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData
2 changes: 1 addition & 1 deletion configs/data/pubchem_graph.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ init_args:
- chebai_graph.preprocessing.properties.BondInRing
- chebai_graph.preprocessing.properties.BondAromaticity
#- chebai_graph.preprocessing.properties.MoleculeNumRings
- chebai_graph.preprocessing.properties.RDKit2DNormalized
- chebai_graph.preprocessing.properties.RDKit2DNormalized
2 changes: 1 addition & 1 deletion configs/loss/mask_pretraining.yml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss
class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss
2 changes: 1 addition & 1 deletion configs/model/gnn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ init_args:
hidden_length: 512
dropout_rate: 0.1
n_conv_layers: 3
n_linear_layers: 3
n_linear_layers: 3
2 changes: 1 addition & 1 deletion configs/model/gnn_attention.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ init_args:
dropout_rate: 0.1
n_conv_layers: 5
n_linear_layers: 3
n_heads: 5
n_heads: 5
2 changes: 1 addition & 1 deletion configs/model/gnn_gine.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ init_args:
n_conv_layers: 5
n_linear_layers: 3
n_atom_properties: 125
n_bond_properties: 5
n_bond_properties: 5
2 changes: 1 addition & 1 deletion configs/model/gnn_res_gated.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ init_args:
n_linear_layers: 3
n_atom_properties: 158
n_bond_properties: 7
n_molecule_properties: 200
n_molecule_properties: 200
2 changes: 1 addition & 1 deletion configs/model/gnn_resgated_pretrain.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ init_args:
n_linear_layers: 3
n_atom_properties: 151
n_bond_properties: 7
n_molecule_properties: 200
n_molecule_properties: 200
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@ dev = [
"black"
]

inference = [
"chebai" # pip install chebai[inference]
# below packages need to manually installed as mentioned in readme
# torch-geometric
# torch_scatter
]

[build-system]
build-backend = "flit_core.buildapi"
requires = ["flit_core >=3.2,<4"]

[project.entry-points.'chebai.plugins']
models = 'chebai_graph.models'
models = 'chebai_graph.models'