diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index 497025c..f998396 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -1,8 +1,12 @@ import abc import os -import torch from typing import Optional +import torch +import sys +from itertools import islice +import inspect + class PropertyEncoder(abc.ABC): def __init__(self, property, **kwargs): @@ -36,11 +40,13 @@ class IndexEncoder(PropertyEncoder): def __init__(self, property, indices_dir=None, **kwargs): super().__init__(property, **kwargs) if indices_dir is None: - indices_dir = os.path.dirname(__file__) + indices_dir = os.path.dirname(inspect.getfile(self.__class__)) self.dirname = indices_dir # load already existing cache with open(self.index_path, "r") as pk: - self.cache = [x.strip() for x in pk] + self.cache: dict[str, int] = { + token.strip(): idx for idx, token in enumerate(pk) + } self.index_length_start = len(self.cache) self.offset = 0 @@ -64,19 +70,33 @@ def index_path(self): def on_finish(self): """Save cache""" - with open(self.index_path, "w") as pk: - new_length = len(self.cache) - self.index_length_start - pk.writelines([f"{c}\n" for c in self.cache]) - print( - f"saved index of property {self.property.name} to {self.index_path}, " - f"index length: {len(self.cache)} (new: {new_length})" - ) + total_tokens = len(self.cache) + if total_tokens > self.index_length_start: + print("New tokens added to the cache, Saving them to index token file.....") + + assert sys.version_info >= ( + 3, + 7, + ), "This code requires Python 3.7 or higher." + # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order + # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights + # https://mail.python.org/pipermail/python-dev/2017-December/151283.html + new_tokens = list(islice(self.cache, self.index_length_start, total_tokens)) + + with open(self.index_path, "a") as pk: + pk.writelines([f"{c}\n" for c in new_tokens]) + print( + f"New {len(new_tokens)} tokens append to index of property {self.property.name} to {self.index_path}..." + ) + print( + f"Now, the total length of the index of property {self.property.name} is {total_tokens}" + ) def encode(self, token): """Returns a unique number for each token, automatically adds new tokens to the cache.""" if not str(token) in self.cache: - self.cache.append(str(token)) - return torch.tensor([self.cache.index(str(token)) + self.offset]) + self.cache[(str(token))] = len(self.cache) + return torch.tensor([self.cache[str(token)] + self.offset]) class OneHotEncoder(IndexEncoder):