Skip to content

Commit bbb8f10

Browse files
committed
Merge branch 'dev' into fix/read_data
2 parents 6761079 + 836af9e commit bbb8f10

File tree

72 files changed

+3138
-91
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+3138
-91
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,7 @@ lightning_logs
176176
logs
177177
.isort.cfg
178178
/.vscode
179+
180+
*.out
181+
*.err
182+
*.sh

README.md

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,30 @@
33
ChEBai is a deep learning library designed for the integration of deep learning methods with chemical ontologies, particularly ChEBI.
44
The library emphasizes the incorporation of the semantic qualities of the ontology into the learning process.
55

6-
## Installation
6+
## News
7+
8+
We now support regression tasks!
9+
10+
## Note for developers
711

8-
You can install ChEBai via pip:
12+
If you have used ChEBai before PR #39, the file structure in which your ChEBI-data is saved has changed. This means that
13+
datasets will be freshly generated. The data however is the same. If you want to keep the old data (including the old
14+
splits), you can use a migration script. It copies the old data to the new location for a specific ChEBI class
15+
(including chebi version and other parameters). The script can be called by specifying the data module from a config
916
```
10-
pip install chebai
17+
python chebai/preprocessing/migration/chebi_data_migration.py migrate --datamodule=[path-to-data-config]
18+
```
19+
or by specifying the class name (e.g. `ChEBIOver50`) and arguments separately
1120
```
21+
python chebai/preprocessing/migration/chebi_data_migration.py migrate --class_name=[data-class] [--chebi_version=[version]]
22+
```
23+
The new dataset will by default generate random data splits (with a given seed).
24+
To reuse a fixed data split, you have to provide the path of the csv file generated during the migration:
25+
`--data.init_args.splits_file_path=[path-to-processed_data]/splits.csv`
1226

13-
Alternatively, you can get the latest development version directly from GitHub:
27+
## Installation
28+
29+
To install ChEBai, follow these steps:
1430

1531
1. Clone the repository:
1632
```
@@ -63,11 +79,16 @@ A command with additional options may look like this:
6379
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi50.yml --model.criterion=configs/loss/bce.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
6480
```
6581

66-
### Fine-tuning for Toxicity prediction
82+
### Fine-tuning for classification tasks, e.g. Toxicity prediction
6783
```
6884
python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=configs/training/default_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model]
6985
```
7086

87+
### Fine-tuning for regression tasks, e.g. solubility prediction
88+
```
89+
python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=configs/training/solCur_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model]
90+
```
91+
7192
### Predicting classes given SMILES strings
7293
```
7394
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
@@ -81,7 +102,7 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont
81102
You can evaluate a model trained on the ontology extension task in one of two ways:
82103

83104
### 1. Using the Jupyter Notebook
84-
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
105+
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
85106
- Load your finetuned model and run the evaluation cells to compute metrics on the test set.
86107

87108
### 2. Using the Lightning CLI

chebai/cli.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,44 @@ def call_data_methods(data: Type[XYBaseDataModule]):
6060
)
6161

6262
for kind in ("train", "val", "test"):
63-
for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
63+
for average in (
64+
"micro-f1",
65+
"macro-f1",
66+
"balanced-accuracy",
67+
"roc-auc",
68+
"f1",
69+
"mse",
70+
"rmse",
71+
"r2",
72+
):
73+
# When using lightning > 2.5.1 then need to uncomment all metrics that are not used
74+
# for average in ("mse", "rmse","r2"): # for regression
75+
# for average in ("f1", "roc-auc"): # for binary classification
76+
# for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification
77+
# for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy
6478
parser.link_arguments(
6579
"data.num_of_labels",
6680
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
6781
apply_on="instantiate",
6882
)
83+
6984
parser.link_arguments(
7085
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
7186
)
87+
# parser.link_arguments(
88+
# "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
89+
# )
90+
# parser.link_arguments(
91+
# "data", "model.init_args.criterion.init_args.data_extractor"
92+
# )
93+
# parser.link_arguments(
94+
# "data.init_args.chebi_version",
95+
# "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version",
96+
# )
97+
98+
parser.link_arguments(
99+
"data", "model.init_args.criterion.init_args.data_extractor"
100+
)
72101

73102
@staticmethod
74103
def subcommands() -> Dict[str, Set[str]]:

chebai/loss/bce_weighted.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
1111
"""
1212
BCEWithLogitsLoss with weights automatically computed according to the beta parameter.
13-
If beta is None or data_extractor is None, the loss is unweighted.
1413
1514
This class computes weights based on the formula from the paper by Cui et al. (2019):
1615
https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf
@@ -22,7 +21,7 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss):
2221

2322
def __init__(
2423
self,
25-
beta: Optional[float] = None,
24+
beta: float = 0.99,
2625
data_extractor: Optional[XYBaseDataModule] = None,
2726
**kwargs,
2827
):
@@ -32,11 +31,26 @@ def __init__(
3231
if isinstance(data_extractor, LabeledUnlabeledMixed):
3332
data_extractor = data_extractor.labeled
3433
self.data_extractor = data_extractor
34+
35+
assert (
36+
isinstance(beta, float) and beta > 0.0
37+
), f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}."
38+
39+
assert (
40+
self.data_extractor is not None
41+
), f"Data extractor must be provided if this loss class ({self.__class__.__name__}) is used."
42+
43+
assert all(
44+
os.path.exists(os.path.join(self.data_extractor.processed_dir, file_name))
45+
for file_name in self.data_extractor.processed_file_names
46+
), "Dataset files not found. Make sure the dataset is processed before using this loss."
47+
3548
assert (
3649
isinstance(self.data_extractor, _ChEBIDataExtractor)
3750
or self.data_extractor is None
3851
)
3952
super().__init__(**kwargs)
53+
self.pos_weight: Optional[torch.Tensor] = None
4054

4155
def set_pos_weight(self, input: torch.Tensor) -> None:
4256
"""
@@ -45,17 +59,7 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
4559
Args:
4660
input (torch.Tensor): The input tensor for which to set the positive weights.
4761
"""
48-
if (
49-
self.beta is not None
50-
and self.data_extractor is not None
51-
and all(
52-
os.path.exists(
53-
os.path.join(self.data_extractor.processed_dir, file_name)
54-
)
55-
for file_name in self.data_extractor.processed_file_names
56-
)
57-
and self.pos_weight is None
58-
):
62+
if self.pos_weight is None:
5963
print(
6064
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
6165
)
@@ -96,3 +100,9 @@ def forward(
96100
"""
97101
self.set_pos_weight(input)
98102
return super().forward(input, target)
103+
104+
105+
class UnWeightedBCEWithLogitsLoss(torch.nn.BCEWithLogitsLoss):
106+
def forward(self, input, target, **kwargs):
107+
# As the custom passed kwargs are not used in BCEWithLogitsLoss, we can ignore them
108+
return super().forward(input, target)

chebai/loss/focal_loss.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
# from https://github.com/itakurah/Focal-loss-PyTorch
7+
8+
9+
class FocalLoss(nn.Module):
10+
def __init__(
11+
self,
12+
gamma=2,
13+
alpha=None,
14+
reduction="mean",
15+
task_type="binary",
16+
num_classes=None,
17+
):
18+
"""
19+
Unified Focal Loss class for binary, multi-class, and multi-label classification tasks.
20+
:param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma
21+
:param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used.
22+
:param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum'
23+
:param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label'
24+
:param num_classes: Number of classes (only required for multi-class classification)
25+
"""
26+
super(FocalLoss, self).__init__()
27+
self.gamma = gamma
28+
self.alpha = alpha
29+
self.reduction = reduction
30+
self.task_type = task_type
31+
self.num_classes = num_classes
32+
33+
# Handle alpha for class balancing in multi-class tasks
34+
if (
35+
task_type == "multi-class"
36+
and alpha is not None
37+
and isinstance(alpha, (list, torch.Tensor))
38+
):
39+
assert (
40+
num_classes is not None
41+
), "num_classes must be specified for multi-class classification"
42+
if isinstance(alpha, list):
43+
self.alpha = torch.Tensor(alpha)
44+
else:
45+
self.alpha = alpha
46+
47+
def forward(self, inputs, targets):
48+
"""
49+
Forward pass to compute the Focal Loss based on the specified task type.
50+
:param inputs: Predictions (logits) from the model.
51+
Shape:
52+
- binary/multi-label: (batch_size, num_classes)
53+
- multi-class: (batch_size, num_classes)
54+
:param targets: Ground truth labels.
55+
Shape:
56+
- binary: (batch_size,)
57+
- multi-label: (batch_size, num_classes)
58+
- multi-class: (batch_size,)
59+
"""
60+
if self.task_type == "binary":
61+
return self.binary_focal_loss(inputs, targets)
62+
elif self.task_type == "multi-class":
63+
return self.multi_class_focal_loss(inputs, targets)
64+
elif self.task_type == "multi-label":
65+
return self.multi_label_focal_loss(inputs, targets)
66+
else:
67+
raise ValueError(
68+
f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'."
69+
)
70+
71+
def binary_focal_loss(self, inputs, targets):
72+
"""Focal loss for binary classification."""
73+
probs = torch.sigmoid(inputs)
74+
targets = targets.float()
75+
76+
# Compute binary cross entropy
77+
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
78+
79+
# Compute focal weight
80+
p_t = probs * targets + (1 - probs) * (1 - targets)
81+
focal_weight = (1 - p_t) ** self.gamma
82+
83+
# Apply alpha if provided
84+
if self.alpha is not None:
85+
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
86+
bce_loss = alpha_t * bce_loss
87+
88+
# Apply focal loss weighting
89+
loss = focal_weight * bce_loss
90+
91+
if self.reduction == "mean":
92+
return loss.mean()
93+
elif self.reduction == "sum":
94+
return loss.sum()
95+
return loss
96+
97+
def multi_class_focal_loss(self, inputs, targets):
98+
"""Focal loss for multi-class classification."""
99+
if self.alpha is not None:
100+
alpha = self.alpha.to(inputs.device)
101+
102+
# Convert logits to probabilities with softmax
103+
probs = F.softmax(inputs, dim=1)
104+
105+
# One-hot encode the targets
106+
targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float()
107+
108+
# Compute cross-entropy for each class
109+
ce_loss = -targets_one_hot * torch.log(probs)
110+
111+
# Compute focal weight
112+
p_t = torch.sum(probs * targets_one_hot, dim=1) # p_t for each sample
113+
focal_weight = (1 - p_t) ** self.gamma
114+
115+
# Apply alpha if provided (per-class weighting)
116+
if self.alpha is not None:
117+
alpha_t = alpha.gather(0, targets)
118+
ce_loss = alpha_t.unsqueeze(1) * ce_loss
119+
120+
# Apply focal loss weight
121+
loss = focal_weight.unsqueeze(1) * ce_loss
122+
123+
if self.reduction == "mean":
124+
return loss.mean()
125+
elif self.reduction == "sum":
126+
return loss.sum()
127+
return loss
128+
129+
def multi_label_focal_loss(self, inputs, targets):
130+
"""Focal loss for multi-label classification."""
131+
probs = torch.sigmoid(inputs)
132+
133+
# Compute binary cross entropy
134+
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
135+
136+
# Compute focal weight
137+
p_t = probs * targets + (1 - probs) * (1 - targets)
138+
focal_weight = (1 - p_t) ** self.gamma
139+
140+
# Apply alpha if provided
141+
if self.alpha is not None:
142+
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
143+
bce_loss = alpha_t * bce_loss
144+
145+
# Apply focal loss weight
146+
loss = focal_weight * bce_loss
147+
148+
if self.reduction == "mean":
149+
return loss.mean()
150+
elif self.reduction == "sum":
151+
return loss.sum()
152+
return loss

chebai/loss/semantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33
import os
44
import pickle
5-
from typing import TYPE_CHECKING, List, Literal, Union
5+
from typing import TYPE_CHECKING, List, Literal, Union, Tuple
66

77
import torch
88

@@ -62,7 +62,7 @@ def __init__(
6262
pos_epsilon: float = 0.01,
6363
multiply_by_softmax: bool = False,
6464
use_sigmoidal_implication: bool = False,
65-
weight_epoch_dependent: Union[bool | tuple[int, int]] = False,
65+
weight_epoch_dependent: Union[bool, Tuple[int, int]] = False,
6666
start_at_epoch: int = 0,
6767
violations_per_cls_aggregator: Literal[
6868
"sum", "max", "mean", "log-sum", "log-max", "log-mean"

chebai/models/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,19 @@ def __init__(
4242
exclude_hyperparameter_logging: Optional[Iterable[str]] = None,
4343
**kwargs,
4444
):
45-
super().__init__()
45+
super().__init__(**kwargs)
46+
# super().__init__()
4647
if exclude_hyperparameter_logging is None:
4748
exclude_hyperparameter_logging = tuple()
4849
self.criterion = criterion
4950
assert out_dim is not None, "out_dim must be specified"
5051
assert input_dim is not None, "input_dim must be specified"
5152
self.out_dim = out_dim
5253
self.input_dim = input_dim
54+
print(
55+
f"Input dimension for the model: {self.input_dim}",
56+
f"Output dimension for the model: {self.out_dim}",
57+
)
5358

5459
self.save_hyperparameters(
5560
ignore=[
@@ -273,7 +278,6 @@ def _execute(
273278
loss_kwargs = dict()
274279
if self.pass_loss_kwargs:
275280
loss_kwargs = loss_kwargs_candidates
276-
loss_kwargs["current_epoch"] = self.trainer.current_epoch
277281
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
278282
if isinstance(loss, tuple):
279283
unnamed_loss_index = 1

0 commit comments

Comments
 (0)