Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions chebai/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,13 @@ def setup_processed(self):
print("Load data from file", filename)
data = self._load_data_from_file(filename)
print("Create splits")
train, test = train_test_split(data, train_size=self.train_split)
train, test = train_test_split(
data, train_size=1 - (self.validation_split + self.test_split)
)
del data
test, val = train_test_split(test, train_size=self.train_split)
test, val = train_test_split(
test, train_size=self.test_split / (self.validation_split + self.test_split)
)
torch.save(train, os.path.join(self.processed_dir, "train.pt"))
torch.save(test, os.path.join(self.processed_dir, "test.pt"))
torch.save(val, os.path.join(self.processed_dir, "validation.pt"))
Expand All @@ -179,6 +183,21 @@ def processed_file_names(self) -> List[str]:
"""
return ["test.pt", "train.pt", "validation.pt"]

def _set_processed_data_props(self):
"""
Self-supervised learning with PubChem does not use this metadata, therefore set them to zero.

Sets:
- self._num_of_labels: 0
- self._feature_vector_size: 0.
"""

self._num_of_labels = 0
self._feature_vector_size = 0

print(f"Number of labels for loaded data: {self._num_of_labels}")
print(f"Feature vector size: {self._feature_vector_size}")

def _perform_data_preparation(self, *args, **kwargs):
"""
Checks for raw data and downloads if necessary.
Expand Down