diff --git a/configs/dataset/hypergraph/conjugated_ocelotv1.yaml b/configs/dataset/hypergraph/conjugated_ocelotv1.yaml new file mode 100644 index 000000000..4c0daaecd --- /dev/null +++ b/configs/dataset/hypergraph/conjugated_ocelotv1.yaml @@ -0,0 +1,35 @@ +# Dataset loader config +loader: + _target_: topobench.data.loaders.ConjugatedMoleculeDatasetLoader + parameters: + data_domain: hypergraph + data_type: conjugated_molecules + data_name: OCELOTv1 + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + target_col: ${dataset.parameters.target_col} + task: ${dataset.parameters.task} + +# Dataset parameters +parameters: + num_features: 9 # OGB atom features + num_classes: 1 # Single target regression + target_col: 0 # Use the first target column + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + +#splits +split_params: + learning_setting: inductive + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + split_type: random # OCELOTv1 doesn't have predefined splits + k: 10 # for k-fold Cross-Validation + train_prop: 0.8 # for random strategy splitting + +# Dataloader parameters +dataloader_params: + batch_size: 64 + num_workers: 4 + pin_memory: True diff --git a/configs/dataset/hypergraph/conjugated_opv_train.yaml b/configs/dataset/hypergraph/conjugated_opv_train.yaml new file mode 100644 index 000000000..4ae97432f --- /dev/null +++ b/configs/dataset/hypergraph/conjugated_opv_train.yaml @@ -0,0 +1,36 @@ +# Dataset loader config - OPV train split +loader: + _target_: topobench.data.loaders.ConjugatedMoleculeDatasetLoader + parameters: + data_domain: hypergraph + data_type: conjugated_molecules + data_name: OPV + split: train # Options: train, valid, test + task: default # Options: default, polymer + target_col: ${dataset.parameters.target_col} + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + +# Dataset parameters +parameters: + num_features: 9 # OGB atom features + num_classes: 1 + target_col: 0 + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + +#splits +split_params: + learning_setting: inductive + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + split_type: fixed # OPV has predefined train/valid/test splits + k: 10 # not used for fixed splits + train_prop: 0.8 # not used for fixed splits + +# Dataloader parameters +dataloader_params: + batch_size: 64 + num_workers: 4 + pin_memory: True diff --git a/configs/dataset/hypergraph/conjugated_opv_train_polymer.yaml b/configs/dataset/hypergraph/conjugated_opv_train_polymer.yaml new file mode 100644 index 000000000..07f83ee16 --- /dev/null +++ b/configs/dataset/hypergraph/conjugated_opv_train_polymer.yaml @@ -0,0 +1,36 @@ +# Dataset loader config - OPV train split (polymer task) +loader: + _target_: topobench.data.loaders.ConjugatedMoleculeDatasetLoader + parameters: + data_domain: hypergraph + data_type: conjugated_molecules + data_name: OPV + split: train + task: polymer # Filters molecules with complete extrapolated properties + target_col: ${dataset.parameters.target_col} + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + +# Dataset parameters +parameters: + num_features: 9 # OGB atom features + num_classes: 1 + target_col: 0 + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + +#splits +split_params: + learning_setting: inductive + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + split_type: fixed + k: 10 + train_prop: 0.8 + +# Dataloader parameters +dataloader_params: + batch_size: 64 + num_workers: 4 + pin_memory: True diff --git a/configs/dataset/hypergraph/conjugated_pcqm4mv2.yaml b/configs/dataset/hypergraph/conjugated_pcqm4mv2.yaml new file mode 100644 index 000000000..1315091f8 --- /dev/null +++ b/configs/dataset/hypergraph/conjugated_pcqm4mv2.yaml @@ -0,0 +1,32 @@ +# Dataset loader config - PCQM4Mv2 +loader: + _target_: topobench.data.loaders.ConjugatedMoleculeDatasetLoader + parameters: + data_domain: hypergraph + data_type: conjugated_molecules + data_name: PCQM4MV2 + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + +# Dataset parameters +parameters: + num_features: 9 # OGB atom features + num_classes: 1 # Single target: homolumogap + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + +#splits +split_params: + learning_setting: inductive + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + split_type: fixed # PCQM4Mv2 has predefined splits + k: 10 # not used for fixed splits + train_prop: 0.8 # not used for fixed splits + +# Dataloader parameters +dataloader_params: + batch_size: 64 + num_workers: 4 + pin_memory: True diff --git a/pyproject.toml b/pyproject.toml index 3234ea9e6..2b20c98a4 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,11 +6,11 @@ build-backend = "setuptools.build_meta" name = "TopoBench" dynamic = ["version"] authors = [ - {name = "Topological Intelligence Team Authors", email = "tlscabinet@gmail.com"} + { name = "Topological Intelligence Team Authors", email = "tlscabinet@gmail.com" }, ] readme = "README.md" description = "Topological Deep Learning" -license = {file = "LICENSE.txt"} +license = { file = "LICENSE.txt" } classifiers = [ "License :: OSI Approved :: MIT License", "Development Status :: 4 - Beta", @@ -21,10 +21,10 @@ classifiers = [ "Natural Language :: English", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11" + "Programming Language :: Python :: 3.11", ] -requires-python = ">= 3.10" -dependencies=[ +requires-python = "== 3.11.3" +dependencies = [ "tqdm", "charset-normalizer", "numpy", @@ -54,6 +54,7 @@ dependencies=[ "topomodelx @ git+https://github.com/pyt-team/TopoModelX.git", "toponetx @ git+https://github.com/pyt-team/TopoNetX.git", "lightning==2.4.0", + "rdkit", ] [project.optional-dependencies] @@ -65,30 +66,20 @@ doc = [ "sphinx", "sphinx_gallery", "pydata-sphinx-theme", - "myst_parser" -] -lint = [ - "pre-commit", - "ruff" -] -test = [ - "pytest", - "pytest-cov", - "coverage", - "jupyter", - "mypy", - "pytest-mock" + "myst_parser", ] +lint = ["pre-commit", "ruff"] +test = ["pytest", "pytest-cov", "coverage", "jupyter", "mypy", "pytest-mock"] dev = ["TopoBench[test, lint]"] all = ["TopoBench[dev, doc]"] [project.urls] -homepage="https://geometric-intelligence.github.io/topobench/index.html" -repository="https://github.com/geometric-intelligence/TopoBench" +homepage = "https://geometric-intelligence.github.io/topobench/index.html" +repository = "https://github.com/geometric-intelligence/TopoBench" [tool.black] -line-length = 79 # PEP 8 standard for maximum line length +line-length = 79 # PEP 8 standard for maximum line length target-version = ['py310'] [tool.docformatter] @@ -99,35 +90,35 @@ wrap-descriptions = 79 target-version = "py310" #extend-include = ["*.ipynb"] extend-exclude = ["test", "tutorials", "notebooks"] -line-length = 79 # PEP 8 standard for maximum line length +line-length = 79 # PEP 8 standard for maximum line length [tool.ruff.format] docstring-code-format = false [tool.ruff.lint] select = [ - "F", # pyflakes errors - "E", # code style - "W", # warnings - "I", # import order - "UP", # pyupgrade rules - "B", # bugbear rules - "PIE", # pie rules - "Q", # quote rules - "RET", # return rules - "SIM", # code simplifications - "NPY", # numpy rules + "F", # pyflakes errors + "E", # code style + "W", # warnings + "I", # import order + "UP", # pyupgrade rules + "B", # bugbear rules + "PIE", # pie rules + "Q", # quote rules + "RET", # return rules + "SIM", # code simplifications + "NPY", # numpy rules "PERF", # performance rules ] fixable = ["ALL"] ignore = [ - "E501", # line too long - "RET504", # Unnecessary assignment before return - "RET505", # Unnecessary `elif` after `return` statement - "NPY002", # Replace legacy `np.random.seed` call with `np.random.Generator` - "UP038", # Use `X | Y` in `isinstance` call instead of `(X, Y)` -- not compatible with python 3.9 (even with __future__ import) - "W293", # Does not allow to have empty lines in multiline comments - "PERF203", # [TODO: fix all such issues] `try`-`except` within a loop incurs performance overhead + "E501", # line too long + "RET504", # Unnecessary assignment before return + "RET505", # Unnecessary `elif` after `return` statement + "NPY002", # Replace legacy `np.random.seed` call with `np.random.Generator` + "UP038", # Use `X | Y` in `isinstance` call instead of `(X, Y)` -- not compatible with python 3.9 (even with __future__ import) + "W293", # Does not allow to have empty lines in multiline comments + "PERF203", # [TODO: fix all such issues] `try`-`except` within a loop incurs performance overhead ] [tool.ruff.lint.pydocstyle] @@ -138,13 +129,10 @@ convention = "numpy" "tests/*" = ["D"] [tool.setuptools.dynamic] -version = {attr = "topobench.__version__"} +version = { attr = "topobench.__version__" } [tool.setuptools.packages.find] -include = [ - "topobench", - "topobench.*" -] +include = ["topobench", "topobench.*"] [tool.mypy] warn_redundant_casts = true @@ -155,26 +143,18 @@ plugins = "numpy.typing.mypy_plugin" [[tool.mypy.overrides]] module = [ - "torch_cluster.*","networkx.*","scipy.spatial","scipy.sparse","toponetx.classes.simplicial_complex" + "torch_cluster.*", + "networkx.*", + "scipy.spatial", + "scipy.sparse", + "toponetx.classes.simplicial_complex", ] ignore_missing_imports = true [tool.pytest.ini_options] addopts = "--capture=no" -pythonpath = [ - "." -] +pythonpath = ["."] [tool.numpydoc_validation] -checks = [ - "all", - "GL01", - "ES01", - "EX01", - "SA01" -] -exclude = [ - '\.undocumented_method$', - '\.__init__$', - '\.__repr__$', -] +checks = ["all", "GL01", "ES01", "EX01", "SA01"] +exclude = ['\.undocumented_method$', '\.__init__$', '\.__repr__$'] diff --git a/test/data/load/test_datasetloaders.py b/test/data/load/test_datasetloaders.py index cb21fd421..0d0eee4c8 100644 --- a/test/data/load/test_datasetloaders.py +++ b/test/data/load/test_datasetloaders.py @@ -45,7 +45,7 @@ def _gather_config_files(self, base_dir: Path) -> List[str]: } # Below the datasets that takes quite some time to load and process - self.long_running_datasets = {"mantra_name.yaml", "mantra_orientation.yaml", "mantra_genus.yaml", "mantra_betti_numbers.yaml"} + self.long_running_datasets = {"mantra_name.yaml", "mantra_orientation.yaml", "mantra_genus.yaml", "mantra_betti_numbers.yaml", "conjugated_pcqm4mv2.yaml"} for dir_path in config_base_dir.iterdir(): diff --git a/test/pipeline/test_pipeline.py b/test/pipeline/test_pipeline.py index 785987159..6b2dcd462 100644 --- a/test/pipeline/test_pipeline.py +++ b/test/pipeline/test_pipeline.py @@ -4,8 +4,8 @@ from test._utils.simplified_pipeline import run -DATASET = "graph/MUTAG" # ADD YOUR DATASET HERE -MODELS = ["graph/gcn", "cell/topotune", "simplicial/topotune"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE +DATASET = "hypergraph/conjugated_ocelotv1" +MODELS = ["hypergraph/edgnn"] class TestPipeline: @@ -23,7 +23,7 @@ def test_pipeline(self): config_name="run.yaml", overrides=[ f"model={MODEL}", - f"dataset={DATASET}", # IF YOU IMPLEMENT A LARGE DATASET WITH AN OPTION TO USE A SLICE OF IT, ADD BELOW THE CORRESPONDING OPTION + f"dataset={DATASET}", "trainer.max_epochs=2", "trainer.min_epochs=1", "trainer.check_val_every_n_epoch=1", @@ -32,4 +32,6 @@ def test_pipeline(self): ], return_hydra_config=True ) + + print(cfg) run(cfg) \ No newline at end of file diff --git a/topobench/data/datasets/conjugated_molecule_datasets.py b/topobench/data/datasets/conjugated_molecule_datasets.py new file mode 100644 index 000000000..51662e04e --- /dev/null +++ b/topobench/data/datasets/conjugated_molecule_datasets.py @@ -0,0 +1,489 @@ +"""Dataset class for conjugated molecular structures.""" + +import os.path as osp +from collections.abc import Callable + +import pandas as pd +import torch +from torch_geometric.data import ( + Data, + InMemoryDataset, + extract_zip, +) + +from topobench.data.utils import ( + download_file_from_link, +) +from topobench.data.utils.conjugated_utils import ( + get_hypergraph_data_from_smiles, +) + + +class ConjugatedMoleculeDataset(InMemoryDataset): + """Dataset class for conjugated molecular structures. + + Parameters + ---------- + root : str + Root directory where the dataset will be saved. + name : str, optional + Name of the dataset. Default is "conjugated_molecules". + split : str, optional + Split of the dataset (e.g., "train", "valid", "test"). Only used for OPV. + task : str, optional + Task type. Default is "default". For OPV, can be "polymer" to filter + molecules with extrapolated properties. + transform : Callable, optional + A function/transform that takes in an :obj:`torch_geometric.data.Data` object + and returns a transformed version. The data object will be transformed before + every access. + pre_transform : Callable, optional + A function/transform that takes in an :obj:`torch_geometric.data.Data` object + and returns a transformed version. The data object will be transformed before + being saved to disk. + pre_filter : Callable, optional + A function that takes in an :obj:`torch_geometric.data.Data` object and + returns a boolean value, indicating whether the data object should be + included in the final dataset. + slice : int, optional + Number of samples to slice from the dataset. Useful for testing. + target_col : int | str | None, optional + Index or name of the target column to use. If None, all available targets are used. + **kwargs : optional + Additional keyword arguments passed to InMemoryDataset. + Common options include: + - force_reload: bool, whether to re-download and re-process the dataset. + """ + + URLS = { + "OCELOTv1": "https://data.materialsdatafacility.org/mdf_open/ocelot_chromophore_v1_v1.1/ocelot_chromophore_v1.csv", + "OPV": { + "train": "https://data.nrel.gov/system/files/236/1712697052-smiles_train.csv.gz", + "valid": "https://data.nrel.gov/system/files/236/1712697052-smiles_valid.csv.gz", + "test": "https://data.nrel.gov/system/files/236/1712697052-smiles_test.csv.gz", + }, + "PCQM4MV2": "https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m-v2.zip", + } + + def __init__( + self, + root: str, + name: str, + split: str | None = None, + task: str = "default", + transform: Callable | None = None, + pre_transform: Callable | None = None, + pre_filter: Callable | None = None, + slice: int | None = None, + target_col: int | str | None = None, + **kwargs, + ): + if name not in self.URLS: + raise ValueError(f"Unknown dataset name: {name}") + self.name = name + self.split = split + self.task = task + self.slice = slice + self.target_col = target_col + super().__init__( + root, + transform, + pre_transform, + pre_filter, + **kwargs, + ) + self.data, self.slices = torch.load( + self.processed_paths[0], weights_only=False + ) + + def mean(self, target: int = 0) -> float: + """Calculate mean of a specific target across the dataset. + + Parameters + ---------- + target : int + Index of the target to calculate mean for. Default is 0. + + Returns + ------- + float + Mean value of the specified target. + """ + y = torch.cat([self.get(i).y for i in range(len(self))], dim=0) + if y.dim() == 1: + return y.mean().item() + return y[:, target].mean().item() + + def std(self, target: int = 0) -> float: + """Calculate standard deviation of a specific target across the dataset. + + Parameters + ---------- + target : int + Index of the target to calculate std for. Default is 0. + + Returns + ------- + float + Standard deviation of the specified target. + """ + y = torch.cat([self.get(i).y for i in range(len(self))], dim=0) + if y.dim() == 1: + return y.std().item() + return y[:, target].std().item() + + @property + def raw_dir(self) -> str: + """Return the raw directory. + + Returns + ------- + str + Path to the raw directory. + """ + if self.name == "OPV" and self.split: + return osp.join(self.root, self.name, "raw", self.split) + return osp.join(self.root, self.name, "raw") + + @property + def processed_dir(self) -> str: + """Return the processed directory. + + Returns + ------- + str + Path to the processed directory. + """ + if self.name == "OPV" and self.split: + return osp.join(self.root, self.name, "processed", self.split) + return osp.join(self.root, self.name, "processed") + + @property + def raw_file_names(self) -> list[str]: + """Return the raw file names. + + Returns + ------- + list[str] + List of raw file names. + """ + if self.name == "OCELOTv1": + return ["ocelot_chromophore_v1.csv"] + if self.name == "OPV": + if self.split == "train": + return ["1712697052-smiles_train.csv.gz"] + if self.split == "valid": + return ["1712697052-smiles_valid.csv.gz"] + if self.split == "test": + return ["1712697052-smiles_test.csv.gz"] + if self.name == "PCQM4MV2": + return ["pcqm4m-v2/raw/data.csv.gz"] # Extracted path + return ["merged_data.csv"] + + @property + def processed_file_names(self) -> str: + """Return the processed file name. + + Returns + ------- + str + Name of the processed file. + """ + suffix = "" + if self.target_col is not None: + suffix += f"_target_{self.target_col}" + if self.slice is not None: + suffix += f"_slice_{self.slice}" + return f"data{suffix}.pt" + + def download(self): + """Download the dataset.""" + if self.name == "OCELOTv1": + download_file_from_link( + self.URLS["OCELOTv1"], + self.raw_dir, + dataset_name="ocelot_chromophore_v1", + file_format="csv", + ) + elif self.name == "OPV": + if self.split: + download_file_from_link( + self.URLS["OPV"][self.split], + self.raw_dir, + dataset_name=f"1712697052-smiles_{self.split}", + file_format="csv.gz", + ) + else: + for split_name in self.URLS["OPV"]: + download_file_from_link( + self.URLS["OPV"][split_name], + osp.join(self.root, self.name, "raw", split_name), + dataset_name=f"1712697052-smiles_{split_name}", + file_format="csv.gz", + ) + + elif self.name == "PCQM4MV2": + path = osp.join(self.raw_dir, "PCQM4MV2.zip") + download_file_from_link( + self.URLS["PCQM4MV2"], + self.raw_dir, + dataset_name="PCQM4MV2", + file_format="zip", + # file_format="csv", # Fix: The URL is a zip file + ) + extract_zip(path, self.raw_dir) + # The zip extracts to pcqm4m-v2 folder + else: + # Placeholder for user provided data + if not osp.exists(osp.join(self.raw_dir, "merged_data.csv")): + print( + f"Please place 'merged_data.csv' in {self.raw_dir}. " + "This file should contain a 'ready_SMILES' column." + ) + + def process(self): + """Convert data from raw files and save to disk.""" + + data_list = [] + + if self.name == "OCELOTv1": + raw_path = osp.join(self.raw_dir, "ocelot_chromophore_v1.csv") + df = pd.read_csv(raw_path) + smiles_col = "smiles" + elif self.name == "OPV": + # Filename depends on split + filename = self.raw_file_names[0] + raw_path = osp.join(self.raw_dir, filename) + df = pd.read_csv(raw_path) + smiles_col = "smiles" # Assuming standard + + # Polymer task: filter molecules with complete extrapolated properties + if self.task == "polymer": + df = df.dropna(subset=["gap_extrapolated"]) + print( + f"Polymer task: filtered to {len(df)} molecules with gap_extrapolated values" + ) + elif self.name == "PCQM4MV2": + # The extracted file is likely in a subdir + raw_path = osp.join( + self.raw_dir, "pcqm4m-v2", "raw", "data.csv.gz" + ) + df = pd.read_csv(raw_path) + smiles_col = "smiles" + else: + raw_path = osp.join(self.raw_dir, "merged_data.csv") + df = pd.read_csv(raw_path) + smiles_col = "ready_SMILES" + + if not osp.exists(raw_path): + raise FileNotFoundError(f"File not found: {raw_path}") + + if smiles_col not in df.columns: + # Fallback or check for other common names if needed + if "SMILES" in df.columns: + smiles_col = "SMILES" + elif "smile" in df.columns: # OPV uses singular 'smile' + smiles_col = "smile" + elif "ready_SMILES" in df.columns: + smiles_col = "ready_SMILES" + else: + raise ValueError( + f"CSV file must contain '{smiles_col}' column. Found: {df.columns}" + ) + + smiles_list = df[smiles_col].tolist() + + if self.slice is not None: + smiles_list = smiles_list[: self.slice] + + for idx, smiles in enumerate(smiles_list): + try: + atom_fvs, incidence_list, bond_fvs = ( + get_hypergraph_data_from_smiles(smiles) + ) + except (TypeError, ValueError, AttributeError): + continue + + if not incidence_list: + continue + + num_nodes = len(atom_fvs) + # incidence_matrix = create_incidence_matrix( + # incidence_list, num_nodes + # ) + + # Convert to tensors + x = torch.tensor(atom_fvs, dtype=torch.float) + + # Create edge_index from incidence list + # incidence_list is list of lists of node indices + sources = [] + targets = [] + for edge_idx, nodes in enumerate(incidence_list): + for node_idx in nodes: + sources.append(node_idx) + targets.append(edge_idx) + + edge_index = torch.tensor([sources, targets], dtype=torch.long) + + # Calculate edge order (hyperedge cardinality) + e_order = torch.tensor( + [len(nodes) for nodes in incidence_list], + dtype=torch.long, + ) + + # Hyperedge features (bond features) + hyperedge_attr = list(bond_fvs) + + hyperedge_attr = torch.tensor(hyperedge_attr, dtype=torch.float) + + # Incidence matrix as sparse tensor + incidence_hyperedges = torch.sparse_coo_tensor( + edge_index, + torch.ones(edge_index.shape[1]), + size=(num_nodes, len(incidence_list)), + ) + + # Create base data object + data = Data( + x=x, + edge_index=edge_index, + hyperedge_attr=hyperedge_attr, + incidence_hyperedges=incidence_hyperedges, + num_nodes=num_nodes, + num_hyperedges=len(incidence_list), + e_order=e_order, # Edge order tracking + smi=smiles, # Store SMILES string + ) + + # Extract target labels based on dataset + if self.name == "OPV": + # OPV has 8 regression targets (columns 2-9) + if self.target_col is not None: + if isinstance(self.target_col, int): + target = df.iloc[idx, 2 + self.target_col] + else: + target = df.iloc[idx][self.target_col] + # Use scalar tensor for single task so batching gives [N] + y = torch.tensor(target, dtype=torch.float) + else: + target = df.iloc[idx, 2:].values.astype(float) + y = torch.tensor(target, dtype=torch.float) + data.y = y + elif self.name == "PCQM4MV2": + # PCQM4MV2 has single target: homolumogap + if "homolumogap" in df.columns: + # Use scalar tensor for single task + y = torch.tensor( + df.loc[idx, "homolumogap"], dtype=torch.float + ) + data.y = y + elif self.name == "OCELOTv1": + # OCELOTv1 - extract targets from columns after identifier and smiles + # Skip non-numeric columns (identifier, smiles) + numeric_cols = [] + for col in df.columns: + if col not in ["identifier", "smiles", "smile"]: + # Try to convert to numeric + try: + pd.to_numeric(df[col].iloc[0]) + numeric_cols.append(col) + except (ValueError, TypeError): + pass + + if len(numeric_cols) > 0: + if self.target_col is not None: + if isinstance(self.target_col, int): + target_val = df.iloc[idx][ + numeric_cols[self.target_col] + ] + elif self.target_col in numeric_cols: + target_val = df.iloc[idx][self.target_col] + else: + raise ValueError( + f"Target column {self.target_col} not found in numeric columns" + ) + + # Use scalar tensor for single task + y = torch.tensor(target_val, dtype=torch.float) + else: + target = df.iloc[idx][numeric_cols].values.astype( + float + ) + y = torch.tensor(target, dtype=torch.float) + data.y = y + + if self.split: + data.split = self.split + + if self.pre_filter is not None and not self.pre_filter(data): + continue + + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + torch.save(self.collate(data_list), self.processed_paths[0]) + + +if __name__ == "__main__": + """ + Run this script to test the ConjugatedMoleculeDataset. + `uv run python -m topobench.data.datasets.conjugated_molecule_datasets` + """ + import rootutils + + root = rootutils.setup_root( + search_from=".", + indicator="pyproject.toml", + pythonpath=True, + cwd=True, + ) + + print("=" * 60) + print("ConjugatedMoleculeDataset Verification Script") + print("=" * 60) + + results = [] + + print("Testing OCELOTv1 dataset...") + try: + dataset = ConjugatedMoleculeDataset( + root="datasets/hypergraph/conjugated_molecules", + name="OCELOTv1", + ) + print( + f"✓ OCELOTv1 dataset loaded successfully with {len(dataset)} samples" + ) + print(f" First sample: {dataset[0]}") + except Exception as e: + print(f"✗ OCELOTv1 dataset failed: {e}") + raise + + print("Testing OPV dataset...") + try: + dataset = ConjugatedMoleculeDataset( + root="datasets/hypergraph/conjugated_molecules", + name="OPV", + split="train", # OPV requires a split + ) + print(f"✓ OPV dataset loaded successfully with {len(dataset)} samples") + print(f" First sample: {dataset[0]}") + except Exception as e: + print(f"✗ OPV dataset failed: {e}") + raise + + print("Testing PCQM4MV2 dataset...") + try: + dataset = ConjugatedMoleculeDataset( + root="datasets/hypergraph/conjugated_molecules", + name="PCQM4MV2", + ) + print( + f"✓ PCQM4MV2 dataset loaded successfully with {len(dataset)} samples" + ) + print(f" First sample: {dataset[0]}") + except Exception as e: + print(f"✗ PCQM4MV2 dataset failed: {e}") + raise diff --git a/topobench/data/loaders/hypergraph/conjugated_molecule_dataset_loader.py b/topobench/data/loaders/hypergraph/conjugated_molecule_dataset_loader.py new file mode 100644 index 000000000..0b455a497 --- /dev/null +++ b/topobench/data/loaders/hypergraph/conjugated_molecule_dataset_loader.py @@ -0,0 +1,74 @@ +"""Loader for Conjugated Molecule dataset.""" + +from omegaconf import DictConfig + +from topobench.data.datasets import ConjugatedMoleculeDataset +from topobench.data.loaders.base import AbstractLoader + + +class ConjugatedMoleculeDatasetLoader(AbstractLoader): + """Load Conjugated Molecule dataset with configurable parameters. + + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the dataset + - split: Split of the dataset (optional, for OPV) + - other relevant parameters + """ + + def __init__(self, parameters: DictConfig) -> None: + super().__init__(parameters) + + def load_dataset(self, **kwargs) -> ConjugatedMoleculeDataset: + """Load the Conjugated Molecule dataset. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + ConjugatedMoleculeDataset + The loaded Conjugated Molecule dataset. + """ + dataset = self._initialize_dataset(**kwargs) + + # Handle slicing if requested (e.g. for testing long-running datasets) + if "slice" in kwargs: + dataset = dataset[: kwargs["slice"]] + + self.data_dir = self.get_data_dir() + return dataset + + def _initialize_dataset(self, **kwargs) -> ConjugatedMoleculeDataset: + """Initialize the Conjugated Molecule dataset. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + ConjugatedMoleculeDataset + The initialized dataset instance. + """ + # Check if split is in parameters, default to None + split = self.parameters.get("split", None) + + # Check if slice is in kwargs + slice_val = kwargs.get("slice") + + return ConjugatedMoleculeDataset( + root=str(self.root_data_dir), + name=self.parameters.data_name, + split=split, + slice=slice_val, + task=self.parameters.get("task", "default"), + target_col=self.parameters.get("target_col", None), + # Pass other parameters if needed, e.g. transforms from config + ) diff --git a/topobench/data/utils/conjugated_utils.py b/topobench/data/utils/conjugated_utils.py new file mode 100644 index 000000000..88b13f287 --- /dev/null +++ b/topobench/data/utils/conjugated_utils.py @@ -0,0 +1,149 @@ +"""Utilities for conjugated structures dataset.""" + +import numpy as np +from ogb.utils.features import atom_to_feature_vector, bond_to_feature_vector +from rdkit import Chem, RDLogger + + +def contains_conjugated_bond( + mol: Chem.Mol, +) -> tuple[bool, Chem.ResonanceMolSupplier]: + """Check if a molecule contains conjugated bonds. + + Parameters + ---------- + mol : Chem.Mol + Input molecule. + + Returns + ------- + tuple[bool, Chem.ResonanceMolSupplier] + Tuple containing a boolean indicating if the molecule has conjugated bonds + and the ResonanceMolSupplier object. + """ + reso = Chem.ResonanceMolSupplier(mol) + num_he = reso.GetNumConjGrps() + return num_he > 0, reso + + +def he_conj(mol: Chem.Mol) -> list[list]: + """Get incidence list of conjugated structures in a molecule. + + Parameters + ---------- + mol : Chem.Mol + Input molecule. + + Returns + ------- + list[list] + Incidence list of conjugated structures. + """ + num_atom = mol.GetNumAtoms() + reso = Chem.ResonanceMolSupplier(mol) + num_he = reso.GetNumConjGrps() + if num_he == 0: + return [] + + incidence_list: list[list] = [[] for _ in range(num_he)] + for i in range(num_atom): + _conj = reso.GetAtomConjGrpIdx(i) + if _conj > -1 and _conj < num_he: + incidence_list[_conj].append(i) + return incidence_list + + +def edge_order(e_idx): + """Get the order (cardinality) of each edge. + + Parameters + ---------- + e_idx : list + List of edge indices. + + Returns + ------- + list + List containing the count of each edge index. + """ + return [e_idx.count(i) for i in range(len(set(e_idx)))] + + +def get_hypergraph_data_from_smiles( + smiles_string: str, +) -> tuple[list, list[list], list]: + """Convert a SMILES string to hypergraph Data object. + + Parameters + ---------- + smiles_string : str + SMILES string. + + Returns + ------- + tuple[list, list[list], list] + Tuple containing atom feature vectors, incidence list, and bond feature vectors. + """ + RDLogger.DisableLog("rdApp.*") + try: + try: + mol = Chem.MolFromSmiles(smiles_string) + except TypeError as err: + raise TypeError from err + + # atoms + atom_fvs = [atom_to_feature_vector(atom) for atom in mol.GetAtoms()] + + # bonds + num_bond_features = 1 # bond type (single, double, triple, conjugated) + bonds = mol.GetBonds() + if len(bonds) > 0: # mol has bonds + incidence_list: list[list] = [[] for _ in range(len(bonds))] + bond_fvs: list[list] = [[] for _ in range(len(bonds))] + for i, bond in enumerate(bonds): + incidence_list[i] = [ + bond.GetBeginAtomIdx(), + bond.GetEndAtomIdx(), + ] + bond_type = bond_to_feature_vector(bond)[0] + bond_fvs[i] = [bond_type] + + else: # mol has no bonds + incidence_list: list[list] = [] + bond_fvs: list[list] = [] + return (atom_fvs, incidence_list, bond_fvs) + + # hyperedges for conjugated bonds + he_incidence_list = he_conj(mol) # [[3,4,5], [0,1,2]] + if len(he_incidence_list) != 0: + incidence_list.extend( + he_incidence_list + ) # [[0,1], [1,2], [3,4], [[3,4,5], [0,1,2]] + bond_fvs += len(he_incidence_list) * [num_bond_features * [5]] + + return (atom_fvs, incidence_list, bond_fvs) + finally: + RDLogger.EnableLog("rdApp.*") + + +def create_incidence_matrix(incidence_list, num_nodes): + """Create an incidence matrix from an incidence list. + + Parameters + ---------- + incidence_list : list[list] + List of edges, where each edge is a list of node indices. + num_nodes : int + Number of nodes. + + Returns + ------- + np.ndarray + Incidence matrix. + """ + num_edges = len(incidence_list) + incidence_matrix = np.zeros((num_nodes, num_edges), dtype=int) + for edge_idx, nodes in enumerate(incidence_list): + for node in nodes: + incidence_matrix[node, edge_idx] = 1 + return incidence_matrix