Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import sys
from abc import ABC
from itertools import islice
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional

import deepsmiles
import selfies as sf
Expand Down Expand Up @@ -119,23 +120,14 @@ def on_finish(self) -> None:
return


class ChemDataReader(DataReader):
class TokenIndexerReader(DataReader, ABC):
"""
Data reader for chemical data using SMILES tokens.
Abstract base class for reading tokenized data and mapping tokens to unique indices.

Args:
collator_kwargs: Optional dictionary of keyword arguments for the collator.
token_path: Optional path for the token file.
kwargs: Additional keyword arguments.
This class maintains a cache of token-to-index mappings that can be extended during runtime,
and saves new tokens to a persistent file at the end of processing.
"""

COLLATOR = RaggedCollator

@classmethod
def name(cls) -> str:
"""Returns the name of the data reader."""
return "smiles_token"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
with open(self.token_path, "r") as pk:
Expand All @@ -150,21 +142,9 @@ def _get_token_index(self, token: str) -> int:
self.cache[(str(token))] = len(self.cache)
return self.cache[str(token)] + EMBEDDING_OFFSET

def _read_data(self, raw_data: str) -> List[int]:
"""
Reads and tokenizes raw SMILES data into a list of token indices.

Args:
raw_data (str): The raw SMILES string to be tokenized.

Returns:
List[int]: A list of integers representing the indices of the SMILES tokens.
"""
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]

def on_finish(self) -> None:
"""
Saves the current cache of tokens to the token file. This method is called after all data processing is complete.
Saves the current cache of tokens to the token file.This method is called after all data processing is complete.
"""
print(f"first 10 tokens: {list(islice(self.cache, 10))}")

Expand All @@ -188,6 +168,31 @@ def on_finish(self) -> None:
pk.writelines([f"{c}\n" for c in new_tokens])


class ChemDataReader(TokenIndexerReader):
"""
Data reader for chemical data using SMILES tokens.
"""

COLLATOR = RaggedCollator

@classmethod
def name(cls) -> str:
"""Returns the name of the data reader."""
return "smiles_token"

def _read_data(self, raw_data: str) -> List[int]:
"""
Reads and tokenizes raw SMILES data into a list of token indices.

Args:
raw_data (str): The raw SMILES string to be tokenized.

Returns:
List[int]: A list of integers representing the indices of the SMILES tokens.
"""
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]


class DeepChemDataReader(ChemDataReader):
"""
Data reader for chemical data using DeepSMILES tokens.
Expand Down