diff --git a/chebai/cli.py b/chebai/cli.py index b7e78d17..61a1da8a 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -1,7 +1,8 @@ -from typing import Dict, Set +from typing import Dict, Set, Type from lightning.pytorch.cli import LightningArgumentParser, LightningCLI +from chebai.preprocessing.datasets import XYBaseDataModule from chebai.trainer.CustomTrainer import CustomTrainer @@ -38,14 +39,35 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): Args: parser (LightningArgumentParser): Argument parser instance. """ + + def call_data_methods(data: Type[XYBaseDataModule]): + if data._num_of_labels is None: + data.prepare_data() + data.setup() + return data.num_of_labels + + parser.link_arguments( + "data", + "model.init_args.out_dim", + apply_on="instantiate", + compute_fn=call_data_methods, + ) + + parser.link_arguments( + "data.feature_vector_size", + "model.init_args.input_dim", + apply_on="instantiate", + ) + for kind in ("train", "val", "test"): for average in ("micro-f1", "macro-f1", "balanced-accuracy"): parser.link_arguments( - "model.init_args.out_dim", + "data.num_of_labels", f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels", + apply_on="instantiate", ) parser.link_arguments( - "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" + "data.num_of_labels", "trainer.callbacks.init_args.num_labels" ) @staticmethod diff --git a/chebai/models/base.py b/chebai/models/base.py index 4ba27bbc..fd02c6ce 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -36,6 +36,7 @@ def __init__( self, criterion: torch.nn.Module = None, out_dim: Optional[int] = None, + input_dim: Optional[int] = None, train_metrics: Optional[torch.nn.Module] = None, val_metrics: Optional[torch.nn.Module] = None, test_metrics: Optional[torch.nn.Module] = None, @@ -57,7 +58,12 @@ def __init__( *exclude_hyperparameter_logging, ] ) + self.out_dim = out_dim + self.input_dim = input_dim + assert out_dim is not None, "out_dim must be specified" + assert input_dim is not None, "input_dim must be specified" + if optimizer_kwargs: self.optimizer_kwargs = optimizer_kwargs else: diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 39e5fbec..5cd210be 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -119,6 +119,23 @@ def __init__( os.makedirs(os.path.join(self.processed_dir, self.fold_dir), exist_ok=True) self.save_hyperparameters() + self._num_of_labels = None + self._feature_vector_size = None + self._prepare_data_flag = 1 + self._setup_data_flag = 1 + + @property + def num_of_labels(self): + assert self._num_of_labels is not None, "num of labels must be set" + return self._num_of_labels + + @property + def feature_vector_size(self): + assert ( + self._feature_vector_size is not None + ), "size of feature vector must be set" + return self._feature_vector_size + @property def identifier(self) -> tuple: """Identifier for the dataset.""" @@ -390,7 +407,17 @@ def predict_dataloader( """ return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) - def setup(self, **kwargs): + def prepare_data(self, *args, **kwargs) -> None: + if self._prepare_data_flag != 1: + return + + self._prepare_data_flag += 1 + self._perform_data_preparation(*args, **kwargs) + + def _perform_data_preparation(self, *args, **kwargs) -> None: + raise NotImplementedError + + def setup(self, *args, **kwargs) -> None: """ Setup the data module. @@ -399,6 +426,11 @@ def setup(self, **kwargs): Args: **kwargs: Additional keyword arguments. """ + if self._setup_data_flag != 1: + return + + self._setup_data_flag += 1 + rank_zero_info(f"Check for processed data in {self.processed_dir}") rank_zero_info(f"Cross-validation enabled: {self.use_inner_cross_validation}") if any( @@ -410,6 +442,21 @@ def setup(self, **kwargs): if not ("keep_reader" in kwargs and kwargs["keep_reader"]): self.reader.on_finish() + self._set_processed_data_props() + + def _set_processed_data_props(self): + + data_pt = torch.load( + os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), + weights_only=False, + ) + + self._num_of_labels = len(data_pt[0]["labels"]) + self._feature_vector_size = max(len(d["features"]) for d in data_pt) + + print(f"Number of labels for loaded data: {self._num_of_labels}") + print(f"Feature vector size: {self._feature_vector_size}") + def setup_processed(self): """ Setup the processed data. @@ -482,18 +529,6 @@ def raw_file_names_dict(self) -> dict: """ raise NotImplementedError - @property - def label_number(self) -> int: - """ - Returns the number of labels. - - This property should be implemented by subclasses to provide the number of labels. - - Returns: - int: The number of labels. Returns -1 for seq2seq encoding. - """ - raise NotImplementedError - class MergedDataset(XYBaseDataModule): MERGED = [] @@ -531,7 +566,7 @@ def __init__( os.makedirs(self.processed_dir, exist_ok=True) super(pl.LightningDataModule, self).__init__(**kwargs) - def prepare_data(self): + def _perform_data_preparation(self): """ Placeholder for data preparation logic. """ @@ -547,9 +582,15 @@ def setup(self, **kwargs): Args: **kwargs: Additional keyword arguments. """ + if self._setup_data_flag != 1: + return + + self._setup_data_flag += 1 for s in self.subsets: s.setup(**kwargs) + self._set_processed_data_props() + def dataloader(self, kind: str, **kwargs) -> DataLoader: """ Creates a DataLoader for a specific subset. @@ -623,13 +664,6 @@ def processed_file_names(self) -> List[str]: """ return ["test.pt", "train.pt", "validation.pt"] - @property - def label_number(self) -> int: - """ - Returns the number of labels from the first subset. - """ - return self.subsets[0].label_number - @property def limits(self): """ @@ -725,7 +759,7 @@ def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str] return splits_file_path # ------------------------------ Phase: Prepare data ----------------------------------- - def prepare_data(self, *args: Any, **kwargs: Any) -> None: + def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: """ Prepares the data for the dataset. diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index d927a44c..d3387a05 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -59,7 +59,7 @@ def download(self): def raw_file_names(self): return ["test.pkl", "train.pkl", "validation.pkl"] - def prepare_data(self, *args, **kwargs): + def _perform_data_preparation(self, *args, **kwargs): print("Check for raw data in", self.raw_dir) if any( not os.path.isfile(os.path.join(self.raw_dir, f)) @@ -88,10 +88,6 @@ def setup_processed(self): os.path.join(self.processed_dir, f"{k}.pt"), ) - @property - def label_number(self): - return 500 - class JCIData(JCIBase): READER = dr.OrdReader @@ -158,7 +154,7 @@ def __init__( ) # ------------------------------ Phase: Prepare data ----------------------------------- - def prepare_data(self, *args: Any, **kwargs: Any) -> None: + def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: """ Prepares the data for the Chebi dataset. @@ -179,7 +175,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: Returns: None """ - super().prepare_data(args, kwargs) + super()._perform_data_preparation(args, kwargs) if self.chebi_version_train is not None: if not os.path.isfile( @@ -545,10 +541,6 @@ def raw_file_names_dict(self) -> dict: class JCIExtendedBase(_ChEBIDataExtractor): - @property - def label_number(self): - return 500 - @property def _name(self): return "JCI_extended" @@ -573,16 +565,6 @@ class ChEBIOverX(_ChEBIDataExtractor): READER: dr.ChemDataReader = dr.ChemDataReader THRESHOLD: int = None - @property - def label_number(self) -> int: - """ - Returns the number of labels in the dataset. - - Returns: - int: The number of labels. - """ - return 854 - @property def _name(self) -> str: """ @@ -675,17 +657,6 @@ class ChEBIOver100(ChEBIOverX): THRESHOLD: int = 100 - def label_number(self) -> int: - """ - Returns the number of labels in the dataset. - - Overrides the base class method to return the correct number of labels for this threshold. - - Returns: - int: The number of labels. - """ - return 854 - class ChEBIOver50(ChEBIOverX): """ @@ -699,17 +670,6 @@ class ChEBIOver50(ChEBIOverX): THRESHOLD: int = 50 - def label_number(self) -> int: - """ - Returns the number of labels in the dataset. - - Overrides the base class method to return the correct number of labels for this threshold. - - Returns: - int: The number of labels. - """ - return 1332 - class ChEBIOver100DeepSMILES(ChEBIOverXDeepSMILES, ChEBIOver100): """ diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 9e43302a..f6f2cdb3 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -179,7 +179,7 @@ def processed_file_names(self) -> List[str]: """ return ["test.pt", "train.pt", "validation.pt"] - def prepare_data(self, *args, **kwargs): + def _perform_data_preparation(self, *args, **kwargs): """ Checks for raw data and downloads if necessary. """ @@ -692,13 +692,6 @@ class PubchemChem(PubChem): READER: Type[dr.ChemDataReader] = dr.ChemDataReader - @property - def label_number(self) -> int: - """ - Returns the label number. - """ - return -1 - class PubchemBPE(PubChem): """ @@ -712,13 +705,6 @@ class PubchemBPE(PubChem): READER: Type[dr.ChemBPEReader] = dr.ChemBPEReader - @property - def label_number(self) -> int: - """ - Returns the label number. - """ - return -1 - class SWJChem(SWJPreChem): """ @@ -732,13 +718,6 @@ class SWJChem(SWJPreChem): READER: Type[dr.ChemDataUnlabeledReader] = dr.ChemDataUnlabeledReader - @property - def label_number(self) -> int: - """ - Returns the label number. - """ - return -1 - class SWJBPE(SWJPreChem): """ @@ -752,13 +731,6 @@ class SWJBPE(SWJPreChem): READER: Type[dr.ChemBPEReader] = dr.ChemBPEReader - @property - def label_number(self) -> int: - """ - Returns the label number. - """ - return -1 - class PubChemTokens(PubChem): """ diff --git a/chebai/preprocessing/datasets/tox21.py b/chebai/preprocessing/datasets/tox21.py index 4bdfbdee..95c60cdd 100644 --- a/chebai/preprocessing/datasets/tox21.py +++ b/chebai/preprocessing/datasets/tox21.py @@ -39,11 +39,6 @@ def _name(self) -> str: """Returns the name of the dataset.""" return "Tox21MN" - @property - def label_number(self) -> int: - """Returns the number of labels.""" - return 12 - @property def raw_file_names(self) -> List[str]: """Returns a list of raw file names.""" @@ -118,6 +113,10 @@ def setup_processed(self) -> None: def setup(self, **kwargs) -> None: """Sets up the dataset by downloading and processing if necessary.""" + if self._setup_data_flag != 1: + return + + self._setup_data_flag += 1 if any( not os.path.isfile(os.path.join(self.raw_dir, f)) for f in self.raw_file_names @@ -129,6 +128,8 @@ def setup(self, **kwargs) -> None: ): self.setup_processed() + self._set_processed_data_props() + def _load_data_from_file(self, input_file_path: str) -> List[Dict]: """Loads data from a CSV file. @@ -171,11 +172,6 @@ def _name(self) -> str: """Returns the name of the dataset.""" return "Tox21Chal" - @property - def label_number(self) -> int: - """Returns the number of labels.""" - return 12 - @property def raw_file_names(self) -> List[str]: """Returns a list of raw file names.""" @@ -300,6 +296,10 @@ def setup_processed(self) -> None: def setup(self, **kwargs) -> None: """Sets up the dataset by downloading and processing if necessary.""" + if self._setup_data_flag != 1: + return + + self._setup_data_flag += 1 if any( not os.path.isfile(os.path.join(self.raw_dir, f)) for f in self.raw_file_names @@ -311,6 +311,8 @@ def setup(self, **kwargs) -> None: ): self.setup_processed() + self._set_processed_data_props() + def _load_dict(self, input_file_path: str) -> Generator[Dict, None, None]: """Loads data from a CSV file as a generator.