Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions chebai/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Dict, Set
from typing import Dict, Set, Type

from lightning.pytorch.cli import LightningArgumentParser, LightningCLI

from chebai.preprocessing.datasets import XYBaseDataModule
from chebai.trainer.CustomTrainer import CustomTrainer


Expand Down Expand Up @@ -38,14 +39,35 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
Args:
parser (LightningArgumentParser): Argument parser instance.
"""

def call_data_methods(data: Type[XYBaseDataModule]):
if data._num_of_labels is None:
data.prepare_data()
data.setup()
return data.num_of_labels

parser.link_arguments(
"data",
"model.init_args.out_dim",
apply_on="instantiate",
compute_fn=call_data_methods,
)

parser.link_arguments(
"data.feature_vector_size",
"model.init_args.input_dim",
apply_on="instantiate",
)

for kind in ("train", "val", "test"):
for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
parser.link_arguments(
"model.init_args.out_dim",
"data.num_of_labels",
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
apply_on="instantiate",
)
parser.link_arguments(
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
)

@staticmethod
Expand Down
6 changes: 6 additions & 0 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
self,
criterion: torch.nn.Module = None,
out_dim: Optional[int] = None,
input_dim: Optional[int] = None,
train_metrics: Optional[torch.nn.Module] = None,
val_metrics: Optional[torch.nn.Module] = None,
test_metrics: Optional[torch.nn.Module] = None,
Expand All @@ -57,7 +58,12 @@ def __init__(
*exclude_hyperparameter_logging,
]
)

self.out_dim = out_dim
self.input_dim = input_dim
assert out_dim is not None, "out_dim must be specified"
assert input_dim is not None, "input_dim must be specified"

if optimizer_kwargs:
self.optimizer_kwargs = optimizer_kwargs
else:
Expand Down
78 changes: 56 additions & 22 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,23 @@ def __init__(
os.makedirs(os.path.join(self.processed_dir, self.fold_dir), exist_ok=True)
self.save_hyperparameters()

self._num_of_labels = None
self._feature_vector_size = None
self._prepare_data_flag = 1
self._setup_data_flag = 1

@property
def num_of_labels(self):
assert self._num_of_labels is not None, "num of labels must be set"
return self._num_of_labels

@property
def feature_vector_size(self):
assert (
self._feature_vector_size is not None
), "size of feature vector must be set"
return self._feature_vector_size

@property
def identifier(self) -> tuple:
"""Identifier for the dataset."""
Expand Down Expand Up @@ -390,7 +407,17 @@ def predict_dataloader(
"""
return self.dataloader(self.prediction_kind, shuffle=False, **kwargs)

def setup(self, **kwargs):
def prepare_data(self, *args, **kwargs) -> None:
if self._prepare_data_flag != 1:
return

self._prepare_data_flag += 1
self._perform_data_preparation(*args, **kwargs)

def _perform_data_preparation(self, *args, **kwargs) -> None:
raise NotImplementedError

def setup(self, *args, **kwargs) -> None:
"""
Setup the data module.

Expand All @@ -399,6 +426,11 @@ def setup(self, **kwargs):
Args:
**kwargs: Additional keyword arguments.
"""
if self._setup_data_flag != 1:
return

self._setup_data_flag += 1

rank_zero_info(f"Check for processed data in {self.processed_dir}")
rank_zero_info(f"Cross-validation enabled: {self.use_inner_cross_validation}")
if any(
Expand All @@ -410,6 +442,21 @@ def setup(self, **kwargs):
if not ("keep_reader" in kwargs and kwargs["keep_reader"]):
self.reader.on_finish()

self._set_processed_data_props()

def _set_processed_data_props(self):

data_pt = torch.load(
os.path.join(self.processed_dir, self.processed_file_names_dict["data"]),
weights_only=False,
)

self._num_of_labels = len(data_pt[0]["labels"])
self._feature_vector_size = max(len(d["features"]) for d in data_pt)

print(f"Number of labels for loaded data: {self._num_of_labels}")
print(f"Feature vector size: {self._feature_vector_size}")

def setup_processed(self):
"""
Setup the processed data.
Expand Down Expand Up @@ -482,18 +529,6 @@ def raw_file_names_dict(self) -> dict:
"""
raise NotImplementedError

@property
def label_number(self) -> int:
"""
Returns the number of labels.

This property should be implemented by subclasses to provide the number of labels.

Returns:
int: The number of labels. Returns -1 for seq2seq encoding.
"""
raise NotImplementedError


class MergedDataset(XYBaseDataModule):
MERGED = []
Expand Down Expand Up @@ -531,7 +566,7 @@ def __init__(
os.makedirs(self.processed_dir, exist_ok=True)
super(pl.LightningDataModule, self).__init__(**kwargs)

def prepare_data(self):
def _perform_data_preparation(self):
"""
Placeholder for data preparation logic.
"""
Expand All @@ -547,9 +582,15 @@ def setup(self, **kwargs):
Args:
**kwargs: Additional keyword arguments.
"""
if self._setup_data_flag != 1:
return

self._setup_data_flag += 1
for s in self.subsets:
s.setup(**kwargs)

self._set_processed_data_props()

def dataloader(self, kind: str, **kwargs) -> DataLoader:
"""
Creates a DataLoader for a specific subset.
Expand Down Expand Up @@ -623,13 +664,6 @@ def processed_file_names(self) -> List[str]:
"""
return ["test.pt", "train.pt", "validation.pt"]

@property
def label_number(self) -> int:
"""
Returns the number of labels from the first subset.
"""
return self.subsets[0].label_number

@property
def limits(self):
"""
Expand Down Expand Up @@ -725,7 +759,7 @@ def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]
return splits_file_path

# ------------------------------ Phase: Prepare data -----------------------------------
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None:
"""
Prepares the data for the dataset.

Expand Down
46 changes: 3 additions & 43 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def download(self):
def raw_file_names(self):
return ["test.pkl", "train.pkl", "validation.pkl"]

def prepare_data(self, *args, **kwargs):
def _perform_data_preparation(self, *args, **kwargs):
print("Check for raw data in", self.raw_dir)
if any(
not os.path.isfile(os.path.join(self.raw_dir, f))
Expand Down Expand Up @@ -88,10 +88,6 @@ def setup_processed(self):
os.path.join(self.processed_dir, f"{k}.pt"),
)

@property
def label_number(self):
return 500


class JCIData(JCIBase):
READER = dr.OrdReader
Expand Down Expand Up @@ -158,7 +154,7 @@ def __init__(
)

# ------------------------------ Phase: Prepare data -----------------------------------
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None:
"""
Prepares the data for the Chebi dataset.

Expand All @@ -179,7 +175,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
Returns:
None
"""
super().prepare_data(args, kwargs)
super()._perform_data_preparation(args, kwargs)

if self.chebi_version_train is not None:
if not os.path.isfile(
Expand Down Expand Up @@ -545,10 +541,6 @@ def raw_file_names_dict(self) -> dict:

class JCIExtendedBase(_ChEBIDataExtractor):

@property
def label_number(self):
return 500

@property
def _name(self):
return "JCI_extended"
Expand All @@ -573,16 +565,6 @@ class ChEBIOverX(_ChEBIDataExtractor):
READER: dr.ChemDataReader = dr.ChemDataReader
THRESHOLD: int = None

@property
def label_number(self) -> int:
"""
Returns the number of labels in the dataset.

Returns:
int: The number of labels.
"""
return 854

@property
def _name(self) -> str:
"""
Expand Down Expand Up @@ -675,17 +657,6 @@ class ChEBIOver100(ChEBIOverX):

THRESHOLD: int = 100

def label_number(self) -> int:
"""
Returns the number of labels in the dataset.

Overrides the base class method to return the correct number of labels for this threshold.

Returns:
int: The number of labels.
"""
return 854


class ChEBIOver50(ChEBIOverX):
"""
Expand All @@ -699,17 +670,6 @@ class ChEBIOver50(ChEBIOverX):

THRESHOLD: int = 50

def label_number(self) -> int:
"""
Returns the number of labels in the dataset.

Overrides the base class method to return the correct number of labels for this threshold.

Returns:
int: The number of labels.
"""
return 1332


class ChEBIOver100DeepSMILES(ChEBIOverXDeepSMILES, ChEBIOver100):
"""
Expand Down
30 changes: 1 addition & 29 deletions chebai/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def processed_file_names(self) -> List[str]:
"""
return ["test.pt", "train.pt", "validation.pt"]

def prepare_data(self, *args, **kwargs):
def _perform_data_preparation(self, *args, **kwargs):
"""
Checks for raw data and downloads if necessary.
"""
Expand Down Expand Up @@ -692,13 +692,6 @@ class PubchemChem(PubChem):

READER: Type[dr.ChemDataReader] = dr.ChemDataReader

@property
def label_number(self) -> int:
"""
Returns the label number.
"""
return -1


class PubchemBPE(PubChem):
"""
Expand All @@ -712,13 +705,6 @@ class PubchemBPE(PubChem):

READER: Type[dr.ChemBPEReader] = dr.ChemBPEReader

@property
def label_number(self) -> int:
"""
Returns the label number.
"""
return -1


class SWJChem(SWJPreChem):
"""
Expand All @@ -732,13 +718,6 @@ class SWJChem(SWJPreChem):

READER: Type[dr.ChemDataUnlabeledReader] = dr.ChemDataUnlabeledReader

@property
def label_number(self) -> int:
"""
Returns the label number.
"""
return -1


class SWJBPE(SWJPreChem):
"""
Expand All @@ -752,13 +731,6 @@ class SWJBPE(SWJPreChem):

READER: Type[dr.ChemBPEReader] = dr.ChemBPEReader

@property
def label_number(self) -> int:
"""
Returns the label number.
"""
return -1


class PubChemTokens(PubChem):
"""
Expand Down
Loading