Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,11 @@ cython_debug/
/logs
/results_buffer
electra_pretrained.ckpt

build
.virtual_documents
.jupyter
chebai.egg-info
lightning_logs
logs
.isort.cfg
35 changes: 28 additions & 7 deletions chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import sys
from itertools import islice
from typing import Any, Dict, List, Optional, Tuple

import deepsmiles
Expand Down Expand Up @@ -137,13 +139,16 @@ def name(cls) -> str:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
with open(self.token_path, "r") as pk:
self.cache = [x.strip() for x in pk]
self.cache: Dict[str, int] = {
token.strip(): idx for idx, token in enumerate(pk)
}
self._loaded_tokens_count = len(self.cache)

def _get_token_index(self, token: str) -> int:
"""Returns a unique number for each token, automatically adds new tokens."""
if not str(token) in self.cache:
self.cache.append(str(token))
return self.cache.index(str(token)) + EMBEDDING_OFFSET
self.cache[(str(token))] = len(self.cache)
return self.cache[str(token)] + EMBEDDING_OFFSET

def _read_data(self, raw_data: str) -> List[int]:
"""
Expand All @@ -161,10 +166,26 @@ def on_finish(self) -> None:
"""
Saves the current cache of tokens to the token file. This method is called after all data processing is complete.
"""
with open(self.token_path, "w") as pk:
print(f"saving {len(self.cache)} tokens to {self.token_path}...")
print(f"first 10 tokens: {self.cache[:10]}")
pk.writelines([f"{c}\n" for c in self.cache])
print(f"first 10 tokens: {list(islice(self.cache, 10))}")

total_tokens = len(self.cache)
if total_tokens > self._loaded_tokens_count:
print("New tokens added to the cache, Saving them to token file.....")

assert sys.version_info >= (
3,
7,
), "This code requires Python 3.7 or higher."
# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
new_tokens = list(
islice(self.cache, self._loaded_tokens_count, total_tokens)
)

with open(self.token_path, "a") as pk:
print(f"saving new {len(new_tokens)} tokens to {self.token_path}...")
pk.writelines([f"{c}\n" for c in new_tokens])


class DeepChemDataReader(ChemDataReader):
Expand Down
71 changes: 62 additions & 9 deletions tests/unit/readers/testChemDataReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ def setUpClass(cls, mock_file: mock_open) -> None:
"""
cls.reader = ChemDataReader(token_path="/mock/path")
# After initializing, cls.reader.cache should now be set to ['C', 'O', 'N', '=', '1', '(']
assert cls.reader.cache == [
"C",
"O",
"N",
"=",
"1",
"(",
], "Initial cache does not match expected values."
assert list(cls.reader.cache.items()) == list(
{
"C": 0,
"O": 1,
"N": 2,
"=": 3,
"1": 4,
"(": 5,
}.items()
), "Initial cache does not match expected values or the order doesn't match."

def test_read_data(self) -> None:
"""
Expand Down Expand Up @@ -87,7 +89,7 @@ def test_read_data_with_new_token(self) -> None:
)
# Ensure it's at the correct index
self.assertEqual(
self.reader.cache.index("[H-]"),
self.reader.cache["[H-]"],
index_for_last_token,
"The new token '[H-]' was not added at the correct index in the cache.",
)
Expand All @@ -102,6 +104,57 @@ def test_read_data_with_invalid_input(self) -> None:
with self.assertRaises(ValueError):
self.reader._read_data(raw_data)

@patch("builtins.open", new_callable=mock_open)
def test_finish_method_for_new_tokens(self, mock_file: mock_open) -> None:
"""
Test the on_finish method to ensure it appends only the new tokens to the token file in order.
"""
# Simulate that some tokens were already loaded
self.reader._loaded_tokens_count = 6 # 6 tokens already loaded
self.reader.cache = {
"C": 0,
"O": 1,
"N": 2,
"=": 3,
"1": 4,
"(": 5,
"[H-]": 6, # New token 1
"Br": 7, # New token 2
"Cl": 8, # New token 3
"Na": 9, # New token 4
"Mg": 10, # New token 5
}

# Run the on_finish method
self.reader.on_finish()

# Check that the file was opened in append mode ('a')
mock_file.assert_called_with(self.reader.token_path, "a")

# Verify the new tokens were written in the correct order
mock_file().writelines.assert_called_with(
["[H-]\n", "Br\n", "Cl\n", "Na\n", "Mg\n"]
)

def test_finish_method_no_new_tokens(self) -> None:
"""
Test the on_finish method when no new tokens are added (cache is the same).
"""
self.reader._loaded_tokens_count = 6 # No new tokens
self.reader.cache = {
"C": 0,
"O": 1,
"N": 2,
"=": 3,
"1": 4,
"(": 5,
}

with patch("builtins.open", new_callable=mock_open) as mock_file:
self.reader.on_finish()
# Check that no new tokens were written
mock_file().writelines.assert_not_called()


if __name__ == "__main__":
unittest.main()
16 changes: 9 additions & 7 deletions tests/unit/readers/testDeepChemDataReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ def setUpClass(cls, mock_file: mock_open) -> None:
"""
cls.reader = DeepChemDataReader(token_path="/mock/path")
# After initializing, cls.reader.cache should now be set to ['C', 'O', 'c', ')']
assert cls.reader.cache == [
"C",
"O",
"c",
")",
], "Cache initialization did not match expected tokens."
assert list(cls.reader.cache.items()) == list(
{
"C": 0,
"O": 1,
"c": 2,
")": 3,
}.items()
), "Cache initialization did not match expected tokens or the expected order."

def test_read_data(self) -> None:
"""
Expand Down Expand Up @@ -95,7 +97,7 @@ def test_read_data_with_new_token(self) -> None:
)
# Ensure it's at the correct index
self.assertEqual(
self.reader.cache.index("[H-]"),
self.reader.cache["[H-]"],
index_for_last_token,
"The new token '[H-]' was not added to the correct index in the cache.",
)
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/readers/testSelfiesReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ def setUpClass(cls, mock_file: mock_open) -> None:
"""
cls.reader = SelfiesReader(token_path="/mock/path")
# After initializing, cls.reader.cache should now be set to ['[C]', '[O]', '[=C]']
assert cls.reader.cache == [
"[C]",
"[O]",
"[=C]",
], "Cache initialization did not match expected tokens."
assert list(cls.reader.cache.items()) == list(
{
"[C]": 0,
"[O]": 1,
"[=C]": 2,
}.items()
), "Cache initialization did not match expected tokens or the expected order."

def test_read_data(self) -> None:
"""
Expand Down Expand Up @@ -98,7 +100,7 @@ def test_read_data_with_new_token(self) -> None:
)
# Ensure it's at the correct index
self.assertEqual(
self.reader.cache.index("[H-1]"),
self.reader.cache["[H-1]"],
index_for_last_token,
"The new token '[H-1]' was not added at the correct index in the cache.",
)
Expand Down