diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 2b3b1b0e..7cd74b17 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -205,17 +205,19 @@ def _read_data(self, raw_data: str) -> List[int]: if mol is not None: raw_data = Chem.MolToSmiles(mol, canonical=True) except Exception as e: - print(f"RDKit failed to process {raw_data}") + print(f"RDKit failed to canonicalize the SMILES: {raw_data}") print(f"\t{e}") try: + mol = Chem.MolFromSmiles(raw_data.strip()) + if mol is None: + raise ValueError(f"Invalid SMILES: {raw_data}") return [self._get_token_index(v[1]) for v in _tokenize(raw_data)] except ValueError as e: print(f"could not process {raw_data}") - print(f"\t{e}") + print(f"\tError: {e}") return None def _back_to_smiles(self, smiles_encoded): - token_file = self.reader.token_path token_coding = {} counter = 0 diff --git a/pyproject.toml b/pyproject.toml index eb75643d..4ba71f8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "torch", "transformers", "pysmiles==1.1.2", - "rdkit", + "rdkit==2024.3.6", "lightning==2.5.1", ] diff --git a/tests/unit/readers/testChemDataReader.py b/tests/unit/readers/testChemDataReader.py index ec018f00..9d322f27 100644 --- a/tests/unit/readers/testChemDataReader.py +++ b/tests/unit/readers/testChemDataReader.py @@ -42,19 +42,22 @@ def test_read_data(self) -> None: """ Test the _read_data method with a SMILES string to ensure it correctly tokenizes the string. """ - raw_data = "CC(=O)NC1[Mg-2]" + raw_data = "CC(=O)NC1CC1[Mg-2]" # Expected output as per the tokens already in the cache, and ")" getting added to it. expected_output: List[int] = [ EMBEDDING_OFFSET + 0, # C EMBEDDING_OFFSET + 0, # C - EMBEDDING_OFFSET + 5, # = - EMBEDDING_OFFSET + 3, # O - EMBEDDING_OFFSET + 1, # N - EMBEDDING_OFFSET + len(self.reader.cache), # ( - EMBEDDING_OFFSET + 2, # C + EMBEDDING_OFFSET + 5, # ( + EMBEDDING_OFFSET + 3, # = + EMBEDDING_OFFSET + 1, # O + EMBEDDING_OFFSET + len(self.reader.cache), # ) - new token + EMBEDDING_OFFSET + 2, # N EMBEDDING_OFFSET + 0, # C EMBEDDING_OFFSET + 4, # 1 - EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 4, # 1 + EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] - new token ] result = self.reader._read_data(raw_data) self.assertEqual( @@ -99,13 +102,14 @@ def test_read_data_with_invalid_input(self) -> None: Test the _read_data method with an invalid input. The invalid token should prompt a return value None """ - raw_data = "%INVALID%" - - result = self.reader._read_data(raw_data) - self.assertIsNone( - result, - "The output for invalid token '%INVALID%' should be None.", - ) + # see https://github.com/ChEB-AI/python-chebai/issues/137 + raw_datas = ["%INVALID%", "ADADAD", "ADASDAD", "CC(=O)NC1[Mg-2]"] + for raw_data in raw_datas: + result = self.reader._read_data(raw_data) + self.assertIsNone( + result, + f"The output for invalid token '{raw_data}' should be None.", + ) @patch("builtins.open", new_callable=mock_open) def test_finish_method_for_new_tokens(self, mock_file: mock_open) -> None: