Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chebai_graph/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def forward(self, batch: dict) -> torch.Tensor:
torch.Tensor: Predicted output.
"""
graph_data = batch["features"][0]
graph_data.to(self.device)
assert isinstance(graph_data, GraphData)
a = self.gnn(batch)
a = scatter_add(a, graph_data.batch, dim=0)
Expand Down
2 changes: 2 additions & 0 deletions chebai_graph/preprocessing/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
ChEBI50_WFGE_WGN_GraphProp,
ChEBI50GraphData,
ChEBI50GraphProperties,
ChEBI100GraphProperties,
)
from .pubchem import PubChemGraphProperties

__all__ = [
"ChEBI50GraphFGAugmentorReader",
"ChEBI50GraphProperties",
"ChEBI100GraphProperties",
"ChEBI50GraphData",
"PubChemGraphProperties",
"ChEBI50_Atom_WGNOnly_GraphProp",
Expand Down
15 changes: 10 additions & 5 deletions chebai_graph/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC
from collections.abc import Callable
from pprint import pformat
from typing import Optional

import pandas as pd
import torch
Expand Down Expand Up @@ -281,7 +282,9 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
molecule_attr=molecule_attr,
)

def load_processed_data_from_file(self, filename: str) -> list[dict]:
def load_processed_data(
self, kind: Optional[str] = None, filename: Optional[str] = None
) -> list[dict]:
"""
Load dataset and merge cached properties into base features.

Expand All @@ -291,7 +294,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
Returns:
List of data entries, each a dictionary.
"""
base_data = super().load_processed_data_from_file(filename)
base_data = super().load_processed_data(kind, filename)
base_df = pd.DataFrame(base_data)

for property in self.properties:
Expand Down Expand Up @@ -379,7 +382,9 @@ def __init__(self, properties=None, transform=None, **kwargs):
f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}",
)

def load_processed_data_from_file(self, filename: str) -> list[dict]:
def load_processed_data(
self, kind: Optional[str] = None, filename: Optional[str] = None
) -> list[dict]:
"""
Load dataset and merge cached properties into base features.

Expand All @@ -389,9 +394,8 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
Returns:
List of data entries, each a dictionary.
"""
base_data = super().load_processed_data_from_file(filename)
base_data = super().load_processed_data(kind, filename)
base_df = pd.DataFrame(base_data)

props_categories = {
"AllNodeTypeProperties": [],
"FGNodeTypeProperties": [],
Expand Down Expand Up @@ -442,6 +446,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
)

for property in self.properties:
rank_zero_info(f"Loading property {property.name}...")
property_data = torch.load(
self.get_property_path(property), weights_only=False
)
Expand Down
16 changes: 15 additions & 1 deletion chebai_graph/preprocessing/property_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ class PropertyEncoder(abc.ABC):
**kwargs: Additional keyword arguments.
"""

def __init__(self, property, **kwargs) -> None:
def __init__(self, property, eval=False, **kwargs) -> None:
self.property = property
self._encoding_length: int = 1
self.eval = eval # if True, do not update cache (for index encoder)

@property
def name(self) -> str:
Expand Down Expand Up @@ -150,6 +151,10 @@ def encode(self, token: str | None) -> torch.Tensor:
self._count_for_unk_token += 1
return torch.tensor([self._unk_token_idx])

if self.eval and str(token) not in self.cache:
self._count_for_unk_token += 1
return torch.tensor([self._unk_token_idx])

if str(token) not in self.cache:
self.cache[str(token)] = len(self.cache)
return torch.tensor([self.cache[str(token)] + self.offset])
Expand Down Expand Up @@ -213,6 +218,15 @@ def encode(self, token: str | None) -> torch.Tensor:
Returns:
One-hot encoded tensor of shape (1, encoding_length).
"""
if self.eval:
if token is None or str(token) not in self.cache:
self._count_for_unk_token += 1
return torch.zeros(self.get_encoding_length(), dtype=torch.int64)
index = self.cache[str(token)] + self.offset
return torch.nn.functional.one_hot(
torch.tensor(index), num_classes=self.get_encoding_length()
)

if token not in self.tokens_dict:
self._count_for_unk_token += 1
return torch.zeros(1, self.get_encoding_length(), dtype=torch.int64)
Expand Down