diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index 3dd7d57..3226d21 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -184,6 +184,7 @@ def forward(self, batch: dict) -> torch.Tensor: torch.Tensor: Predicted output. """ graph_data = batch["features"][0] + graph_data.to(self.device) assert isinstance(graph_data, GraphData) a = self.gnn(batch) a = scatter_add(a, graph_data.batch, dim=0) diff --git a/chebai_graph/preprocessing/datasets/__init__.py b/chebai_graph/preprocessing/datasets/__init__.py index d13b1c3..8708c28 100644 --- a/chebai_graph/preprocessing/datasets/__init__.py +++ b/chebai_graph/preprocessing/datasets/__init__.py @@ -12,12 +12,14 @@ ChEBI50_WFGE_WGN_GraphProp, ChEBI50GraphData, ChEBI50GraphProperties, + ChEBI100GraphProperties, ) from .pubchem import PubChemGraphProperties __all__ = [ "ChEBI50GraphFGAugmentorReader", "ChEBI50GraphProperties", + "ChEBI100GraphProperties", "ChEBI50GraphData", "PubChemGraphProperties", "ChEBI50_Atom_WGNOnly_GraphProp", diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index c94b772..4ae441a 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -2,6 +2,7 @@ from abc import ABC from collections.abc import Callable from pprint import pformat +from typing import Optional import pandas as pd import torch @@ -281,7 +282,9 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData: molecule_attr=molecule_attr, ) - def load_processed_data_from_file(self, filename: str) -> list[dict]: + def load_processed_data( + self, kind: Optional[str] = None, filename: Optional[str] = None + ) -> list[dict]: """ Load dataset and merge cached properties into base features. @@ -291,7 +294,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: Returns: List of data entries, each a dictionary. """ - base_data = super().load_processed_data_from_file(filename) + base_data = super().load_processed_data(kind, filename) base_df = pd.DataFrame(base_data) for property in self.properties: @@ -379,7 +382,9 @@ def __init__(self, properties=None, transform=None, **kwargs): f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}", ) - def load_processed_data_from_file(self, filename: str) -> list[dict]: + def load_processed_data( + self, kind: Optional[str] = None, filename: Optional[str] = None + ) -> list[dict]: """ Load dataset and merge cached properties into base features. @@ -389,9 +394,8 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: Returns: List of data entries, each a dictionary. """ - base_data = super().load_processed_data_from_file(filename) + base_data = super().load_processed_data(kind, filename) base_df = pd.DataFrame(base_data) - props_categories = { "AllNodeTypeProperties": [], "FGNodeTypeProperties": [], @@ -442,6 +446,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: ) for property in self.properties: + rank_zero_info(f"Loading property {property.name}...") property_data = torch.load( self.get_property_path(property), weights_only=False ) diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index 1487163..991ab16 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -16,9 +16,10 @@ class PropertyEncoder(abc.ABC): **kwargs: Additional keyword arguments. """ - def __init__(self, property, **kwargs) -> None: + def __init__(self, property, eval=False, **kwargs) -> None: self.property = property self._encoding_length: int = 1 + self.eval = eval # if True, do not update cache (for index encoder) @property def name(self) -> str: @@ -150,6 +151,10 @@ def encode(self, token: str | None) -> torch.Tensor: self._count_for_unk_token += 1 return torch.tensor([self._unk_token_idx]) + if self.eval and str(token) not in self.cache: + self._count_for_unk_token += 1 + return torch.tensor([self._unk_token_idx]) + if str(token) not in self.cache: self.cache[str(token)] = len(self.cache) return torch.tensor([self.cache[str(token)] + self.offset]) @@ -213,6 +218,15 @@ def encode(self, token: str | None) -> torch.Tensor: Returns: One-hot encoded tensor of shape (1, encoding_length). """ + if self.eval: + if token is None or str(token) not in self.cache: + self._count_for_unk_token += 1 + return torch.zeros(self.get_encoding_length(), dtype=torch.int64) + index = self.cache[str(token)] + self.offset + return torch.nn.functional.one_hot( + torch.tensor(index), num_classes=self.get_encoding_length() + ) + if token not in self.tokens_dict: self._count_for_unk_token += 1 return torch.zeros(1, self.get_encoding_length(), dtype=torch.int64)